[billreminder: 6/17] Added option to use fake db on tests and some fixes.



commit ae0824e3e09d05587c7a939fce373f7e98f44f0a
Author: Luiz Armesto <luiz armesto gmail com>
Date:   Mon Aug 16 00:44:11 2010 -0300

    Added option to use fake db on tests and some fixes.

 src/lib/dal.py |   40 ++++++++++++++++++++++++++++++++--------
 1 files changed, 32 insertions(+), 8 deletions(-)
---
diff --git a/src/lib/dal.py b/src/lib/dal.py
index 85a3d3d..be6aa50 100644
--- a/src/lib/dal.py
+++ b/src/lib/dal.py
@@ -2,6 +2,7 @@
 
 import os
 import sys
+import warnings
 
 try:
     from sqlalchemy import create_engine
@@ -20,7 +21,7 @@ from xdg.BaseDirectory import *
 
 class DAL(object):
 
-    def __init__(self):
+    def __init__(self, fake=False):
 
         new_setup = False
 
@@ -33,7 +34,12 @@ class DAL(object):
             # Safe to assume that this is a new setup.
             new_setup = True
 
-        self.engine = create_engine('sqlite:///%s' % os.path.join(data_dir, DB_NAME))
+        # The fake mode is used to run tests using a clean db created on memory
+	if not fake:
+            self.engine = create_engine('sqlite:///%s' % os.path.join(data_dir, DB_NAME))
+        else:
+            self.engine = create_engine('sqlite:///:memory:', echo=False)
+
         self.Session = sessionmaker(bind=self.engine)
 
         # Creates all database tables
@@ -57,20 +63,33 @@ class DAL(object):
                     bill.alarmDate = dbobject.alarmDate
                     bill.notes = dbobject.notes
                     bill.paid = dbobject.paid
+                    
                     if dbobject.category:
                         try:
-                            category = session.query(Category).filter_by(name=dbobject.category.name).one()
+                            category = session.query(Category).filter(Category.name==dbobject.category.name).one()
                             bill.category = category
-                        except Exception, e:
-                            print "Failed to retrieve category \"%s\" for bill \"%s\": %s" \
-                                % (dbobject.payee, dbobject.category.name, str(e))
-
+                        except NoResultFound, e:
+                            warnings.warn("Failed to retrieve category \"%s\" for bill \"%s\". Creating category." \
+	                        % (dbobject.category.name, dbobject.payee), RuntimeWarning)
+                
                 if session.dirty:
                     session.commit()
 
+                dbobject_id = bill.id
+
             except NoResultFound, e:
+                if dbobject.category:
+                    try:
+                        category = session.query(Category).filter(Category.name==dbobject.category.name).one()
+                        del(dbobject.category)
+                        dbobject.category = category
+                    except NoResultFound, e:
+                        warnings.warn("Failed to retrieve category \"%s\" for bill \"%s\". Creating category." \
+	                    % (dbobject.category.name, dbobject.payee), RuntimeWarning)
+                
                 session.add(dbobject)
                 session.commit()
+                dbobject_id = dbobject.id
 
             except Exception, e:
                 session.rollback()
@@ -88,15 +107,20 @@ class DAL(object):
                 if session.dirty:
                     session.commit()
 
+                dbobject_id = category.id
+
             except NoResultFound, e:
                 session.add(dbobject)
-                session.commit()
+                session.commit()              
+                dbobject_id = dbobject.id
             except Exception, e:
                 session.rollback()
                 print str(e)
             finally:
                 session.close()
 
+        return dbobject_id
+
     def edit(self, dbobject):
 
         session = self.Session()



[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]