[dasher] Replace pinyin 'leaves' (CPPMPYnode *pychild[DIVISION]) with map<symbol, count>



commit eac2414377540c871ff395761779c79ed9ba5968
Author: Alan Lawrence <acl33 inf phy cam ac uk>
Date:   Tue Feb 8 11:32:17 2011 +0000

    Replace pinyin 'leaves' (CPPMPYnode *pychild[DIVISION]) with map<symbol,count>
    
    ...as they have no children, are never used as contexts, and vine never read
    (they only store pinyin symbol + count). std::map may not be the fastest thing
    around but it's probably faster than the previous fixed-5-way hash, and soooo
    much more concise!
    
    Also remove non-functional ReadFromFile/WriteToFile, just return false = failure

 .../LanguageModelling/PPMPYLanguageModel.cpp       |  193 ++------------------
 .../LanguageModelling/PPMPYLanguageModel.h         |   36 +++--
 2 files changed, 43 insertions(+), 186 deletions(-)
---
diff --git a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp
index c3053d0..2f1df3e 100644
--- a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp
+++ b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp
@@ -359,34 +359,25 @@ void CPPMPYLanguageModel::GetProbs(Context context, std::vector<unsigned int> &p
   while(pTemp != 0) {
     int iTotal = 0;
 
-    CPPMPYnode *pSymbol;
-    for(i=0; i<DIVISION; i++){
-      pSymbol  = pTemp->pychild[i];
-      while(pSymbol) {
-	if(!(exclusions[pSymbol->sym] && doExclusion))
-	  iTotal += pSymbol->count;
-	pSymbol = pSymbol->next;
-      }
+    for (map<symbol, unsigned short int>::iterator it=pTemp->pychild.begin(); it!=pTemp->pychild.end(); it++) {
+      if(!(exclusions[it->first] && doExclusion))
+        iTotal += it->second;
     }
 
     if(iTotal) {
       unsigned int size_of_slice = iToSpend;
       
-      for(i=0; i<DIVISION; i++){
-	pSymbol = pTemp->pychild[i];
-	while(pSymbol) {
-	  if(!(exclusions[pSymbol->sym] && doExclusion)) {
-	    exclusions[pSymbol->sym] = 1;
+      for (map<symbol, unsigned short int>::iterator it = pTemp->pychild.begin(); it!=pTemp->pychild.end(); it++) {
+        if(!(exclusions[it->first] && doExclusion)) {
+          exclusions[it->first] = 1;
 	    
-	    unsigned int p = static_cast < myint > (size_of_slice) * (100 * pSymbol->count - beta) / (100 * iTotal + alpha);
+          unsigned int p = static_cast < myint > (size_of_slice) * (100 * it->second - beta) / (100 * iTotal + alpha);
 	    
-	    probs[pSymbol->sym] += p;
-	    iToSpend -= p;
-	  }
+          probs[it->first] += p;
+          iToSpend -= p;
+        }
 	  //                              Usprintf(debug,TEXT("sym %u counts %d p %u tospend %u \n"),sym,s->count,p,tospend);      
-	  //                              DebugOutput(debug);
-	  pSymbol = pSymbol->next;
-	}
+	  //                              DebugOutput( debug);
       }
 
     }
@@ -517,7 +508,13 @@ void CPPMPYLanguageModel::LearnPYSymbol(Context c, int pysym) {
      std::cout<<" "<<std::endl;
   */
 
-  AddPYSymbolToNode(context.head, pysym);
+  for (CPPMPYnode *pNode = context.head; pNode; pNode=pNode->vine) {
+    if (++pNode->pychild[pysym]>1) {
+      //sym was already present
+      if (bUpdateExclusion) break;
+    }
+  }
+  
   //no context order increase
   //context.order++;
 }
@@ -623,19 +620,6 @@ CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::CPPMPYnode::find_symbol(i
   return 0;
 }
 
-// New find pysymbol function, to find the py symbol in nodes attached to character node
-CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::CPPMPYnode::find_pysymbol(int pysym) const
-// see if pysymbol is a child of node
-{
-
-  for (CPPMPYnode *found=pychild[ min(DIVISION-1, pysym/UNITPY) ]; found; found=found->next) {
-    if(found->sym == pysym){
-      return found;
-    }
-  }
-  return 0;
-}
-
 CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::AddSymbolToNode(CPPMPYnode *pNode, int sym) {
   //  std::cout<<"Addnode sym "<<sym<<std::endl;
   CPPMPYnode *pReturn = pNode->find_symbol(sym);
@@ -666,151 +650,12 @@ CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::AddSymbolToNode(CPPMPYnod
   return pReturn;
 }
 
-CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::AddPYSymbolToNode(CPPMPYnode *pNode, int pysym) {
-  CPPMPYnode *pReturn = pNode->find_pysymbol(pysym);
-
-  //      std::cout << sym << ",";
-
-  if(pReturn != NULL) {
-    //      std::cout << "Using existing node" << std::endl;
-    pReturn->count++;
-    if (!bUpdateExclusion) {
-      //Update vine contexts too. Guaranteed to exist if higher-order context does!
-      for (CPPMPYnode *v=pReturn->vine; v; v=v->vine) {
-        DASHER_ASSERT(v->sym==pysym);
-        v->count++;
-      }
-    }
-  } else{
-    //       std::cout << "Creating new node" << std::endl;
-
-    pReturn = m_NodeAlloc.Alloc();        // count is initialized to 1, no symbol or vine ptr
-    ++NodesAllocated;
-    pReturn->sym = pysym;
-    const int childIdx = min(DIVISION-1, pysym/UNITPY);
-    pReturn->next = pNode->pychild[childIdx];
-    pNode->pychild[childIdx] = pReturn;
-    //the py tree attached to root should have vine pointers NULL (not m_pRoot!)
-    pReturn->vine = (pNode == m_pRoot) ? NULL : AddPYSymbolToNode(pNode->vine, pysym);
-  }
-  return pReturn;
-}
-
-struct BinaryRecord {
-  int m_iIndex;
-  int m_iChild;
-  int m_iNext;
-  int m_iVine;
-  unsigned short int m_iCount;
-  short int m_iSymbol;
-};
-
+//Mandarin - PY not enabled for these read-write functions
 bool CPPMPYLanguageModel::WriteToFile(std::string strFilename) {
-
-  std::cout<<"WRITE TO FILE USED?"<<std::endl;
-
-  std::map<CPPMPYnode *, int> mapIdx;
-  int iNextIdx(1); // Index of 0 means NULL;
-
-  std::ofstream oOutputFile(strFilename.c_str());
-
-  RecursiveWrite(m_pRoot, &mapIdx, &iNextIdx, &oOutputFile);
-
-  oOutputFile.close();
-
   return false;
 }
 
 //Mandarin - PY not enabled for these read-write functions
-bool CPPMPYLanguageModel::RecursiveWrite(CPPMPYnode *pNode, std::map<CPPMPYnode *, int> *pmapIdx, int *pNextIdx, std::ofstream *pOutputFile) {
-
-  // Dump node here
-
-  BinaryRecord sBR;
-
-  sBR.m_iIndex = GetIndex(pNode, pmapIdx, pNextIdx); 
-  //Note future changes here:
-  sBR.m_iChild = GetIndex(pNode->child[0], pmapIdx, pNextIdx); 
-  sBR.m_iNext = GetIndex(pNode->next, pmapIdx, pNextIdx); 
-  sBR.m_iVine = GetIndex(pNode->vine, pmapIdx, pNextIdx);
-  sBR.m_iCount = pNode->count;
-  sBR.m_iSymbol = pNode->sym;
-
-  pOutputFile->write(reinterpret_cast<char*>(&sBR), sizeof(BinaryRecord));
-
-  CPPMPYnode *pCurrentChild(pNode->child[0]);
-  
-  while(pCurrentChild != NULL) {
-    RecursiveWrite(pCurrentChild, pmapIdx, pNextIdx, pOutputFile);
-    pCurrentChild = pCurrentChild->next;
-  }
-
-  return true;
-}
-
-int CPPMPYLanguageModel::GetIndex(CPPMPYnode *pAddr, std::map<CPPMPYnode *, int> *pmapIdx, int *pNextIdx) {
-  std::cout<<"GetIndex gets called?"<<std::endl;
-  int iIndex;
-  if(pAddr == NULL)
-    iIndex = 0;
-  else {
-    std::map<CPPMPYnode *, int>::iterator it(pmapIdx->find(pAddr));
-    
-    if(it == pmapIdx->end()) {
-      iIndex = *pNextIdx;
-      pmapIdx->insert(std::pair<CPPMPYnode *, int>(pAddr, iIndex));
-      ++(*pNextIdx);
-    }
-    else {
-      iIndex = it->second;
-    }
-  }
-  return iIndex;
-}
-
-
-//Mandarin - PY not enabled for these read-write functions
 bool CPPMPYLanguageModel::ReadFromFile(std::string strFilename) {
-  
-  std::ifstream oInputFile(strFilename.c_str());
-  std::map<int, CPPMPYnode*> oMap;
-  BinaryRecord sBR;
-  bool bStarted(false);
-
-  while(!oInputFile.eof()) {
-    oInputFile.read(reinterpret_cast<char *>(&sBR), sizeof(BinaryRecord));
-
-    CPPMPYnode *pCurrent(GetAddress(sBR.m_iIndex, &oMap));
-    //Note future changes here:
-    pCurrent->child[0] = GetAddress(sBR.m_iChild, &oMap);
-    pCurrent->next = GetAddress(sBR.m_iNext, &oMap);
-    pCurrent->vine = GetAddress(sBR.m_iVine, &oMap);
-    pCurrent->count = sBR.m_iCount;
-    pCurrent->sym = sBR.m_iSymbol;
-
-    if(!bStarted) {
-      m_pRoot = pCurrent;
-      bStarted = true;
-    }
-  }
-
-  oInputFile.close();
-
   return false;
 }
-
-CPPMPYLanguageModel::CPPMPYnode *CPPMPYLanguageModel::GetAddress(int iIndex, std::map<int, CPPMPYnode*> *pMap) {
-
-  std::cout<<"Get Address gets called?"<<std::endl;
-  std::map<int, CPPMPYnode*>::iterator it(pMap->find(iIndex));
-
-  if(it == pMap->end()) {
-    CPPMPYnode *pNewNode;
-    pNewNode = m_NodeAlloc.Alloc();
-    pMap->insert(std::pair<int, CPPMPYnode*>(iIndex, pNewNode));
-    return pNewNode;
-  }
-  else {
-    return it->second;
-  }
-}
diff --git a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h
index b135894..312034f 100644
--- a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h
+++ b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h
@@ -39,21 +39,23 @@ namespace Dasher {
   /// @{
 
   ///
-  /// PPM language model (with PinYin)
+  /// PPM language model (with PinYin). Implements a standard PPM model amongst chinese characters,
+  /// but with each ppm-node additionally storing counts of possible Pinyin symbols which might be
+  /// entered in that context. GetProbs returns probabilities for the next Pinyin symbol, which (NB!)
+  /// is _not_ entered into the context; new method GetPartProbs is used to compute probabilities
+  /// for the next chinese symbol (which should be entered into context), by filtering to a set.
   ///
-
   class CPPMPYLanguageModel:public CLanguageModel, private NoClones {
   private:
     class CPPMPYnode {
     public:
       CPPMPYnode * find_symbol(int sym)const;
-      CPPMPYnode * find_pysymbol(int pysym)const;
       //Each PPM node store DIVISION number of addresses for children, so that each node branches out DIVISION times (as compared to binary); this is aimed to give better run-time speed
       CPPMPYnode * child[DIVISION];
       CPPMPYnode *next;
       CPPMPYnode *vine;
-      //Similarly (as last comment) for Pin Yin 
-      CPPMPYnode * pychild[DIVISION];
+      /// map from pinyin-symbol to count: the number of times each pinyin symbol has been seen in this context
+      std::map<symbol,unsigned short int> pychild;
       unsigned short int count;
       symbol sym;
       CPPMPYnode(int sym);
@@ -75,6 +77,10 @@ namespace Dasher {
     };
 	  
   public:
+    ///Construct a new PPMPYLanguageModel.
+    /// \param pAlph alphabet containing the actual symbols we want to write (i.e. Chinese)
+    /// \param pPyAlph alphabet of pinyin phonemes; we will predict probabilities for these
+    /// based (only) on the preceding _Chinese_ symbols.
     CPPMPYLanguageModel(Dasher::CEventHandler * pEventHandler, CSettingsStore * pSettingsStore, const CAlphInfo *pAlph, const CAlphInfo *pPyAlph);
 
     virtual ~ CPPMPYLanguageModel();
@@ -83,16 +89,25 @@ namespace Dasher {
     void ReleaseContext(Context context);
     Context CloneContext(Context context);
 
+    ///Advance the context by entering a chinese symbol
     virtual void EnterSymbol(Context context, int Symbol);
+    ///Train the LM with the specified Chinese symbol in that context (moves context on)
     virtual void LearnSymbol(Context context, int Symbol);
-    //Learns a pinyin symbol in the current context, but does not move the context on.
+    ///Learns a pinyin symbol in the specified context, but does not move the context on.
     virtual void LearnPYSymbol(Context context, int Symbol);
 
+    ///Predicts probabilities for the next Pinyin symbol (blending as per PPM,
+    /// but using the pychild map rather than child CPPMPYnodes).
+    /// \param Probs vector to fill with predictions for pinyin symbols: will be filled
+    ///  with m_pPyAlphabet->GetNumberTextSymbols() numbers plus an initial 0. 
     virtual void GetProbs(Context context, std::vector < unsigned int >&Probs, int norm, int iUniform) const;
     
-    //ACL renamed, just call GetProbs instead:
-    //void GetPYProbs(Context context, std::vector < unsigned int >&Probs, int norm, int iUniform);
-
+    ///Predicts probabilities for the next Chinese symbol, filtered to only include symbols within a specified set.
+    /// Predictions are made as per PPM, but considering only counts for the specified symbols; this means
+    /// the value of LP_LM_ALPHA is relative to the total counts of _those_ chinese symbols (in the specified
+    /// context), not to the total count of all chinese symbols in that context.
+    /// \param vChildren vector of (chinese symbol, probability) pairs; on entry, the first element of each pair
+    /// indicates a possible chinese symbol; on exit, the second element will have been filled in.
     void GetPartProbs(Context context, std::vector<std::pair<symbol, unsigned int> > &vChildren, int norm, int iUniform);
 
     void dump();
@@ -104,7 +119,6 @@ namespace Dasher {
     CPPMPYnode *GetAddress(int iIndex, std::map<int, CPPMPYnode*> *pMap);
 
     CPPMPYnode *AddSymbolToNode(CPPMPYnode * pNode, int sym);
-    CPPMPYnode *AddPYSymbolToNode(CPPMPYnode * pNode, int pysym);
 
     void dumpSymbol(int sym);
     void dumpString(char *str, int pos, int len);
@@ -146,7 +160,6 @@ namespace Dasher {
     //Added: Mandarin; Setting initial values
     for(int i =0; i <DIVISION; i++){
       child[i] = NULL;
-      pychild[i] = NULL;
     }
   }
 
@@ -161,7 +174,6 @@ namespace Dasher {
     //Added: Mandarin; Setting initial values
     for(int i =0; i <DIVISION; i++){
       child[i] = NULL;
-      pychild[i] = NULL;
     }
 
   }



[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]