[dasher] Extract CAbstractPPM superclass from PPM Lang Model, make PPMPY a subclass of it
- From: Patrick Welche <pwelche src gnome org>
- To: commits-list gnome org
- Cc:
- Subject: [dasher] Extract CAbstractPPM superclass from PPM Lang Model, make PPMPY a subclass of it
- Date: Tue, 15 Mar 2011 17:11:57 +0000 (UTC)
commit 9fa9a842293091ce6682ecf6ca75c9e3a81f48c7
Author: Alan Lawrence <acl33 inf phy cam ac uk>
Date: Tue Feb 8 14:42:29 2011 +0000
Extract CAbstractPPM superclass from PPM Lang Model, make PPMPY a subclass of it
i.e. adds just map<py sym, count> to PPMnode.
=> PPMPY uses same optimized hashmap for (chinese) children, find_symbol, etc.
.../LanguageModelling/PPMLanguageModel.cpp | 61 +++---
.../LanguageModelling/PPMLanguageModel.h | 98 +++++---
.../LanguageModelling/PPMPYLanguageModel.cpp | 244 ++------------------
.../LanguageModelling/PPMPYLanguageModel.h | 162 ++-----------
4 files changed, 131 insertions(+), 434 deletions(-)
---
diff --git a/Src/DasherCore/LanguageModelling/PPMLanguageModel.cpp b/Src/DasherCore/LanguageModelling/PPMLanguageModel.cpp
index 3f35e42..a01c29b 100644
--- a/Src/DasherCore/LanguageModelling/PPMLanguageModel.cpp
+++ b/Src/DasherCore/LanguageModelling/PPMLanguageModel.cpp
@@ -30,25 +30,15 @@ static char THIS_FILE[] = __FILE__;
/////////////////////////////////////////////////////////////////////
-CPPMLanguageModel::CPPMLanguageModel(Dasher::CEventHandler *pEventHandler, CSettingsStore *pSettingsStore, const CAlphInfo *pAlph)
-:CLanguageModel(pEventHandler, pSettingsStore, pAlph), m_iMaxOrder(4), NodesAllocated(0), m_NodeAlloc(8192), m_ContextAlloc(1024) {
- m_pRoot = m_NodeAlloc.Alloc();
- m_pRoot->sym = -1;
-
+CAbstractPPM::CAbstractPPM(Dasher::CEventHandler *pEventHandler, CSettingsStore *pSettingsStore, const CAlphInfo *pAlph, CPPMnode *pRoot, int iMaxOrder)
+: CLanguageModel(pEventHandler, pSettingsStore, pAlph), m_pRoot(pRoot), m_iMaxOrder(iMaxOrder), bUpdateExclusion( GetLongParameter(LP_LM_UPDATE_EXCLUSION)!=0 ), m_ContextAlloc(1024) {
m_pRootContext = m_ContextAlloc.Alloc();
m_pRootContext->head = m_pRoot;
m_pRootContext->order = 0;
-
- // Cache parameters that don't make sense to adjust during the life of a language model...
- bUpdateExclusion = ( GetLongParameter(LP_LM_UPDATE_EXCLUSION) !=0 );
-
- m_iMaxOrder = GetLongParameter( LP_LM_MAX_ORDER );
-
}
-/////////////////////////////////////////////////////////////////////
-
-CPPMLanguageModel::~CPPMLanguageModel() {
+bool CAbstractPPM::isValidContext(const Context context) const {
+ return m_setContexts.count((const CPPMContext *)context) > 0;
}
/////////////////////////////////////////////////////////////////////
@@ -57,7 +47,7 @@ CPPMLanguageModel::~CPPMLanguageModel() {
void CPPMLanguageModel::GetProbs(Context context, std::vector<unsigned int> &probs, int norm, int iUniform) const {
const CPPMContext *ppmcontext = (const CPPMContext *)(context);
- DASHER_ASSERT(m_setContexts.count(ppmcontext) > 0);
+ DASHER_ASSERT(isValidContext(context));
int iNumSymbols = GetSize();
@@ -155,13 +145,13 @@ void CPPMLanguageModel::GetProbs(Context context, std::vector<unsigned int> &pro
/////////////////////////////////////////////////////////////////////
// Update context with symbol 'Symbol'
-void CPPMLanguageModel::EnterSymbol(Context c, int Symbol) {
+void CAbstractPPM::EnterSymbol(Context c, int Symbol) {
if(Symbol==0)
return;
DASHER_ASSERT(Symbol >= 0 && Symbol < GetSize());
- CPPMLanguageModel::CPPMContext & context = *(CPPMContext *) (c);
+ CPPMContext & context = *(CPPMContext *) (c);
while(context.head) {
@@ -196,14 +186,14 @@ void CPPMLanguageModel::EnterSymbol(Context c, int Symbol) {
// add symbol to the context
// creates new nodes, updates counts
// and leaves 'context' at the new context
-void CPPMLanguageModel::LearnSymbol(Context c, int Symbol) {
+void CAbstractPPM::LearnSymbol(Context c, int Symbol) {
if(Symbol==0)
return;
DASHER_ASSERT(Symbol >= 0 && Symbol < GetSize());
- CPPMLanguageModel::CPPMContext & context = *(CPPMContext *) (c);
+ CPPMContext & context = *(CPPMContext *) (c);
CPPMnode* n = AddSymbolToNode(context.head, Symbol);
DASHER_ASSERT ( n == context.head->find_symbol(Symbol));
@@ -217,14 +207,14 @@ void CPPMLanguageModel::LearnSymbol(Context c, int Symbol) {
}
-void CPPMLanguageModel::dumpSymbol(symbol sym) {
+void CAbstractPPM::dumpSymbol(symbol sym) {
if((sym <= 32) || (sym >= 127))
printf("<%d>", sym);
else
printf("%c", sym);
}
-void CPPMLanguageModel::dumpString(char *str, int pos, int len)
+void CAbstractPPM::dumpString(char *str, int pos, int len)
// Dump the string STR starting at position POS
{
char cc;
@@ -238,7 +228,7 @@ void CPPMLanguageModel::dumpString(char *str, int pos, int len)
}
}
-void CPPMLanguageModel::dumpTrie(CPPMLanguageModel::CPPMnode *t, int d)
+void CAbstractPPM::dumpTrie(CAbstractPPM::CPPMnode *t, int d)
// diagnostic display of the PPM trie from node t and deeper
{
//TODO
@@ -274,7 +264,7 @@ void CPPMLanguageModel::dumpTrie(CPPMLanguageModel::CPPMnode *t, int d)
*/
}
-void CPPMLanguageModel::dump()
+void CAbstractPPM::dump()
// diagnostic display of the whole PPM trie
{
// TODO:
@@ -299,7 +289,7 @@ void CPPMLanguageModel::dump()
*/
}
-bool CPPMLanguageModel::eq(CPPMLanguageModel *other) {
+bool CAbstractPPM::eq(CAbstractPPM *other) {
std::map<CPPMnode *,CPPMnode *> equivs;
if (!m_pRoot->eq(other->m_pRoot,equivs)) return false;
//have first & second being equivalent, for all entries in map, except vine ptrs not checked.
@@ -320,7 +310,7 @@ bool CPPMLanguageModel::eq(CPPMLanguageModel *other) {
/// PPMnode definitions
////////////////////////////////////////////////////////////////////////
-bool CPPMLanguageModel::CPPMnode::eq(CPPMLanguageModel::CPPMnode *other, std::map<CPPMnode *,CPPMnode *> &equivs) {
+bool CAbstractPPM::CPPMnode::eq(CAbstractPPM::CPPMnode *other, std::map<CPPMnode *,CPPMnode *> &equivs) {
if (sym != other->sym)
return false;
if (count != other->count)
@@ -340,7 +330,7 @@ bool CPPMLanguageModel::CPPMnode::eq(CPPMLanguageModel::CPPMnode *other, std::ma
#define MAX_RUN 4
-CPPMLanguageModel::CPPMnode * CPPMLanguageModel::CPPMnode::find_symbol(symbol sym) const
+CAbstractPPM::CPPMnode * CAbstractPPM::CPPMnode::find_symbol(symbol sym) const
// see if symbol is a child of node
{
if (m_iNumChildSlots < 0) //negative to mean "full alphabet", use direct indexing
@@ -367,7 +357,7 @@ CPPMLanguageModel::CPPMnode * CPPMLanguageModel::CPPMnode::find_symbol(symbol sy
return 0;
}
-void CPPMLanguageModel::CPPMnode::AddChild(CPPMnode *pNewChild, int numSymbols) {
+void CAbstractPPM::CPPMnode::AddChild(CPPMnode *pNewChild, int numSymbols) {
if (m_iNumChildSlots < 0) {
m_ppChildren[pNewChild->sym] = pNewChild;
}
@@ -429,7 +419,7 @@ void CPPMLanguageModel::CPPMnode::AddChild(CPPMnode *pNewChild, int numSymbols)
}
}
-CPPMLanguageModel::CPPMnode * CPPMLanguageModel::AddSymbolToNode(CPPMnode *pNode, symbol sym) {
+CAbstractPPM::CPPMnode * CAbstractPPM::AddSymbolToNode(CPPMnode *pNode, symbol sym) {
CPPMnode *pReturn = pNode->find_symbol(sym);
@@ -446,9 +436,7 @@ CPPMLanguageModel::CPPMnode * CPPMLanguageModel::AddSymbolToNode(CPPMnode *pNode
}
} else {
//symbol does not exist at this level
- pReturn = m_NodeAlloc.Alloc(); //count initialized to 1 but no symbol or vine pointer
- ++NodesAllocated;
- pReturn->sym = sym;
+ pReturn = makeNode(sym); //count initialized to 1 but no vine pointer
pNode->AddChild(pReturn, GetSize());
pReturn->vine = (pNode==m_pRoot) ? m_pRoot : AddSymbolToNode(pNode->vine,sym);
}
@@ -456,6 +444,17 @@ CPPMLanguageModel::CPPMnode * CPPMLanguageModel::AddSymbolToNode(CPPMnode *pNode
return pReturn;
}
+CPPMLanguageModel::CPPMLanguageModel(CEventHandler *pEvt, CSettingsStore *sets, const CAlphInfo *pAlph)
+: CAbstractPPM(pEvt, sets, pAlph, new CPPMnode(-1), sets->GetLongParameter(LP_LM_MAX_ORDER)), NodesAllocated(0), m_NodeAlloc(8192) {
+}
+
+CAbstractPPM::CPPMnode *CPPMLanguageModel::makeNode(int sym) {
+ CPPMnode *res = m_NodeAlloc.Alloc();
+ res->sym = sym;
+ ++NodesAllocated;
+ return res;
+}
+
struct BinaryRecord {
int m_iIndex;
int m_iChild;
diff --git a/Src/DasherCore/LanguageModelling/PPMLanguageModel.h b/Src/DasherCore/LanguageModelling/PPMLanguageModel.h
index 39bee14..37b034a 100644
--- a/Src/DasherCore/LanguageModelling/PPMLanguageModel.h
+++ b/Src/DasherCore/LanguageModelling/PPMLanguageModel.h
@@ -25,10 +25,16 @@ namespace Dasher {
/// \ingroup LM
/// @{
///
- /// PPM language model
+ /// Common superclass for both PPM and PPMY language models. Implements the PPM tree,
+ /// inc. fast hashing of child nodes by symbol number; and entering and learning symbols
+ /// in a context, i.e. navigating and updating the tree, with update exclusion according
+ /// to LP_LM_UPDATE_EXCLUSION
///
- class CPPMLanguageModel:public CLanguageModel, private NoClones {
- private:
+ /// Subclasses must implement CLanguageModel::GetProbs and a makeNode() method (perhaps
+ /// using a pooled allocator).
+ ///
+ class CAbstractPPM :public CLanguageModel, private NoClones {
+ protected:
class ChildIterator;
class CPPMnode {
private:
@@ -53,8 +59,8 @@ namespace Dasher {
symbol sym;
CPPMnode(symbol sym);
CPPMnode();
- ~CPPMnode();
- bool eq(CPPMnode *other, std::map<CPPMnode *,CPPMnode *> &equivs);
+ virtual ~CPPMnode();
+ virtual bool eq(CPPMnode *other, std::map<CPPMnode *,CPPMnode *> &equivs);
};
class ChildIterator {
private:
@@ -89,10 +95,26 @@ namespace Dasher {
CPPMnode *head;
int order;
};
+
+ ///Makes a new node, of whatever kind (subclass of CPPMnode, perhaps with extra info)
+ /// is required by the subclass, for the specified symbol. (Initial count will be 1.)
+ virtual CPPMnode *makeNode(int sym)=0;
+ CAbstractPPM(Dasher::CEventHandler * pEventHandler, CSettingsStore * pSettingsStore, const CAlphInfo *pAlph, CPPMnode *pRoot, int iMaxOrder);
+
+ void dumpSymbol(symbol sym);
+ void dumpString(char *str, int pos, int len);
+ void dumpTrie(CPPMnode * t, int d);
+
+ CPPMContext *m_pRootContext;
+ CPPMnode *m_pRoot;
+
+ /// Cache parameters that don't make sense to adjust during the life of a language model...
+ const int m_iMaxOrder;
+ const bool bUpdateExclusion;
+
public:
- CPPMLanguageModel(Dasher::CEventHandler * pEventHandler, CSettingsStore * pSettingsStore, const CAlphInfo *pAlph);
- bool eq(CPPMLanguageModel *other);
- virtual ~ CPPMLanguageModel();
+ virtual bool eq(CAbstractPPM *other);
+ virtual ~CAbstractPPM() {};
Context CreateEmptyContext();
void ReleaseContext(Context context);
@@ -101,71 +123,73 @@ namespace Dasher {
virtual void EnterSymbol(Context context, int Symbol);
virtual void LearnSymbol(Context context, int Symbol);
- virtual void GetProbs(Context context, std::vector < unsigned int >&Probs, int norm, int iUniform) const;
-
void dump();
+ bool isValidContext(const Context c) const ;
+ private:
+ CPPMnode *AddSymbolToNode(CPPMnode * pNode, symbol sym);
+ CPooledAlloc < CPPMContext > m_ContextAlloc;
+
+ std::set<const CPPMContext *> m_setContexts;
+ };
+
+ ///"Standard" PPM language model: GetProbs uses counts in PPM child nodes,
+ /// universal alpha+beta values read from LP_LM_ALPHA and LP_LM_BETA,
+ /// max order from LP_LM_MAX_ORDER.
+ class CPPMLanguageModel : public CAbstractPPM {
+ public:
+ CPPMLanguageModel(CEventHandler *pEventHandler, CSettingsStore *pSets, const CAlphInfo *pAlph);
+ virtual void GetProbs(Context context, std::vector < unsigned int >&Probs, int norm, int iUniform) const;
+ protected:
+ /// Makes a standard CPPMnode, but using a pooled allocator (m_NodeAlloc) - faster!
+ virtual CPPMnode *makeNode(int sym);
+
virtual bool WriteToFile(std::string strFilename);
virtual bool ReadFromFile(std::string strFilename);
+ private:
+ int NodesAllocated;
+
bool RecursiveWrite(CPPMnode *pNode, CPPMnode *pNextSibling, std::map<CPPMnode *, int> *pmapIdx, int *pNextIdx, std::ofstream *pOutputFile);
int GetIndex(CPPMnode *pAddr, std::map<CPPMnode *, int> *pmapIdx, int *pNextIdx);
CPPMnode *GetAddress(int iIndex, std::map<int, CPPMnode*> *pMap);
- CPPMnode *AddSymbolToNode(CPPMnode * pNode, symbol sym);
-
- void dumpSymbol(symbol sym);
- void dumpString(char *str, int pos, int len);
- void dumpTrie(CPPMnode * t, int d);
-
- CPPMContext *m_pRootContext;
- CPPMnode *m_pRoot;
-
- int m_iMaxOrder;
-
- int NodesAllocated;
-
- bool bUpdateExclusion;
-
mutable CSimplePooledAlloc < CPPMnode > m_NodeAlloc;
- CPooledAlloc < CPPMContext > m_ContextAlloc;
-
- std::set<const CPPMContext *> m_setContexts;
};
/// @}
- inline CPPMLanguageModel::ChildIterator CPPMLanguageModel::CPPMnode::children() const {
+ inline CAbstractPPM::ChildIterator CPPMLanguageModel::CPPMnode::children() const {
//if m_iNumChildSlots = 0 / 1, m_ppChildren is direct pointer, else ptr to array (of pointers)
CPPMnode *const *ppChild = (m_iNumChildSlots == 0 || m_iNumChildSlots == 1) ? &m_pChild : m_ppChildren;
return ChildIterator(ppChild + abs(m_iNumChildSlots), ppChild - 1);
}
- inline const CPPMLanguageModel::ChildIterator CPPMLanguageModel::CPPMnode::end() const {
+ inline const CAbstractPPM::ChildIterator CPPMLanguageModel::CPPMnode::end() const {
//if m_iNumChildSlots = 0 / 1, m_ppChildren is direct pointer, else ptr to array (of pointers)
CPPMnode *const *ppChild = (m_iNumChildSlots == 0 || m_iNumChildSlots == 1) ? &m_pChild : m_ppChildren;
return ChildIterator(ppChild, ppChild - 1);
}
- inline Dasher::CPPMLanguageModel::CPPMnode::CPPMnode(symbol _sym): sym(_sym) {
+ inline Dasher::CAbstractPPM::CPPMnode::CPPMnode(symbol _sym): sym(_sym) {
vine = 0;
m_iNumChildSlots = 0;
m_ppChildren = NULL;
count = 1;
}
- inline CPPMLanguageModel::CPPMnode::CPPMnode() {
+ inline CAbstractPPM::CPPMnode::CPPMnode() {
vine = 0;
m_iNumChildSlots = 0;
m_ppChildren = NULL;
count = 1;
}
- inline CPPMLanguageModel::CPPMnode::~CPPMnode() {
+ inline CAbstractPPM::CPPMnode::~CPPMnode() {
//single child = is direct pointer to node, not array...
if (m_iNumChildSlots != 1)
- delete m_ppChildren;
+ delete[] m_ppChildren;
}
- inline CLanguageModel::Context CPPMLanguageModel::CreateEmptyContext() {
+ inline CLanguageModel::Context CAbstractPPM::CreateEmptyContext() {
CPPMContext *pCont = m_ContextAlloc.Alloc();
*pCont = *m_pRootContext;
@@ -174,7 +198,7 @@ namespace Dasher {
return (Context) pCont;
}
- inline CLanguageModel::Context CPPMLanguageModel::CloneContext(Context Copy) {
+ inline CLanguageModel::Context CAbstractPPM::CloneContext(Context Copy) {
CPPMContext *pCont = m_ContextAlloc.Alloc();
CPPMContext *pCopy = (CPPMContext *) Copy;
*pCont = *pCopy;
@@ -184,7 +208,7 @@ namespace Dasher {
return (Context) pCont;
}
- inline void CPPMLanguageModel::ReleaseContext(Context release) {
+ inline void CAbstractPPM::ReleaseContext(Context release) {
m_setContexts.erase(m_setContexts.find((CPPMContext *) release));
diff --git a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp
index 23247d7..cf12a66 100644
--- a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp
+++ b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.cpp
@@ -34,34 +34,9 @@ static char THIS_FILE[] = __FILE__;
/////////////////////////////////////////////////////////////////////
CPPMPYLanguageModel::CPPMPYLanguageModel(Dasher::CEventHandler *pEventHandler, CSettingsStore *pSettingsStore, const CAlphInfo *pAlph, const CAlphInfo *pPyAlphabet)
- :CLanguageModel(pEventHandler, pSettingsStore, pAlph), m_iMaxOrder(2), NodesAllocated(0), m_NodeAlloc(8192), m_ContextAlloc(1024), m_pPyAlphabet(pPyAlphabet){
- m_pRoot = m_NodeAlloc.Alloc();
- m_pRoot->sym = -1;
- // m_pRoot->child.resize(DIVISION, NULL);
- // m_pRoot->pychild.resize(DIVISION, NULL);
-
-
-
- m_pRootContext = m_ContextAlloc.Alloc();
- m_pRootContext->head = m_pRoot;
- m_pRootContext->order = 0;
+ :CAbstractPPM(pEventHandler, pSettingsStore, pAlph, new CPPMPYnode(-1), 2), NodesAllocated(0), m_NodeAlloc(8192), m_pPyAlphabet(pPyAlphabet){
m_iAlphSize = GetSize();
- // std::cout<<"Alphaunit: "<<UNITALPH<<std::endl;
-
- // Cache the result of update exclusion - otherwise we have to look up a lot when training, which is slow
-
- // FIXME - this should be a boolean parameter
-
- bUpdateExclusion = ( GetLongParameter(LP_LM_UPDATE_EXCLUSION) !=0 );
-
- m_iMaxOrder = 2;//GetLongParameter( LP_LM_MAX_ORDER );
- //std::cout<<"Max Order: "<<m_iMaxOrder<<std::endl;
-}
-
-/////////////////////////////////////////////////////////////////////
-
-CPPMPYLanguageModel::~CPPMPYLanguageModel() {
}
/////////////////////////////////////////////////////////////////////
@@ -200,7 +175,7 @@ void CPPMPYLanguageModel::GetPartProbs(Context context, std::vector<pair<symbol,
// std::cout<<"Norms is "<<norm<<std::endl;
// std::cout<<"iUniform is "<<iUniform<<std::endl;
- const CPPMPYContext *ppmcontext = (const CPPMPYContext *)(context);
+ const CPPMContext *ppmcontext = (const CPPMContext *)(context);
// DASHER_ASSERT(m_setContexts.count(ppmcontext) > 0);
@@ -235,10 +210,10 @@ void CPPMPYLanguageModel::GetPartProbs(Context context, std::vector<pair<symbol,
int *vCounts=new int[vChildren.size()]; //num occurrences of symbol at same index in vChildren
//new code
- for (CPPMPYnode *pTemp = ppmcontext->head; pTemp; pTemp=pTemp->vine) {
+ for (CPPMnode *pTemp = ppmcontext->head; pTemp; pTemp=pTemp->vine) {
int iTotal=0, i=0;
for (std::vector<pair<symbol, unsigned int> >::const_iterator it = vChildren.begin(); it!=vChildren.end(); it++,i++) {
- if (CPPMPYnode *pFound = pTemp->find_symbol(it->first)) {
+ if (CPPMnode *pFound = pTemp->find_symbol(it->first)) {
iTotal += vCounts[i] = pFound->count; //double assignment
} else
vCounts[i] = 0;
@@ -310,8 +285,7 @@ void CPPMPYLanguageModel::GetPartProbs(Context context, std::vector<pair<symbol,
// by an explicit cast to PPMPYLanguageModel whenever MandarinDasher was activated. Renaming
// to GetProbs causes the normal (virtual) call to come straight here without any special-casing...
void CPPMPYLanguageModel::GetProbs(Context context, std::vector<unsigned int> &probs, int norm, int iUniform) const {
- const CPPMPYContext *ppmcontext = (const CPPMPYContext *)(context);
-
+ const CPPMContext *ppmcontext = (const CPPMContext *)(context);
// std::cout<<"PPMCONTEXT symbol: "<<ppmcontext->head->symbol<<std::endl;
/*
@@ -354,10 +328,11 @@ void CPPMPYLanguageModel::GetProbs(Context context, std::vector<unsigned int> &p
int alpha = GetLongParameter( LP_LM_ALPHA );
int beta = GetLongParameter( LP_LM_BETA );
- for (CPPMPYnode *pTemp = ppmcontext->head; pTemp; pTemp = pTemp->vine) {
+ for (CPPMnode *pTemp = ppmcontext->head; pTemp; pTemp = pTemp->vine) {
int iTotal = 0;
+ const map<symbol, unsigned short int> &pychild( static_cast<CPPMPYnode *>(pTemp)->pychild);
- for (map<symbol, unsigned short int>::iterator it=pTemp->pychild.begin(); it!=pTemp->pychild.end(); it++) {
+ for (map<symbol, unsigned short int>::const_iterator it=pychild.begin(); it!=pychild.end(); it++) {
if(!(exclusions[it->first] && doExclusion))
iTotal += it->second;
}
@@ -365,7 +340,7 @@ void CPPMPYLanguageModel::GetProbs(Context context, std::vector<unsigned int> &p
if(iTotal) {
unsigned int size_of_slice = iToSpend;
- for (map<symbol, unsigned short int>::iterator it = pTemp->pychild.begin(); it!=pTemp->pychild.end(); it++) {
+ for (map<symbol, unsigned short int>::const_iterator it = pychild.begin(); it!=pychild.end(); it++) {
if(!(exclusions[it->first] && doExclusion)) {
exclusions[it->first] = 1;
@@ -421,65 +396,6 @@ void CPPMPYLanguageModel::GetProbs(Context context, std::vector<unsigned int> &p
}
/////////////////////////////////////////////////////////////////////
-// Update context with symbol 'Symbol'
-
-void CPPMPYLanguageModel::EnterSymbol(Context c, int Symbol) {
- if(Symbol<0)
- return;
-
- DASHER_ASSERT(Symbol >= 0 && Symbol < GetSize());
-
- CPPMPYLanguageModel::CPPMPYContext & context = *(CPPMPYContext *) (c);
-
- while(context.head) {
-
- //std::cout<<"Max Order: "<<m_iMaxOrder<<std::endl;
- if(context.order < m_iMaxOrder) { // Only try to extend the context if it's not going to make it too long
- if (CPPMPYnode *find = context.head->find_symbol(Symbol)) {
- // std::cout<<"FOund PPM Node for update!"<<std::endl;
- context.order++;
- context.head = find;
- // Usprintf(debug,TEXT("found context %x order %d\n"),head,order);
- // DebugOutput(debug);
-
- // std::cout << context.order << std::endl;
- return;
- }
- }
-
- // If we can't extend the current context, follow vine pointer to shorten it and try again
-
- context.order--;
- context.head = context.head->vine;
- }
-
- if(context.head == 0) {
- context.head = m_pRoot;
- context.order = 0;
- }
-
- // std::cout << context.order << std::endl;
-
-}
-
-/////////////////////////////////////////////////////////////////////
-
-void CPPMPYLanguageModel::LearnSymbol(Context c, int Symbol) {
- if(Symbol==0)
- return;
-
- DASHER_ASSERT(Symbol >= 0 && Symbol < GetSize());
- CPPMPYLanguageModel::CPPMPYContext & context = *(CPPMPYContext *) (c);
- CPPMPYnode *n = AddSymbolToNode(context.head, Symbol);
- DASHER_ASSERT(n == context.head->find_symbol(Symbol));
- context.head = n;
- context.order++;
-
- while(context.order > m_iMaxOrder) {
- context.head = context.head->vine;
- context.order--;
- }
-}
//Do _not_ move on the context...
void CPPMPYLanguageModel::LearnPYSymbol(Context c, int pysym) {
@@ -488,7 +404,7 @@ void CPPMPYLanguageModel::LearnPYSymbol(Context c, int pysym) {
return;
DASHER_ASSERT(pysym > 0 && pysym <= m_pPyAlphabet->GetNumberTextSymbols());
- CPPMPYLanguageModel::CPPMPYContext & context = *(CPPMPYContext *) (c);
+ CPPMPYLanguageModel::CPPMContext & context = *(CPPMContext *) (c);
// std::cout<<"py learn context : "<<context.head->symbol<<std::endl;
/* CPPMPYnode * pNode = m_pRoot->child;
@@ -500,9 +416,9 @@ void CPPMPYLanguageModel::LearnPYSymbol(Context c, int pysym) {
std::cout<<" "<<std::endl;
*/
- for (CPPMPYnode *pNode = context.head; pNode; pNode=pNode->vine) {
- if (++pNode->pychild[pysym]>1) {
- //sym was already present
+ for (CPPMnode *pNode = context.head; pNode; pNode=pNode->vine) {
+ if (static_cast<CPPMPYnode *>(pNode)->pychild[pysym]++) {
+ //count non-zero before increment, i.e. sym already present
if (bUpdateExclusion) break;
}
}
@@ -511,135 +427,11 @@ void CPPMPYLanguageModel::LearnPYSymbol(Context c, int pysym) {
//context.order++;
}
-void CPPMPYLanguageModel::dumpSymbol(int sym) {
- if((sym <= 32) || (sym >= 127))
- printf("<%d>", sym);
- else
- printf("%c", sym);
-}
-
-void CPPMPYLanguageModel::dumpString(char *str, int pos, int len)
- // Dump the string STR starting at position POS
-{
- char cc;
- int p;
- for(p = pos; p < pos + len; p++) {
- cc = str[p];
- if((cc <= 31) || (cc >= 127))
- printf("<%d>", cc);
- else
- printf("%c", cc);
- }
-}
-
-void CPPMPYLanguageModel::dumpTrie(CPPMPYLanguageModel::CPPMPYnode *t, int d)
- // diagnostic display of the PPM trie from node t and deeper
-{
-//TODO
-/*
- dchar debug[256];
- int sym;
- CPPMPYnode *s;
- Usprintf( debug,TEXT("%5d %7x "), d, t );
- //TODO: Uncomment this when headers sort out
- //DebugOutput(debug);
- if (t < 0) // pointer to input
- printf( " <" );
- else {
- Usprintf(debug,TEXT( " %3d %5d %7x %7x %7x <"), t->symbol,t->count, t->vine, t->child, t->next );
- //TODO: Uncomment this when headers sort out
- //DebugOutput(debug);
- }
-
- dumpString( dumpTrieStr, 0, d );
- Usprintf( debug,TEXT(">\n") );
- //TODO: Uncomment this when headers sort out
- //DebugOutput(debug);
- if (t != 0) {
- s = t->child;
- while (s != 0) {
- sym =s->symbol;
-
- dumpTrieStr [d] = sym;
- dumpTrie( s, d+1 );
- s = s->next;
- }
- }
-*/
-}
-
-void CPPMPYLanguageModel::dump()
- // diagnostic display of the whole PPM trie
-{
-// TODO:
-/*
- dchar debug[256];
- Usprintf(debug,TEXT( "Dump of Trie : \n" ));
- //TODO: Uncomment this when headers sort out
- //DebugOutput(debug);
- Usprintf(debug,TEXT( "---------------\n" ));
- //TODO: Uncomment this when headers sort out
- //DebugOutput(debug);
- Usprintf( debug,TEXT( "depth node symbol count vine child next context\n") );
- //TODO: Uncomment this when headers sort out
- //DebugOutput(debug);
- dumpTrie( root, 0 );
- Usprintf( debug,TEXT( "---------------\n" ));
- //TODO: Uncomment this when headers sort out
- //DebugOutput(debug);
- Usprintf(debug,TEXT( "\n" ));
- //TODO: Uncomment this when headers sort out
- //DebugOutput(debug);
-*/
-}
-
-////////////////////////////////////////////////////////////////////////
-/// PPMPYnode definitions
-////////////////////////////////////////////////////////////////////////
-
-CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::CPPMPYnode::find_symbol(int sym) const
-// see if symbol is a child of node
-{
- // printf("finding symbol %d at node %d\n",sym,node->id);
-
- //Potentially replace with large scale find algorithm, necessary?
- for (CPPMPYnode * found = child[ min(DIVISION-1, sym/UNITALPH) ]; found; found=found->next) {
- if(found->sym == sym) {
- return found;
- }
- }
-
- return 0;
-}
-
-CPPMPYLanguageModel::CPPMPYnode * CPPMPYLanguageModel::AddSymbolToNode(CPPMPYnode *pNode, int sym) {
- // std::cout<<"Addnode sym "<<sym<<std::endl;
- CPPMPYnode *pReturn = pNode->find_symbol(sym);
-
- // std::cout << sym << ",";
-
- if(pReturn != NULL) {
- // std::cout << "Using existing node" << std::endl;
- pReturn->count++;
- if (!bUpdateExclusion) {
- //update vine contexts too. Must exist if higher-order context does!
- for (CPPMPYnode *v = pReturn->vine; v; v=v->vine) {
- DASHER_ASSERT(v==m_pRoot || v->sym == sym);
- v->count++;
- }
- }
- } else {
- //symbol does not exist at this level
- pReturn = m_NodeAlloc.Alloc(); // count is initialized to 1 but no symbol or vine ptr
- ++NodesAllocated;
- pReturn->sym = sym;
- const int childIdx( min(DIVISION-1, sym/UNITALPH) );
- pReturn->next = pNode->child[childIdx];
- pNode->child[childIdx] = pReturn;
- pReturn->vine = (pNode == m_pRoot) ? m_pRoot : AddSymbolToNode(pNode->vine, sym);
- }
-
- return pReturn;
+CPPMPYLanguageModel::CPPMPYnode *CPPMPYLanguageModel::makeNode(int sym) {
+ CPPMPYnode *res = m_NodeAlloc.Alloc();
+ res->sym=sym;
+ ++NodesAllocated;
+ return res;
}
//Mandarin - PY not enabled for these read-write functions
diff --git a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h
index 312034f..707ef23 100644
--- a/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h
+++ b/Src/DasherCore/LanguageModelling/PPMPYLanguageModel.h
@@ -16,21 +16,9 @@
#include "../../Common/NoClones.h"
#include "../../Common/Allocators/PooledAlloc.h"
-#include "LanguageModel.h"
-#include "PPMPYLanguageModel.h"
-#include "../SCENode.h"
+#include "PPMLanguageModel.h"
#include <vector>
-#include <fstream>
-#include <set>
-//Define alphabet sizes
-#define ALPHSIZE 7610
-#define PYALPHSIZE 1300
-//Implement a multi-branch tree, instead of a binary tree to gain speed: a trade-off between speed and memory; the choice of branch is implied by ranking of the symbol being searched/added
-#define DIVISION 5
-#define UNITALPH (ALPHSIZE/DIVISION)
-#define UNITPY (PYALPHSIZE/DIVISION)
-
namespace Dasher {
@@ -45,56 +33,21 @@ namespace Dasher {
/// is _not_ entered into the context; new method GetPartProbs is used to compute probabilities
/// for the next chinese symbol (which should be entered into context), by filtering to a set.
///
- class CPPMPYLanguageModel:public CLanguageModel, private NoClones {
- private:
- class CPPMPYnode {
- public:
- CPPMPYnode * find_symbol(int sym)const;
- //Each PPM node store DIVISION number of addresses for children, so that each node branches out DIVISION times (as compared to binary); this is aimed to give better run-time speed
- CPPMPYnode * child[DIVISION];
- CPPMPYnode *next;
- CPPMPYnode *vine;
- /// map from pinyin-symbol to count: the number of times each pinyin symbol has been seen in this context
- std::map<symbol,unsigned short int> pychild;
- unsigned short int count;
- symbol sym;
- CPPMPYnode(int sym);
- CPPMPYnode();
- };
-
- class CPPMPYContext {
- public:
- CPPMPYContext(CPPMPYContext const &input) {
- head = input.head;
- order = input.order;
- } CPPMPYContext(CPPMPYnode * _head = 0, int _order = 0):head(_head), order(_order) {
- };
- ~CPPMPYContext() {
- };
- void dump();
- CPPMPYnode *head;
- int order;
- };
-
+ /// That is: from the superclass (CAbstractPPM) perspective, the alphabet is the chinese one;
+ /// hence, contexts store chinese symbols only, and EnterSymbol+LearnSymbol should be called
+ /// with _chinese_ symbol numbers. All PY-alph details are handled in this subclass, with extra
+ /// LearnPYSymbol method for updating the LM's pinyin predictions.
+ class CPPMPYLanguageModel : public CAbstractPPM {
public:
- ///Construct a new PPMPYLanguageModel.
- /// \param pAlph alphabet containing the actual symbols we want to write (i.e. Chinese)
+ ///Construct a new PPMPYLanguageModel.
+ /// \param pAlph alphabet containing the actual symbols we want to write (i.e. Chinese); this
+ /// is the only alphabet passed to the CAbstractPPM superclass.
/// \param pPyAlph alphabet of pinyin phonemes; we will predict probabilities for these
/// based (only) on the preceding _Chinese_ symbols.
CPPMPYLanguageModel(Dasher::CEventHandler * pEventHandler, CSettingsStore * pSettingsStore, const CAlphInfo *pAlph, const CAlphInfo *pPyAlph);
- virtual ~ CPPMPYLanguageModel();
-
- Context CreateEmptyContext();
- void ReleaseContext(Context context);
- Context CloneContext(Context context);
-
- ///Advance the context by entering a chinese symbol
- virtual void EnterSymbol(Context context, int Symbol);
- ///Train the LM with the specified Chinese symbol in that context (moves context on)
- virtual void LearnSymbol(Context context, int Symbol);
///Learns a pinyin symbol in the specified context, but does not move the context on.
- virtual void LearnPYSymbol(Context context, int Symbol);
+ void LearnPYSymbol(Context context, int Symbol);
///Predicts probabilities for the next Pinyin symbol (blending as per PPM,
/// but using the pychild map rather than child CPPMPYnodes).
@@ -110,99 +63,28 @@ namespace Dasher {
/// indicates a possible chinese symbol; on exit, the second element will have been filled in.
void GetPartProbs(Context context, std::vector<std::pair<symbol, unsigned int> > &vChildren, int norm, int iUniform);
- void dump();
-
virtual bool WriteToFile(std::string strFilename);
virtual bool ReadFromFile(std::string strFilename);
- bool RecursiveWrite(CPPMPYnode *pNode, std::map<CPPMPYnode *, int> *pmapIdx, int *pNextIdx, std::ofstream *pOutputFile);
- int GetIndex(CPPMPYnode *pAddr, std::map<CPPMPYnode *, int> *pmapIdx, int *pNextIdx);
- CPPMPYnode *GetAddress(int iIndex, std::map<int, CPPMPYnode*> *pMap);
-
- CPPMPYnode *AddSymbolToNode(CPPMPYnode * pNode, int sym);
-
- void dumpSymbol(int sym);
- void dumpString(char *str, int pos, int len);
- void dumpTrie(CPPMPYnode * t, int d);
-
- CPPMPYContext *m_pRootContext;
- CPPMPYnode *m_pRoot;
-
- int m_iMaxOrder;
- double m_dBackOffConstat;
-
- int NodesAllocated; //inclusive of both Character and PY nodes
-
- bool bUpdateExclusion;
-
+ protected:
+ class CPPMPYnode : public CPPMnode {
+ public:
+ /// map from pinyin-symbol to count: the number of times each pinyin symbol has been seen in this context
+ std::map<symbol,unsigned short int> pychild;
+ inline CPPMPYnode(int sym) : CPPMnode(sym) {}
+ inline CPPMPYnode() : CPPMnode() {}
+ };
+ CPPMPYnode *makeNode(int sym);
- mutable CSimplePooledAlloc < CPPMPYnode > m_NodeAlloc;
- CPooledAlloc < CPPMPYContext > m_ContextAlloc;
-
- std::set<const CPPMPYContext *> m_setContexts;
-
private:
+ int NodesAllocated;
+ mutable CSimplePooledAlloc < CPPMPYnode > m_NodeAlloc;
const CAlphInfo *m_pPyAlphabet;
int m_iAlphSize;
-
};
- /// @}
-
- inline Dasher::CPPMPYLanguageModel::CPPMPYnode::CPPMPYnode(int _sym):sym(_sym) {
- // child.clear();
- // pychild.clear();
-
- next = vine = 0;
- count = 1;
-
- //Added: Mandarin; Setting initial values
- for(int i =0; i <DIVISION; i++){
- child[i] = NULL;
- }
- }
-
- inline CPPMPYLanguageModel::CPPMPYnode::CPPMPYnode() {
- // child.clear();
- // pychild.clear();
-
- next = vine = 0;
-
- count = 1;
-
- //Added: Mandarin; Setting initial values
- for(int i =0; i <DIVISION; i++){
- child[i] = NULL;
- }
-
- }
-
- inline CLanguageModel::Context CPPMPYLanguageModel::CreateEmptyContext() {
- CPPMPYContext *pCont = m_ContextAlloc.Alloc();
- *pCont = *m_pRootContext;
-
- // m_setContexts.insert(pCont);
-
- return (Context) pCont;
- }
-
- inline CLanguageModel::Context CPPMPYLanguageModel::CloneContext(Context Copy) {
- CPPMPYContext *pCont = m_ContextAlloc.Alloc();
- CPPMPYContext *pCopy = (CPPMPYContext *) Copy;
- *pCont = *pCopy;
-
- // m_setContexts.insert(pCont);
-
- return (Context) pCont;
- }
-
- inline void CPPMPYLanguageModel::ReleaseContext(Context release) {
-
- // m_setContexts.erase(m_setContexts.find((CPPMPYContext *) release));
-
- m_ContextAlloc.Free((CPPMPYContext *) release);
- }
+ /// @}
} // end namespace Dasher
#endif // __LanguageModelling__PPMPYLanguageModel_h__
[
Date Prev][
Date Next] [
Thread Prev][
Thread Next]
[
Thread Index]
[
Date Index]
[
Author Index]