[dasher] Rewrite PPMPYLanguageModel Add{, PY}Symbol{, ToNode} & find_{py, }symbol



commit 296ef819b652144e479151a69a2d728e8958ac33
Author: Alan Lawrence <acl33 inf phy cam ac uk>
Date:   Mon Feb 7 23:35:36 2011 +0000

    Rewrite PPMPYLanguageModel Add{,PY}Symbol{,ToNode} & find_{py,}symbol

 .../LanguageModelling/PPMPYLanguageModel.cpp       |  263 +++++---------------
 .../LanguageModelling/PPMPYLanguageModel.h         |   12 +-
 2 files changed, 64 insertions(+), 211 deletions(-)
---
diff --git a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp
index 6115fc9..34a9c52 100644
--- a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp
+++ b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp
@@ -433,85 +433,6 @@ void CPPMPYLanguageModel::GetProbs(Context context, std::vector<unsigned int> &p
   DASHER_ASSERT(iToSpend == 0);
 }
 
-
-void CPPMPYLanguageModel::AddSymbol(CPPMPYLanguageModel::CPPMPYContext &context, int sym)
-        // add symbol to the context
-        // creates new nodes, updates counts
-        // and leaves 'context' at the new context
-{
-  // Ignore attempts to add the root symbol
-
-  if(sym==0)
-    return;
-
-  DASHER_ASSERT(sym >= 0 && sym < GetSize());
-
-  CPPMPYnode *vineptr, *temp;
-  int updatecnt = 1;
-
-  temp = context.head->vine;
-  context.head = AddSymbolToNode(context.head, sym, &updatecnt);
-  vineptr = context.head;
-  context.order++;
-
-  while(temp != 0) {
-    vineptr->vine = AddSymbolToNode(temp, sym, &updatecnt);
-    vineptr = vineptr->vine;
-    temp = temp->vine;
-  }
-  vineptr->vine = m_pRoot;
-
-  //m_iMaxOrder = LanguageModelParams()->GetValue(std::string("LMMaxOrder"));
-  m_iMaxOrder = 2;//GetLongParameter( LP_LM_MAX_ORDER );
-  //std::cout<<"Max Order: "<<m_iMaxOrder<<std::endl;
-  while(context.order > m_iMaxOrder) {
-    context.head = context.head->vine;
-    context.order--;
-  }
-}
-
-void CPPMPYLanguageModel::AddPYSymbol(CPPMPYLanguageModel::CPPMPYContext &context, int pysym)
-        // add symbol to the context
-        // creates new nodes, updates counts
-        // and leaves 'context' at the new context
-{
-  // Ignore attempts to add the root symbol
-
-  if(pysym==0)
-    return;
-
-  DASHER_ASSERT(pysym >= 0 && pysym < m_pPyAlphabet->GetNumberTextSymbols());
-
-  CPPMPYnode *vineptr, *temp, *pytail;
-  int updatecnt = 1;
-
-
-  //update of vine pointers similar to old PPMPYnodes
-  temp = context.head->vine;
-  pytail = AddPYSymbolToNode(context.head, pysym, &updatecnt);
-  vineptr = pytail;
-
-  //no context order increase
-  //context.order++;
-
-  while(temp != 0) {
-    vineptr->vine = AddPYSymbolToNode(temp, pysym, &updatecnt);
-    vineptr = vineptr->vine;
-    temp = temp->vine;
-  }
-
-  //the py tree attached to root should have vine pointers NULL
-  vineptr->vine = NULL;
-
-  //m_iMaxOrder = LanguageModelParams()->GetValue(std::string("LMMaxOrder"));
-  /*  m_iMaxOrder = GetLongParameter( LP_LM_MAX_ORDER );
-
-  while(context.order > m_iMaxOrder) {
-    context.head = context.head->vine;
-    context.order--;*/
-}
-
-
 /////////////////////////////////////////////////////////////////////
 // Update context with symbol 'Symbol'
 
