[dasher] Rewrite LearnSymbol/AddSymbolToNode: avoid many calls to find_symbol



commit 788f59575607bbae4953f49edf36a81881516d5d
Author: Alan Lawrence <acl33 inf phy cam ac uk>
Date:   Wed Apr 7 23:22:47 2010 +0100

    Rewrite LearnSymbol/AddSymbolToNode: avoid many calls to find_symbol

 .../LanguageModelling/PPMLanguageModel.cpp         |   97 +++++++------------
 .../LanguageModelling/PPMLanguageModel.h           |    5 +-
 2 files changed, 38 insertions(+), 64 deletions(-)
---
diff --git a/Src/DasherCore/LanguageModelling/PPMLanguageModel.cpp b/Src/DasherCore/LanguageModelling/PPMLanguageModel.cpp
index 39b47ec..ad0a739 100644
--- a/Src/DasherCore/LanguageModelling/PPMLanguageModel.cpp
+++ b/Src/DasherCore/LanguageModelling/PPMLanguageModel.cpp
@@ -39,12 +39,11 @@ CPPMLanguageModel::CPPMLanguageModel(Dasher::CEventHandler *pEventHandler, CSett
   m_pRootContext->head = m_pRoot;
   m_pRootContext->order = 0;
 
-  // Cache the result of update exclusion - otherwise we have to look up a lot when training, which is slow
-
-  // FIXME - this should be a boolean parameter
-
+  // Cache parameters that don't make sense to adjust during the life of a language model...
   bUpdateExclusion = ( GetLongParameter(LP_LM_UPDATE_EXCLUSION) !=0 );
-
+  
+  m_iMaxOrder = GetLongParameter( LP_LM_MAX_ORDER );
+  
 }
 
 /////////////////////////////////////////////////////////////////////
@@ -157,42 +156,6 @@ void CPPMLanguageModel::GetProbs(Context context, std::vector<unsigned int> &pro
   DASHER_ASSERT(iToSpend == 0);
 }
 
-void CPPMLanguageModel::AddSymbol(CPPMLanguageModel::CPPMContext &context, symbol sym)
-        // add symbol to the context
-        // creates new nodes, updates counts
-        // and leaves 'context' at the new context
-{
-  // Ignore attempts to add the root symbol
-
-  if(sym==0)
-    return;
-
-  DASHER_ASSERT(sym >= 0 && sym < GetSize());
-
-  CPPMnode *vineptr, *temp;
-  int updatecnt = 1;
-
-  temp = context.head->vine;
-  context.head = AddSymbolToNode(context.head, sym, &updatecnt);
-  vineptr = context.head;
-  context.order++;
-
-  while(temp != 0) {
-    vineptr->vine = AddSymbolToNode(temp, sym, &updatecnt);
-    vineptr = vineptr->vine;
-    temp = temp->vine;
-  }
-  vineptr->vine = m_pRoot;
-
-  //m_iMaxOrder = LanguageModelParams()->GetValue(std::string("LMMaxOrder"));
-  m_iMaxOrder = GetLongParameter( LP_LM_MAX_ORDER );
-
-  while(context.order > m_iMaxOrder) {
-    context.head = context.head->vine;
-    context.order--;
-  }
-}
-
 /////////////////////////////////////////////////////////////////////
 // Update context with symbol 'Symbol'
 
@@ -237,15 +200,28 @@ void CPPMLanguageModel::EnterSymbol(Context c, int Symbol) {
 }
 
 /////////////////////////////////////////////////////////////////////
-
+// add symbol to the context
+// creates new nodes, updates counts
+// and leaves 'context' at the new context
 void CPPMLanguageModel::LearnSymbol(Context c, int Symbol) {
+  
   if(Symbol==0)
     return;
   
 
   DASHER_ASSERT(Symbol >= 0 && Symbol < GetSize());
   CPPMLanguageModel::CPPMContext & context = *(CPPMContext *) (c);
-  AddSymbol(context, Symbol);
+  
+  CPPMnode* n = AddSymbolToNode(context.head, Symbol);
+  DASHER_ASSERT ( n == context.head->find_symbol(Symbol));
+  context.head=n;
+  context.order++;
+  
+  while(context.order > m_iMaxOrder) {
+    context.head = context.head->vine;
+    context.order--;
+  }
+  
 }
 
 void CPPMLanguageModel::dumpSymbol(symbol sym) {
@@ -460,32 +436,31 @@ void CPPMLanguageModel::CPPMnode::AddChild(CPPMnode *pNewChild, int numSymbols)
   }
 }
 
-CPPMLanguageModel::CPPMnode * CPPMLanguageModel::AddSymbolToNode(CPPMnode *pNode, symbol sym, int *update) {
+CPPMLanguageModel::CPPMnode * CPPMLanguageModel::AddSymbolToNode(CPPMnode *pNode, symbol sym) {
+
   CPPMnode *pReturn = pNode->find_symbol(sym);
 
   //      std::cout << sym << ",";
 
   if(pReturn != NULL) {
-    //      std::cout << "Using existing node" << std::endl;
-
-    //            if (*update || (LanguageModelParams()->GetValue("LMUpdateExclusion") == 0) ) 
-    if(*update || !bUpdateExclusion) {  // perform update exclusions
-      pReturn->count++;
-      *update = 0;
+    pReturn->count++;
+    if (!bUpdateExclusion) {
+      //update vine contexts too. Guaranteed to exist if child does!
+      for (CPPMnode *v = pReturn->vine; v; v=v->vine) {
+        DASHER_ASSERT(v == m_pRoot || v->sym == sym);
+        v->count++;
+      }
     }
-    return pReturn;
+  } else {
+    //symbol does not exist at this level
+    pReturn = m_NodeAlloc.Alloc(); //count initialized to 1 but no symbol or vine pointer
+    ++NodesAllocated;
+    pReturn->sym = sym;
+    pNode->AddChild(pReturn, GetSize());
+    pReturn->vine = (pNode==m_pRoot) ? m_pRoot : AddSymbolToNode(pNode->vine,sym);
   }
-
-  //       std::cout << "Creating new node" << std::endl;
-
-  pReturn = m_NodeAlloc.Alloc();        // count is initialized to 1
-  pReturn->sym = sym;
-  pNode->AddChild(pReturn,GetSize());
-
-  ++NodesAllocated;
-
+  
   return pReturn;
-
 }
 
 struct BinaryRecord {
diff --git a/Src/DasherCore/LanguageModelling/PPMLanguageModel.h b/Src/DasherCore/LanguageModelling/PPMLanguageModel.h
index 4edac69..b7edce3 100644
--- a/Src/DasherCore/LanguageModelling/PPMLanguageModel.h
+++ b/Src/DasherCore/LanguageModelling/PPMLanguageModel.h
@@ -111,9 +111,8 @@ namespace Dasher {
     int GetIndex(CPPMnode *pAddr, std::map<CPPMnode *, int> *pmapIdx, int *pNextIdx);
     CPPMnode *GetAddress(int iIndex, std::map<int, CPPMnode*> *pMap);
 
-    CPPMnode *AddSymbolToNode(CPPMnode * pNode, symbol sym, int *update);
-
-    virtual void AddSymbol(CPPMContext & context, symbol sym);
+    CPPMnode *AddSymbolToNode(CPPMnode * pNode, symbol sym);
+    
     void dumpSymbol(symbol sym);
     void dumpString(char *str, int pos, int len);
     void dumpTrie(CPPMnode * t, int d);



[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]