[dasher] Extract CAbstractPPM superclass from PPM Lang Model, make PPMPY a subclass of it



commit 9fa9a842293091ce6682ecf6ca75c9e3a81f48c7
Author: Alan Lawrence <acl33 inf phy cam ac uk>
Date:   Tue Feb 8 14:42:29 2011 +0000

    Extract CAbstractPPM superclass from PPM Lang Model, make PPMPY a subclass of it
    
    i.e. adds just map<py sym, count> to PPMnode.
    => PPMPY uses same optimized hashmap for (chinese) children, find_symbol, etc.

 .../LanguageModelling/PPMLanguageModel.cpp         |   61 +++---
 .../LanguageModelling/PPMLanguageModel.h           |   98 +++++---
 .../LanguageModelling/PPMPYLanguageModel.cpp       |  244 ++------------------
 .../LanguageModelling/PPMPYLanguageModel.h         |  162 ++-----------
 4 files changed, 131 insertions(+), 434 deletions(-)
---
diff --git a/Src/DasherCore/LanguageModelling/PPMLanguageModel.cpp b/Src/DasherCore/LanguageModelling/PPMLanguageModel.cpp
index 3f35e42..a01c29b 100644
--- a/Src/DasherCore/LanguageModelling/PPMLanguageModel.cpp
+++ b/Src/DasherCore/LanguageModelling/PPMLanguageModel.cpp
@@ -30,25 +30,15 @@ static char THIS_FILE[] = __FILE__;
 
 /////////////////////////////////////////////////////////////////////
 