@@ -563,18 +484,28 @@ void CPPMPYLanguageModel::LearnSymbol(Context c, int Symbol) {
   if(Symbol==0)
     return;
   
-
   DASHER_ASSERT(Symbol >= 0 && Symbol < GetSize());
   CPPMPYLanguageModel::CPPMPYContext & context = *(CPPMPYContext *) (c);
-  AddSymbol(context, Symbol);
+  CPPMPYnode *n = AddSymbolToNode(context.head, Symbol);
+  DASHER_ASSERT(n == context.head->find_symbol(Symbol));
+  context.head = n;
+  context.order++;
+  //m_iMaxOrder = LanguageModelParams()->GetValue(std::string("LMMaxOrder"));
+  m_iMaxOrder = 2;//GetLongParameter( LP_LM_MAX_ORDER );
+  //std::cout<<"Max Order: "<<m_iMaxOrder<<std::endl;
+  while(context.order > m_iMaxOrder) {
+    context.head = context.head->vine;
+    context.order--;
+  }
 }
 
-void CPPMPYLanguageModel::LearnPYSymbol(Context c, int Symbol) {
-  if(Symbol==0)
+//Do _not_ move on the context...
+void CPPMPYLanguageModel::LearnPYSymbol(Context c, int pysym) {
+  // Ignore attempts to add the root symbol
+  if(pysym==0)
     return;
-  
 
-  DASHER_ASSERT(Symbol >= 0 && Symbol < m_pPyAlphabet->GetNumberTextSymbols());
+  DASHER_ASSERT(pysym > 0 && pysym <= m_pPyAlphabet->GetNumberTextSymbols());
   CPPMPYLanguageModel::CPPMPYContext & context = *(CPPMPYContext *) (c);
  
   //  std::cout<<"py learn context : "<<context.head->symbol<<std::endl;
@@ -586,7 +517,10 @@ void CPPMPYLanguageModel::LearnPYSymbol(Context c, int Symbol) {
     }
      std::cout<<" "<<std::endl;
   */
-     AddPYSymbol(context, Symbol);
+
+  AddPYSymbolToNode(context.head, pysym);
+  //no context order increase
+  //context.order++;
 }
 
 void CPPMPYLanguageModel::dumpSymbol(int sym) {
@@ -681,46 +615,12 @@ CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::CPPMPYnode::find_symbol(i
   //  printf("finding symbol %d at node %d\n",sym,node->id);
 
   //Potentially replace with large scale find algorithm, necessary?
-  CPPMPYnode * found = NULL;
-  bool bFound = 0;
-
-  for (int i=0; i<DIVISION-1; i++){
-    if(sym<(i+1)*UNITALPH){
-      //      std::cout<<"before"<<std::endl;
-      found = child[i];
-      //  std::cout<<"after child: "<<i<<std::endl;
-      //      std::cout<<"i "<<i<<std::endl;
-      //   std::cout<<"sym "<<sym<<std::endl;
-      bFound =1;
-      break;
-    }
-  }
-  
-  if(!bFound){
-    found = child[DIVISION-1];
-
-    //    std::cout<<"in last group "<<std::endl;
-  }
-
-  //  std::cout<<"here?"<<std::endl;
-  while(found) {
-    if(found->symbol == sym){
-      //  std::cout<<"Found!"<<std::endl;
+  for (CPPMPYnode * found = child[ min(DIVISION-1, sym/UNITALPH) ]; found; found=found->next) {
+    if(found->symbol == sym) {
       return found;
     }
-    found = found->next;
-    //  std::cout<<"next successful"<<std::endl;
-    // if((found)){
-
-    //   std::cout<<"next "<<found->symbol<<std::endl;
-    //  if(found->next)
-    //	std::cout<<"next not empty"<<std::endl;
-    // }
-    // else
-    //std::cout<<"found is NULL"<<std::endl;
   }
 
-  //  std::cout<<"end"<<std::endl;
   return 0;
 }
 
@@ -729,30 +629,15 @@ CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::CPPMPYnode::find_pysymbol
 // see if pysymbol is a child of node
 {
 
-  CPPMPYnode * found = NULL;
-  bool bFound = 0;
-
-  for (int i=0; i<DIVISION-1; i++){
-    if(pysym<(i+1)*UNITPY){
-      found = pychild[i];
-      bFound = 1;
-      break;
-    }
-  }
-  
-  if(!bFound)
-    found = pychild[DIVISION-1];
-
-  while(found) {
+  for (CPPMPYnode *found=pychild[ min(DIVISION-1, pysym/UNITPY) ]; found; found=found->next) {
     if(found->symbol == pysym){
       return found;
     }
-    found = found->next;
   }
   return 0;
 }
 
-CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::AddSymbolToNode(CPPMPYnode *pNode, int sym, int *update) {
+CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::AddSymbolToNode(CPPMPYnode *pNode, int sym) {
   //  std::cout<<"Addnode sym "<<sym<<std::endl;
   CPPMPYnode *pReturn = pNode->find_symbol(sym);
 
@@ -760,86 +645,56 @@ CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::AddSymbolToNode(CPPMPYnod
 
   if(pReturn != NULL) {
     //      std::cout << "Using existing node" << std::endl;
-
-    //            if (*update || (LanguageModelParams()->GetValue("LMUpdateExclusion") == 0) ) 
-    if(*update || !bUpdateExclusion) {  // perform update exclusions
-      pReturn->count++;
-      *update = 0;
-    }
-    return pReturn;
-  }
-
-  //       std::cout << "Creating new node" << std::endl;
-
-  pReturn = m_NodeAlloc.Alloc();        // count is initialized to 1
-  pReturn->symbol = sym;
-  //  pReturn->child.resize(DIVISION, NULL);
-  //  pReturn->pychild.resize(DIVISION, NULL);
-
-  bool bFound =0;
-  for (int i=0; i<DIVISION-1; i++){
-    if(sym<(i+1)*UNITALPH){
-      pReturn->next = pNode->child[i];
-      pNode->child[i]=pReturn;
-      bFound =1;
-      break;
+    pReturn->count++;
+    if (!bUpdateExclusion) {
+      //update vine contexts too. Must exist if higher-order context does!
+      for (CPPMPYnode *v = pReturn->vine; v; v=v->vine) {
+        DASHER_ASSERT(v==m_pRoot || v->symbol == sym);
+        v->count++;
+      }
     }
+  } else {
+    //symbol does not exist at this level
+    pReturn = m_NodeAlloc.Alloc();        // count is initialized to 1 but no symbol or vine ptr
+    ++NodesAllocated;
+    pReturn->symbol = sym;
+    const int childIdx( min(DIVISION-1, sym/UNITALPH) );
+    pReturn->next = pNode->child[childIdx];
+    pNode->child[childIdx] = pReturn;
+    pReturn->vine = (pNode == m_pRoot) ? m_pRoot : AddSymbolToNode(pNode->vine, sym);
   }
 
-  if(!bFound){
-    pReturn->next = pNode->child[DIVISION-1];
-    pNode->child[DIVISION-1]=pReturn;  
-  }
-
-  ++NodesAllocated;
-
   return pReturn;
-
 }
 
-CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::AddPYSymbolToNode(CPPMPYnode *pNode, int pysym, int *update) {
+CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::AddPYSymbolToNode(CPPMPYnode *pNode, int pysym) {
   CPPMPYnode *pReturn = pNode->find_pysymbol(pysym);
 
   //      std::cout << sym << ",";
 
   if(pReturn != NULL) {
     //      std::cout << "Using existing node" << std::endl;
-
-    //            if (*update || (LanguageModelParams()->GetValue("LMUpdateExclusion") == 0) ) 
-    if(*update || !bUpdateExclusion) {  // perform update exclusions
-      pReturn->count++;
-      *update = 0;
-    }
-    return pReturn;
-  }
-
-  //       std::cout << "Creating new node" << std::endl;
-
-  pReturn = m_NodeAlloc.Alloc();        // count is initialized to 1
-  pReturn->symbol = pysym;
-  //  pReturn->child.resize(DIVISION, NULL);
-  //  pReturn->pychild.resize(DIVISION, NULL);
-
-  bool bFound =0;
-  
-  for (int i=0; i<DIVISION-1; i++){
-    if(pysym<(i+1)*UNITPY){
-      pReturn->next = pNode->pychild[i];
-      pNode->pychild[i]=pReturn;
-      bFound = 1;
-      break;
+    pReturn->count++;
+    if (!bUpdateExclusion) {
+      //Update vine contexts too. Guaranteed to exist if higher-order context does!
+      for (CPPMPYnode *v=pReturn->vine; v; v=v->vine) {
+        DASHER_ASSERT(v->symbol==pysym);
+        v->count++;
+      }
     }
+  } else{
+    //       std::cout << "Creating new node" << std::endl;
+
+    pReturn = m_NodeAlloc.Alloc();        // count is initialized to 1, no symbol or vine ptr
+    ++NodesAllocated;
+    pReturn->symbol = pysym;
+    const int childIdx = min(DIVISION-1, pysym/UNITPY);
+    pReturn->next = pNode->pychild[childIdx];
+    pNode->pychild[childIdx] = pReturn;
+    //the py tree attached to root should have vine pointers NULL (not m_pRoot!)
+    pReturn->vine = (pNode == m_pRoot) ? NULL : AddPYSymbolToNode(pNode->vine, pysym);
   }
-
-  if(!bFound){
-    pReturn->next = pNode->pychild[DIVISION-1];
-    pNode->pychild[DIVISION-1]=pReturn;  
-  }
-
-  ++NodesAllocated;
-
   return pReturn;
-
 }
 
 struct BinaryRecord {
diff --git a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h
index 6d71436..b4e4def 100644
--- a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h
+++ b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h
@@ -28,8 +28,8 @@
 #define PYALPHSIZE 1300
 //Implement a multi-branch tree, instead of a binary tree to gain speed: a trade-off between speed and memory; the choice of branch is implied by ranking of the symbol being searched/added 
 #define DIVISION 5
-#define UNITALPH ALPHSIZE/DIVISION
-#define UNITPY PYALPHSIZE/DIVISION
+#define UNITALPH (ALPHSIZE/DIVISION)
+#define UNITPY (PYALPHSIZE/DIVISION)
 
 
 namespace Dasher {
@@ -85,6 +85,7 @@ namespace Dasher {
 
     virtual void EnterSymbol(Context context, int Symbol);
     virtual void LearnSymbol(Context context, int Symbol);
+    //Learns a pinyin symbol in the current context, but does not move the context on.
     virtual void LearnPYSymbol(Context context, int Symbol);
 
     virtual void GetProbs(Context context, std::vector < unsigned int >&Probs, int norm, int iUniform) const;
@@ -102,11 +103,8 @@ namespace Dasher {
     int GetIndex(CPPMPYnode *pAddr, std::map<CPPMPYnode *, int> *pmapIdx, int *pNextIdx);
     CPPMPYnode *GetAddress(int iIndex, std::map<int, CPPMPYnode*> *pMap);
 
-    CPPMPYnode *AddSymbolToNode(CPPMPYnode * pNode, int sym, int *update);
-    CPPMPYnode *AddPYSymbolToNode(CPPMPYnode * pNode, int pysym, int *update);
-
-    virtual void AddSymbol(CPPMPYContext & context, int sym);
-    void AddPYSymbol(CPPMPYContext & context, int pysym);
+    CPPMPYnode *AddSymbolToNode(CPPMPYnode * pNode, int sym);
+    CPPMPYnode *AddPYSymbolToNode(CPPMPYnode * pNode, int pysym);
 
     void dumpSymbol(int sym);
     void dumpString(char *str, int pos, int len);



[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]