-CPPMLanguageModel::CPPMLanguageModel(Dasher::CEventHandler *pEventHandler, CSettingsStore *pSettingsStore, const CAlphInfo *pAlph)
-:CLanguageModel(pEventHandler, pSettingsStore, pAlph), m_iMaxOrder(4), NodesAllocated(0), m_NodeAlloc(8192), m_ContextAlloc(1024) {
-  m_pRoot = m_NodeAlloc.Alloc();
-  m_pRoot->sym = -1;
-
+CAbstractPPM::CAbstractPPM(Dasher::CEventHandler *pEventHandler, CSettingsStore *pSettingsStore, const CAlphInfo *pAlph, CPPMnode *pRoot, int iMaxOrder)
+: CLanguageModel(pEventHandler, pSettingsStore, pAlph), m_pRoot(pRoot), m_iMaxOrder(iMaxOrder), bUpdateExclusion( GetLongParameter(LP_LM_UPDATE_EXCLUSION)!=0 ), m_ContextAlloc(1024) {
   m_pRootContext = m_ContextAlloc.Alloc();
   m_pRootContext->head = m_pRoot;
   m_pRootContext->order = 0;
-
-  // Cache parameters that don't make sense to adjust during the life of a language model...
-  bUpdateExclusion = ( GetLongParameter(LP_LM_UPDATE_EXCLUSION) !=0 );
-  
-  m_iMaxOrder = GetLongParameter( LP_LM_MAX_ORDER );
-  
 }
 
-/////////////////////////////////////////////////////////////////////
-
-CPPMLanguageModel::~CPPMLanguageModel() {
+bool CAbstractPPM::isValidContext(const Context context) const {
+  return m_setContexts.count((const CPPMContext *)context) > 0;
 }
 
 /////////////////////////////////////////////////////////////////////
@@ -57,7 +47,7 @@ CPPMLanguageModel::~CPPMLanguageModel() {
 void CPPMLanguageModel::GetProbs(Context context, std::vector<unsigned int> &probs, int norm, int iUniform) const {
   const CPPMContext *ppmcontext = (const CPPMContext *)(context);
 
-  DASHER_ASSERT(m_setContexts.count(ppmcontext) > 0);
+  DASHER_ASSERT(isValidContext(context));
 
   int iNumSymbols = GetSize();
   
@@ -155,13 +145,13 @@ void CPPMLanguageModel::GetProbs(Context context, std::vector<unsigned int> &pro
 /////////////////////////////////////////////////////////////////////
 // Update context with symbol 'Symbol'
 
-void CPPMLanguageModel::EnterSymbol(Context c, int Symbol) {
+void CAbstractPPM::EnterSymbol(Context c, int Symbol) {
   if(Symbol==0)
     return;
 
   DASHER_ASSERT(Symbol >= 0 && Symbol < GetSize());
 
-  CPPMLanguageModel::CPPMContext & context = *(CPPMContext *) (c);
+  CPPMContext & context = *(CPPMContext *) (c);
 
   while(context.head) {
 
@@ -196,14 +186,14 @@ void CPPMLanguageModel::EnterSymbol(Context c, int Symbol) {
 // add symbol to the context
 // creates new nodes, updates counts
 // and leaves 'context' at the new context
-void CPPMLanguageModel::LearnSymbol(Context c, int Symbol) {
+void CAbstractPPM::LearnSymbol(Context c, int Symbol) {
   
   if(Symbol==0)
     return;
   
 
   DASHER_ASSERT(Symbol >= 0 && Symbol < GetSize());
-  CPPMLanguageModel::CPPMContext & context = *(CPPMContext *) (c);
+  CPPMContext & context = *(CPPMContext *) (c);
   
   CPPMnode* n = AddSymbolToNode(context.head, Symbol);
   DASHER_ASSERT ( n == context.head->find_symbol(Symbol));
@@ -217,14 +207,14 @@ void CPPMLanguageModel::LearnSymbol(Context c, int Symbol) {
   
 }
 
-void CPPMLanguageModel::dumpSymbol(symbol sym) {
+void CAbstractPPM::dumpSymbol(symbol sym) {
   if((sym <= 32) || (sym >= 127))
     printf("<%d>", sym);
   else
     printf("%c", sym);
 }
 
-void CPPMLanguageModel::dumpString(char *str, int pos, int len)
+void CAbstractPPM::dumpString(char *str, int pos, int len)
         // Dump the string STR starting at position POS
 {
   char cc;
@@ -238,7 +228,7 @@ void CPPMLanguageModel::dumpString(char *str, int pos, int len)
   }
 }
 
-void CPPMLanguageModel::dumpTrie(CPPMLanguageModel::CPPMnode *t, int d)
+void CAbstractPPM::dumpTrie(CAbstractPPM::CPPMnode *t, int d)
         // diagnostic display of the PPM trie from node t and deeper
 {
 //TODO
@@ -274,7 +264,7 @@ void CPPMLanguageModel::dumpTrie(CPPMLanguageModel::CPPMnode *t, int d)
 */
 }
 
-void CPPMLanguageModel::dump()
+void CAbstractPPM::dump()
         // diagnostic display of the whole PPM trie
 {
 // TODO:
@@ -299,7 +289,7 @@ void CPPMLanguageModel::dump()
 */
 }
 
-bool CPPMLanguageModel::eq(CPPMLanguageModel *other) {
+bool CAbstractPPM::eq(CAbstractPPM *other) {
   std::map<CPPMnode *,CPPMnode *> equivs;
   if (!m_pRoot->eq(other->m_pRoot,equivs)) return false;
   //have first & second being equivalent, for all entries in map, except vine ptrs not checked.
@@ -320,7 +310,7 @@ bool CPPMLanguageModel::eq(CPPMLanguageModel *other) {
 /// PPMnode definitions 
 ////////////////////////////////////////////////////////////////////////
 
-bool CPPMLanguageModel::CPPMnode::eq(CPPMLanguageModel::CPPMnode *other, std::map<CPPMnode *,CPPMnode *> &equivs) {
+bool CAbstractPPM::CPPMnode::eq(CAbstractPPM::CPPMnode *other, std::map<CPPMnode *,CPPMnode *> &equivs) {
   if (sym != other->sym)
     return false;
   if (count != other->count)
@@ -340,7 +330,7 @@ bool CPPMLanguageModel::CPPMnode::eq(CPPMLanguageModel::CPPMnode *other, std::ma
 
 #define MAX_RUN 4
 
-CPPMLanguageModel::CPPMnode * CPPMLanguageModel::CPPMnode::find_symbol(symbol sym) const
+CAbstractPPM::CPPMnode * CAbstractPPM::CPPMnode::find_symbol(symbol sym) const
 // see if symbol is a child of node
 {
   if (m_iNumChildSlots < 0) //negative to mean "full alphabet", use direct indexing
@@ -367,7 +357,7 @@ CPPMLanguageModel::CPPMnode * CPPMLanguageModel::CPPMnode::find_symbol(symbol sy
   return 0;
 }
 
-void CPPMLanguageModel::CPPMnode::AddChild(CPPMnode *pNewChild, int numSymbols) {
+void CAbstractPPM::CPPMnode::AddChild(CPPMnode *pNewChild, int numSymbols) {
   if (m_iNumChildSlots < 0) {
     m_ppChildren[pNewChild->sym] = pNewChild;
   }
@@ -429,7 +419,7 @@ void CPPMLanguageModel::CPPMnode::AddChild(CPPMnode *pNewChild, int numSymbols)
   }
 }
 
-CPPMLanguageModel::CPPMnode * CPPMLanguageModel::AddSymbolToNode(CPPMnode *pNode, symbol sym) {
+CAbstractPPM::CPPMnode * CAbstractPPM::AddSymbolToNode(CPPMnode *pNode, symbol sym) {
 
   CPPMnode *pReturn = pNode->find_symbol(sym);
 
@@ -446,9 +436,7 @@ CPPMLanguageModel::CPPMnode * CPPMLanguageModel::AddSymbolToNode(CPPMnode *pNode
     }
   } else {
     //symbol does not exist at this level
-    pReturn = m_NodeAlloc.Alloc(); //count initialized to 1 but no symbol or vine pointer
-    ++NodesAllocated;
-    pReturn->sym = sym;
+    pReturn = makeNode(sym); //count initialized to 1 but no vine pointer
     pNode->AddChild(pReturn, GetSize());
     pReturn->vine = (pNode==m_pRoot) ? m_pRoot : AddSymbolToNode(pNode->vine,sym);
   }
@@ -456,6 +444,17 @@ CPPMLanguageModel::CPPMnode * CPPMLanguageModel::AddSymbolToNode(CPPMnode *pNode
   return pReturn;
 }
 
+CPPMLanguageModel::CPPMLanguageModel(CEventHandler *pEvt, CSettingsStore *sets, const CAlphInfo *pAlph)
+: CAbstractPPM(pEvt, sets, pAlph, new CPPMnode(-1), sets->GetLongParameter(LP_LM_MAX_ORDER)), NodesAllocated(0), m_NodeAlloc(8192) {
+}
+
+CAbstractPPM::CPPMnode *CPPMLanguageModel::makeNode(int sym) {
+  CPPMnode *res = m_NodeAlloc.Alloc();
+  res->sym = sym;
+  ++NodesAllocated;
+  return res;
+}
+
 struct BinaryRecord {
   int m_iIndex;
   int m_iChild;
diff --git a/Src/DasherCore/LanguageModelling/PPMLanguageModel.h b/Src/DasherCore/LanguageModelling/PPMLanguageModel.h
index 39bee14..37b034a 100644
--- a/Src/DasherCore/LanguageModelling/PPMLanguageModel.h
+++ b/Src/DasherCore/LanguageModelling/PPMLanguageModel.h
@@ -25,10 +25,16 @@ namespace Dasher {
   /// \ingroup LM
   /// @{
   ///
-  /// PPM language model
+  /// Common superclass for both PPM and PPMY language models. Implements the PPM tree,
+  /// inc. fast hashing of child nodes by symbol number; and entering and learning symbols
+  /// in a context, i.e. navigating and updating the tree, with update exclusion according
+  /// to LP_LM_UPDATE_EXCLUSION
   ///
-  class CPPMLanguageModel:public CLanguageModel, private NoClones {
-  private:
+  /// Subclasses must implement CLanguageModel::GetProbs and a makeNode() method (perhaps
+  /// using a pooled allocator).
+  ///
+  class CAbstractPPM :public CLanguageModel, private NoClones {
+  protected:
     class ChildIterator;
     class CPPMnode {
     private:
@@ -53,8 +59,8 @@ namespace Dasher {
       symbol sym;
       CPPMnode(symbol sym);
       CPPMnode();
-      ~CPPMnode();
-      bool eq(CPPMnode *other, std::map<CPPMnode *,CPPMnode *> &equivs);
+      virtual ~CPPMnode();
+      virtual bool eq(CPPMnode *other, std::map<CPPMnode *,CPPMnode *> &equivs);
 	  };
     class ChildIterator {
     private:
@@ -89,10 +95,26 @@ namespace Dasher {
       CPPMnode *head;
       int order;
     };
+    
+    ///Makes a new node, of whatever kind (subclass of CPPMnode, perhaps with extra info)
+    /// is required by the subclass, for the specified symbol. (Initial count will be 1.)
+    virtual CPPMnode *makeNode(int sym)=0;
+    CAbstractPPM(Dasher::CEventHandler * pEventHandler, CSettingsStore * pSettingsStore, const CAlphInfo *pAlph, CPPMnode *pRoot, int iMaxOrder);
+    
+    void dumpSymbol(symbol sym);
+    void dumpString(char *str, int pos, int len);
+    void dumpTrie(CPPMnode * t, int d);
+    
+    CPPMContext *m_pRootContext;
+    CPPMnode *m_pRoot;
+    
+    /// Cache parameters that don't make sense to adjust during the life of a language model...
+    const int m_iMaxOrder; 
+    const bool bUpdateExclusion;
+    
   public:
-    CPPMLanguageModel(Dasher::CEventHandler * pEventHandler, CSettingsStore * pSettingsStore, const CAlphInfo *pAlph);
-    bool eq(CPPMLanguageModel *other);
-    virtual ~ CPPMLanguageModel();
+    virtual bool eq(CAbstractPPM *other);
+    virtual ~CAbstractPPM() {};
 
     Context CreateEmptyContext();
     void ReleaseContext(Context context);
@@ -101,71 +123,73 @@ namespace Dasher {
     virtual void EnterSymbol(Context context, int Symbol);
     virtual void LearnSymbol(Context context, int Symbol);
 
-    virtual void GetProbs(Context context, std::vector < unsigned int >&Probs, int norm, int iUniform) const;
-
     void dump();
+    bool isValidContext(const Context c) const ;
+  private:
+    CPPMnode *AddSymbolToNode(CPPMnode * pNode, symbol sym);
 
+    CPooledAlloc < CPPMContext > m_ContextAlloc;
+    
+    std::set<const CPPMContext *> m_setContexts;
+  };
+
+  ///"Standard" PPM language model: GetProbs uses counts in PPM child nodes,
+  /// universal alpha+beta values read from LP_LM_ALPHA and LP_LM_BETA,
+  /// max order from LP_LM_MAX_ORDER.
+  class CPPMLanguageModel : public CAbstractPPM {
+  public:
+    CPPMLanguageModel(CEventHandler *pEventHandler, CSettingsStore *pSets, const CAlphInfo *pAlph);
+    virtual void GetProbs(Context context, std::vector < unsigned int >&Probs, int norm, int iUniform) const;
+  protected:
+    /// Makes a standard CPPMnode, but using a pooled allocator (m_NodeAlloc) - faster!
+    virtual CPPMnode *makeNode(int sym);
+    
     virtual bool WriteToFile(std::string strFilename);
     virtual bool ReadFromFile(std::string strFilename);
+  private:
+    int NodesAllocated;
+
     bool RecursiveWrite(CPPMnode *pNode, CPPMnode *pNextSibling, std::map<CPPMnode *, int> *pmapIdx, int *pNextIdx, std::ofstream *pOutputFile);
     int GetIndex(CPPMnode *pAddr, std::map<CPPMnode *, int> *pmapIdx, int *pNextIdx);
     CPPMnode *GetAddress(int iIndex, std::map<int, CPPMnode*> *pMap);
 
-    CPPMnode *AddSymbolToNode(CPPMnode * pNode, symbol sym);
-    
-    void dumpSymbol(symbol sym);
-    void dumpString(char *str, int pos, int len);
-    void dumpTrie(CPPMnode * t, int d);
-
-    CPPMContext *m_pRootContext;
-    CPPMnode *m_pRoot;
-
-    int m_iMaxOrder;
-
-    int NodesAllocated;
-
-    bool bUpdateExclusion;
-
     mutable CSimplePooledAlloc < CPPMnode > m_NodeAlloc;
-    CPooledAlloc < CPPMContext > m_ContextAlloc;
-
-    std::set<const CPPMContext *> m_setContexts;
   };
 
   /// @}
-  inline CPPMLanguageModel::ChildIterator CPPMLanguageModel::CPPMnode::children() const {
+  inline CAbstractPPM::ChildIterator CPPMLanguageModel::CPPMnode::children() const {
     //if m_iNumChildSlots = 0 / 1, m_ppChildren is direct pointer, else ptr to array (of pointers)
     CPPMnode *const *ppChild = (m_iNumChildSlots == 0 || m_iNumChildSlots == 1) ? &m_pChild : m_ppChildren;
     return ChildIterator(ppChild + abs(m_iNumChildSlots), ppChild - 1);
   }
   
-  inline const CPPMLanguageModel::ChildIterator CPPMLanguageModel::CPPMnode::end() const {
+  inline const CAbstractPPM::ChildIterator CPPMLanguageModel::CPPMnode::end() const {
     //if m_iNumChildSlots = 0 / 1, m_ppChildren is direct pointer, else ptr to array (of pointers)
     CPPMnode *const *ppChild = (m_iNumChildSlots == 0 || m_iNumChildSlots == 1) ? &m_pChild : m_ppChildren;
     return ChildIterator(ppChild, ppChild - 1);
   }
 
-  inline Dasher::CPPMLanguageModel::CPPMnode::CPPMnode(symbol _sym): sym(_sym) {
+  inline Dasher::CAbstractPPM::CPPMnode::CPPMnode(symbol _sym): sym(_sym) {
     vine = 0;
     m_iNumChildSlots = 0;
     m_ppChildren = NULL;
     count = 1;
   }
 
-  inline CPPMLanguageModel::CPPMnode::CPPMnode() {
+  inline CAbstractPPM::CPPMnode::CPPMnode() {
     vine = 0;
     m_iNumChildSlots = 0;
     m_ppChildren = NULL;
     count = 1;
   }
   
-  inline CPPMLanguageModel::CPPMnode::~CPPMnode() {
+  inline CAbstractPPM::CPPMnode::~CPPMnode() {
     //single child = is direct pointer to node, not array...
     if (m_iNumChildSlots != 1)
-      delete m_ppChildren;
+      delete[] m_ppChildren;
   }
 
-  inline CLanguageModel::Context CPPMLanguageModel::CreateEmptyContext() {
+  inline CLanguageModel::Context CAbstractPPM::CreateEmptyContext() {
     CPPMContext *pCont = m_ContextAlloc.Alloc();
     *pCont = *m_pRootContext;
 
@@ -174,7 +198,7 @@ namespace Dasher {
     return (Context) pCont;
   }
 
-  inline CLanguageModel::Context CPPMLanguageModel::CloneContext(Context Copy) {
+  inline CLanguageModel::Context CAbstractPPM::CloneContext(Context Copy) {
     CPPMContext *pCont = m_ContextAlloc.Alloc();
     CPPMContext *pCopy = (CPPMContext *) Copy;
     *pCont = *pCopy;
@@ -184,7 +208,7 @@ namespace Dasher {
     return (Context) pCont;
   }
 
-  inline void CPPMLanguageModel::ReleaseContext(Context release) {
+  inline void CAbstractPPM::ReleaseContext(Context release) {
 
     m_setContexts.erase(m_setContexts.find((CPPMContext *) release));
 
diff --git a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp
index 23247d7..cf12a66 100644
--- a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp
+++ b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp
@@ -34,34 +34,9 @@ static char THIS_FILE[] = __FILE__;
 /////////////////////////////////////////////////////////////////////
 
 CPPMPYLanguageModel::CPPMPYLanguageModel(Dasher::CEventHandler *pEventHandler, CSettingsStore *pSettingsStore, const CAlphInfo *pAlph, const CAlphInfo *pPyAlphabet)
-  :CLanguageModel(pEventHandler, pSettingsStore, pAlph), m_iMaxOrder(2), NodesAllocated(0), m_NodeAlloc(8192), m_ContextAlloc(1024), m_pPyAlphabet(pPyAlphabet){
-  m_pRoot = m_NodeAlloc.Alloc();
-  m_pRoot->sym = -1;
-  //  m_pRoot->child.resize(DIVISION, NULL);
-  //  m_pRoot->pychild.resize(DIVISION, NULL);
-
-
-
-  m_pRootContext = m_ContextAlloc.Alloc();
-  m_pRootContext->head = m_pRoot;
-  m_pRootContext->order = 0;
+  :CAbstractPPM(pEventHandler, pSettingsStore, pAlph, new CPPMPYnode(-1), 2), NodesAllocated(0), m_NodeAlloc(8192), m_pPyAlphabet(pPyAlphabet){
 
   m_iAlphSize = GetSize();
-  //  std::cout<<"Alphaunit: "<<UNITALPH<<std::endl;
-  
-  // Cache the result of update exclusion - otherwise we have to look up a lot when training, which is slow
-
-  // FIXME - this should be a boolean parameter
-
-  bUpdateExclusion = ( GetLongParameter(LP_LM_UPDATE_EXCLUSION) !=0 );
-    
-  m_iMaxOrder = 2;//GetLongParameter( LP_LM_MAX_ORDER );
-  //std::cout<<"Max Order: "<<m_iMaxOrder<<std::endl;
-}
-
-/////////////////////////////////////////////////////////////////////
-
-CPPMPYLanguageModel::~CPPMPYLanguageModel() {
 }
 
 /////////////////////////////////////////////////////////////////////
@@ -200,7 +175,7 @@ void CPPMPYLanguageModel::GetPartProbs(Context context, std::vector<pair<symbol,
   //  std::cout<<"Norms is "<<norm<<std::endl;
   //  std::cout<<"iUniform is "<<iUniform<<std::endl;
 
-  const CPPMPYContext *ppmcontext = (const CPPMPYContext *)(context);
+  const CPPMContext *ppmcontext = (const CPPMContext *)(context);
 
   //  DASHER_ASSERT(m_setContexts.count(ppmcontext) > 0);
 
@@ -235,10 +210,10 @@ void CPPMPYLanguageModel::GetPartProbs(Context context, std::vector<pair<symbol,
   int *vCounts=new int[vChildren.size()]; //num occurrences of symbol at same index in vChildren
 
   //new code
-  for (CPPMPYnode *pTemp = ppmcontext->head; pTemp; pTemp=pTemp->vine) {
+  for (CPPMnode *pTemp = ppmcontext->head; pTemp; pTemp=pTemp->vine) {
     int iTotal=0, i=0;
     for (std::vector<pair<symbol, unsigned int> >::const_iterator it = vChildren.begin(); it!=vChildren.end(); it++,i++) {
-      if (CPPMPYnode *pFound = pTemp->find_symbol(it->first)) {
+      if (CPPMnode *pFound = pTemp->find_symbol(it->first)) {
         iTotal += vCounts[i] = pFound->count; //double assignment
       } else
         vCounts[i] = 0;
@@ -310,8 +285,7 @@ void CPPMPYLanguageModel::GetPartProbs(Context context, std::vector<pair<symbol,
 // by an explicit cast to PPMPYLanguageModel whenever MandarinDasher was activated. Renaming
 // to GetProbs causes the normal (virtual) call to come straight here without any special-casing...
 void CPPMPYLanguageModel::GetProbs(Context context, std::vector<unsigned int> &probs, int norm, int iUniform) const {
-  const CPPMPYContext *ppmcontext = (const CPPMPYContext *)(context);
-
+  const CPPMContext *ppmcontext = (const CPPMContext *)(context);
 
   //  std::cout<<"PPMCONTEXT symbol: "<<ppmcontext->head->symbol<<std::endl;
   /*
@@ -354,10 +328,11 @@ void CPPMPYLanguageModel::GetProbs(Context context, std::vector<unsigned int> &p
   int alpha = GetLongParameter( LP_LM_ALPHA );
   int beta = GetLongParameter( LP_LM_BETA );
 
-  for (CPPMPYnode *pTemp = ppmcontext->head; pTemp; pTemp = pTemp->vine) {
+  for (CPPMnode *pTemp = ppmcontext->head; pTemp; pTemp = pTemp->vine) {
     int iTotal = 0;
+    const map<symbol, unsigned short int> &pychild( static_cast<CPPMPYnode *>(pTemp)->pychild);
 
-    for (map<symbol, unsigned short int>::iterator it=pTemp->pychild.begin(); it!=pTemp->pychild.end(); it++) {
+    for (map<symbol, unsigned short int>::const_iterator it=pychild.begin(); it!=pychild.end(); it++) {
       if(!(exclusions[it->first] && doExclusion))
         iTotal += it->second;
     }
@@ -365,7 +340,7 @@ void CPPMPYLanguageModel::GetProbs(Context context, std::vector<unsigned int> &p
     if(iTotal) {
       unsigned int size_of_slice = iToSpend;
       
-      for (map<symbol, unsigned short int>::iterator it = pTemp->pychild.begin(); it!=pTemp->pychild.end(); it++) {
+      for (map<symbol, unsigned short int>::const_iterator it = pychild.begin(); it!=pychild.end(); it++) {
         if(!(exclusions[it->first] && doExclusion)) {
           exclusions[it->first] = 1;
 	    
@@ -421,65 +396,6 @@ void CPPMPYLanguageModel::GetProbs(Context context, std::vector<unsigned int> &p
 }
 
 /////////////////////////////////////////////////////////////////////
-// Update context with symbol 'Symbol'
-
-void CPPMPYLanguageModel::EnterSymbol(Context c, int Symbol) {
-  if(Symbol<0)
-    return;
-
-  DASHER_ASSERT(Symbol >= 0 && Symbol < GetSize());
-
-  CPPMPYLanguageModel::CPPMPYContext & context = *(CPPMPYContext *) (c);
-
-  while(context.head) {
-
-    //std::cout<<"Max Order: "<<m_iMaxOrder<<std::endl;
-    if(context.order < m_iMaxOrder) {   // Only try to extend the context if it's not going to make it too long
-      if (CPPMPYnode *find = context.head->find_symbol(Symbol)) {
-	//	std::cout<<"FOund PPM Node for update!"<<std::endl;
-        context.order++;
-        context.head = find;
-        //      Usprintf(debug,TEXT("found context %x order %d\n"),head,order);
-        //      DebugOutput(debug);
-
-        //      std::cout << context.order << std::endl;
-        return;
-      }
-    }
-
-    // If we can't extend the current context, follow vine pointer to shorten it and try again
-
-    context.order--;
-    context.head = context.head->vine;
-  }
-
-  if(context.head == 0) {
-    context.head = m_pRoot;
-    context.order = 0;
-  }
-
-  //      std::cout << context.order << std::endl;
-
-}
-
-/////////////////////////////////////////////////////////////////////
-
-void CPPMPYLanguageModel::LearnSymbol(Context c, int Symbol) {
-  if(Symbol==0)
-    return;
-  
-  DASHER_ASSERT(Symbol >= 0 && Symbol < GetSize());
-  CPPMPYLanguageModel::CPPMPYContext & context = *(CPPMPYContext *) (c);
-  CPPMPYnode *n = AddSymbolToNode(context.head, Symbol);
-  DASHER_ASSERT(n == context.head->find_symbol(Symbol));
-  context.head = n;
-  context.order++;
-
-  while(context.order > m_iMaxOrder) {
-    context.head = context.head->vine;
-    context.order--;
-  }
-}
 
 //Do _not_ move on the context...
 void CPPMPYLanguageModel::LearnPYSymbol(Context c, int pysym) {
@@ -488,7 +404,7 @@ void CPPMPYLanguageModel::LearnPYSymbol(Context c, int pysym) {
     return;
 
   DASHER_ASSERT(pysym > 0 && pysym <= m_pPyAlphabet->GetNumberTextSymbols());
-  CPPMPYLanguageModel::CPPMPYContext & context = *(CPPMPYContext *) (c);
+  CPPMPYLanguageModel::CPPMContext & context = *(CPPMContext *) (c);
  
   //  std::cout<<"py learn context : "<<context.head->symbol<<std::endl;
   /*   CPPMPYnode * pNode = m_pRoot->child;
@@ -500,9 +416,9 @@ void CPPMPYLanguageModel::LearnPYSymbol(Context c, int pysym) {
      std::cout<<" "<<std::endl;
   */
 
-  for (CPPMPYnode *pNode = context.head; pNode; pNode=pNode->vine) {
-    if (++pNode->pychild[pysym]>1) {
-      //sym was already present
+  for (CPPMnode *pNode = context.head; pNode; pNode=pNode->vine) {
+    if (static_cast<CPPMPYnode *>(pNode)->pychild[pysym]++) {
+      //count non-zero before increment, i.e. sym already present
       if (bUpdateExclusion) break;
     }
   }
@@ -511,135 +427,11 @@ void CPPMPYLanguageModel::LearnPYSymbol(Context c, int pysym) {
   //context.order++;
 }
 
-void CPPMPYLanguageModel::dumpSymbol(int sym) {
-  if((sym <= 32) || (sym >= 127))
-    printf("<%d>", sym);
-  else
-    printf("%c", sym);
-}
-
-void CPPMPYLanguageModel::dumpString(char *str, int pos, int len)
-        // Dump the string STR starting at position POS
-{
-  char cc;
-  int p;
-  for(p = pos; p < pos + len; p++) {
-    cc = str[p];
-    if((cc <= 31) || (cc >= 127))
-      printf("<%d>", cc);
-    else
-      printf("%c", cc);
-  }
-}
-
-void CPPMPYLanguageModel::dumpTrie(CPPMPYLanguageModel::CPPMPYnode *t, int d)
-        // diagnostic display of the PPM trie from node t and deeper
-{
-//TODO
-/*
-	dchar debug[256];
-	int sym;
-	CPPMPYnode *s;
-	Usprintf( debug,TEXT("%5d %7x "), d, t );
-	//TODO: Uncomment this when headers sort out
-	//DebugOutput(debug);
-	if (t < 0) // pointer to input
-		printf( "                     <" );
-	else {
-		Usprintf(debug,TEXT( " %3d %5d %7x  %7x  %7x    <"), t->symbol,t->count, t->vine, t->child, t->next );
-		//TODO: Uncomment this when headers sort out
-		//DebugOutput(debug);
-	}
-	
-	dumpString( dumpTrieStr, 0, d );
-	Usprintf( debug,TEXT(">\n") );
-	//TODO: Uncomment this when headers sort out
-	//DebugOutput(debug);
-	if (t != 0) {
-		s = t->child;
-		while (s != 0) {
-			sym =s->symbol;
-			
-			dumpTrieStr [d] = sym;
-			dumpTrie( s, d+1 );
-			s = s->next;
-		}
-	}
-*/
-}
-
-void CPPMPYLanguageModel::dump()
-        // diagnostic display of the whole PPM trie
-{
-// TODO:
-/*
-	dchar debug[256];
-	Usprintf(debug,TEXT(  "Dump of Trie : \n" ));
-	//TODO: Uncomment this when headers sort out
-	//DebugOutput(debug);
-	Usprintf(debug,TEXT(   "---------------\n" ));
-	//TODO: Uncomment this when headers sort out
-	//DebugOutput(debug);
-	Usprintf( debug,TEXT(  "depth node     symbol count  vine   child      next   context\n") );
-	//TODO: Uncomment this when headers sort out
-	//DebugOutput(debug);
-	dumpTrie( root, 0 );
-	Usprintf( debug,TEXT(  "---------------\n" ));
-	//TODO: Uncomment this when headers sort out
-	//DebugOutput(debug);
-	Usprintf(debug,TEXT( "\n" ));
-	//TODO: Uncomment this when headers sort out
-	//DebugOutput(debug);
-*/
-}
-
-////////////////////////////////////////////////////////////////////////
-/// PPMPYnode definitions 
-////////////////////////////////////////////////////////////////////////
-
-CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::CPPMPYnode::find_symbol(int sym) const
-// see if symbol is a child of node
-{
-  //  printf("finding symbol %d at node %d\n",sym,node->id);
-
-  //Potentially replace with large scale find algorithm, necessary?
-  for (CPPMPYnode * found = child[ min(DIVISION-1, sym/UNITALPH) ]; found; found=found->next) {
-    if(found->sym == sym) {
-      return found;
-    }
-  }
-
-  return 0;
-}
-
-CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::AddSymbolToNode(CPPMPYnode *pNode, int sym) {
-  //  std::cout<<"Addnode sym "<<sym<<std::endl;
-  CPPMPYnode *pReturn = pNode->find_symbol(sym);
-
-  //  std::cout << sym << ",";
-
-  if(pReturn != NULL) {
-    //      std::cout << "Using existing node" << std::endl;
-    pReturn->count++;
-    if (!bUpdateExclusion) {
-      //update vine contexts too. Must exist if higher-order context does!
-      for (CPPMPYnode *v = pReturn->vine; v; v=v->vine) {
-        DASHER_ASSERT(v==m_pRoot || v->sym == sym);
-        v->count++;
-      }
-    }
-  } else {
-    //symbol does not exist at this level
-    pReturn = m_NodeAlloc.Alloc();        // count is initialized to 1 but no symbol or vine ptr
-    ++NodesAllocated;
-    pReturn->sym = sym;
-    const int childIdx( min(DIVISION-1, sym/UNITALPH) );
-    pReturn->next = pNode->child[childIdx];
-    pNode->child[childIdx] = pReturn;
-    pReturn->vine = (pNode == m_pRoot) ? m_pRoot : AddSymbolToNode(pNode->vine, sym);
-  }
-
-  return pReturn;
+CPPMPYLanguageModel::CPPMPYnode *CPPMPYLanguageModel::makeNode(int sym) {
+  CPPMPYnode *res = m_NodeAlloc.Alloc();
+  res->sym=sym;
+  ++NodesAllocated;
+  return res;
 }
 
 //Mandarin - PY not enabled for these read-write functions
diff --git a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h
index 312034f..707ef23 100644
--- a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h
+++ b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h
@@ -16,21 +16,9 @@
 #include "../../Common/NoClones.h"
 #include "../../Common/Allocators/PooledAlloc.h"
 
-#include "LanguageModel.h"
-#include "PPMPYLanguageModel.h"
-#include "../SCENode.h"
+#include "PPMLanguageModel.h"
 
 #include <vector>
-#include <fstream>
-#include <set>
-//Define alphabet sizes
-#define ALPHSIZE 7610
-#define PYALPHSIZE 1300
-//Implement a multi-branch tree, instead of a binary tree to gain speed: a trade-off between speed and memory; the choice of branch is implied by ranking of the symbol being searched/added 
-#define DIVISION 5
-#define UNITALPH (ALPHSIZE/DIVISION)
-#define UNITPY (PYALPHSIZE/DIVISION)
-
 
 namespace Dasher {
 
@@ -45,56 +33,21 @@ namespace Dasher {
   /// is _not_ entered into the context; new method GetPartProbs is used to compute probabilities
   /// for the next chinese symbol (which should be entered into context), by filtering to a set.
   ///
-  class CPPMPYLanguageModel:public CLanguageModel, private NoClones {
-  private:
-    class CPPMPYnode {
-    public:
-      CPPMPYnode * find_symbol(int sym)const;
-      //Each PPM node store DIVISION number of addresses for children, so that each node branches out DIVISION times (as compared to binary); this is aimed to give better run-time speed
-      CPPMPYnode * child[DIVISION];
-      CPPMPYnode *next;
-      CPPMPYnode *vine;
-      /// map from pinyin-symbol to count: the number of times each pinyin symbol has been seen in this context
-      std::map<symbol,unsigned short int> pychild;
-      unsigned short int count;
-      symbol sym;
-      CPPMPYnode(int sym);
-      CPPMPYnode();
-    };
-	  
-    class CPPMPYContext {
-    public:
-      CPPMPYContext(CPPMPYContext const &input) {
-        head = input.head;
-        order = input.order;
-      } CPPMPYContext(CPPMPYnode * _head = 0, int _order = 0):head(_head), order(_order) {
-      };
-      ~CPPMPYContext() {
-      };
-      void dump();
-      CPPMPYnode *head;
-      int order;
-    };
-	  
+  /// That is: from the superclass (CAbstractPPM) perspective, the alphabet is the chinese one;
+  /// hence, contexts store chinese symbols only, and EnterSymbol+LearnSymbol should be called
+  /// with _chinese_ symbol numbers. All PY-alph details are handled in this subclass, with extra
+  /// LearnPYSymbol method for updating the LM's pinyin predictions.
+  class CPPMPYLanguageModel : public CAbstractPPM {
   public:
-    ///Construct a new PPMPYLanguageModel.
-    /// \param pAlph alphabet containing the actual symbols we want to write (i.e. Chinese)
+    ///Construct a new PPMPYLanguageModel. 
+    /// \param pAlph alphabet containing the actual symbols we want to write (i.e. Chinese); this
+    /// is the only alphabet passed to the CAbstractPPM superclass.
     /// \param pPyAlph alphabet of pinyin phonemes; we will predict probabilities for these
     /// based (only) on the preceding _Chinese_ symbols.
     CPPMPYLanguageModel(Dasher::CEventHandler * pEventHandler, CSettingsStore * pSettingsStore, const CAlphInfo *pAlph, const CAlphInfo *pPyAlph);
 
-    virtual ~ CPPMPYLanguageModel();
-
-    Context CreateEmptyContext();
-    void ReleaseContext(Context context);
-    Context CloneContext(Context context);
-
-    ///Advance the context by entering a chinese symbol
-    virtual void EnterSymbol(Context context, int Symbol);
-    ///Train the LM with the specified Chinese symbol in that context (moves context on)
-    virtual void LearnSymbol(Context context, int Symbol);
     ///Learns a pinyin symbol in the specified context, but does not move the context on.
-    virtual void LearnPYSymbol(Context context, int Symbol);
+    void LearnPYSymbol(Context context, int Symbol);
 
     ///Predicts probabilities for the next Pinyin symbol (blending as per PPM,
     /// but using the pychild map rather than child CPPMPYnodes).
@@ -110,99 +63,28 @@ namespace Dasher {
     /// indicates a possible chinese symbol; on exit, the second element will have been filled in.
     void GetPartProbs(Context context, std::vector<std::pair<symbol, unsigned int> > &vChildren, int norm, int iUniform);
 
-    void dump();
-
     virtual bool WriteToFile(std::string strFilename);
     virtual bool ReadFromFile(std::string strFilename);
-    bool RecursiveWrite(CPPMPYnode *pNode, std::map<CPPMPYnode *, int> *pmapIdx, int *pNextIdx, std::ofstream *pOutputFile);
-    int GetIndex(CPPMPYnode *pAddr, std::map<CPPMPYnode *, int> *pmapIdx, int *pNextIdx);
-    CPPMPYnode *GetAddress(int iIndex, std::map<int, CPPMPYnode*> *pMap);
-
-    CPPMPYnode *AddSymbolToNode(CPPMPYnode * pNode, int sym);
-
-    void dumpSymbol(int sym);
-    void dumpString(char *str, int pos, int len);
-    void dumpTrie(CPPMPYnode * t, int d);
-
-    CPPMPYContext *m_pRootContext;
-    CPPMPYnode *m_pRoot;
-
-    int m_iMaxOrder;
-    double m_dBackOffConstat;
-
-    int NodesAllocated; //inclusive of both Character and PY nodes
-
-    bool bUpdateExclusion;
-
 
+  protected:
+    class CPPMPYnode : public CPPMnode {
+    public:
+      /// map from pinyin-symbol to count: the number of times each pinyin symbol has been seen in this context
+      std::map<symbol,unsigned short int> pychild;
+      inline CPPMPYnode(int sym) : CPPMnode(sym) {}
+      inline CPPMPYnode() : CPPMnode() {}
+    };
+    CPPMPYnode *makeNode(int sym);
     
-    mutable CSimplePooledAlloc < CPPMPYnode > m_NodeAlloc;
-    CPooledAlloc < CPPMPYContext > m_ContextAlloc;
-
-    std::set<const CPPMPYContext *> m_setContexts;
-
   private:
+    int NodesAllocated;
+    mutable CSimplePooledAlloc < CPPMPYnode > m_NodeAlloc;
 
     const CAlphInfo *m_pPyAlphabet;
     int m_iAlphSize;
-
   };
 
-  /// @}
-
-  inline Dasher::CPPMPYLanguageModel::CPPMPYnode::CPPMPYnode(int _sym):sym(_sym) {
-    //    child.clear();
-    //    pychild.clear();
-
-    next = vine = 0;
-    count = 1;
-
-    //Added: Mandarin; Setting initial values
-    for(int i =0; i <DIVISION; i++){
-      child[i] = NULL;
-    }
-  }
-
-  inline CPPMPYLanguageModel::CPPMPYnode::CPPMPYnode() {
-    //   child.clear();
-    //   pychild.clear();    
-
-    next = vine = 0;
-
-    count = 1;
-
-    //Added: Mandarin; Setting initial values
-    for(int i =0; i <DIVISION; i++){
-      child[i] = NULL;
-    }
-
-  }
-
-  inline CLanguageModel::Context CPPMPYLanguageModel::CreateEmptyContext() {
-    CPPMPYContext *pCont = m_ContextAlloc.Alloc();
-    *pCont = *m_pRootContext;
-
-    //    m_setContexts.insert(pCont);
-
-    return (Context) pCont;
-  }
-
-  inline CLanguageModel::Context CPPMPYLanguageModel::CloneContext(Context Copy) {
-    CPPMPYContext *pCont = m_ContextAlloc.Alloc();
-    CPPMPYContext *pCopy = (CPPMPYContext *) Copy;
-    *pCont = *pCopy;
-
-    //    m_setContexts.insert(pCont);
-
-    return (Context) pCont;
-  }
-
-  inline void CPPMPYLanguageModel::ReleaseContext(Context release) {
-
-    //    m_setContexts.erase(m_setContexts.find((CPPMPYContext *) release));
-
-    m_ContextAlloc.Free((CPPMPYContext *) release);
-  }
+  /// @}  
 }                               // end namespace Dasher
 
 #endif // __LanguageModelling__PPMPYLanguageModel_h__



[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]