[chronojump/chronojump-importer] Fixes --information command line argument and minor style fixes.



commit e3a76b9f8965debfc3f742d48b0fb5763738c488
Author: Carles Pina i Estany <carles pina cat>
Date:   Fri Sep 9 16:33:47 2016 +0100

    Fixes --information command line argument and minor style fixes.

 src/chronojump-importer/chronojump_importer.py     |   56 ++++++++++++--------
 .../chronojump_importer_test.py                    |    9 ++-
 2 files changed, 40 insertions(+), 25 deletions(-)
---
diff --git a/src/chronojump-importer/chronojump_importer.py b/src/chronojump-importer/chronojump_importer.py
index 94d7d98..67bd319 100755
--- a/src/chronojump-importer/chronojump_importer.py
+++ b/src/chronojump-importer/chronojump_importer.py
@@ -63,7 +63,7 @@ class Table:
         self._table_data.append(row)
 
     def add_table(self, table):
-        self._table_data += table._table_data
+        self._table_data += table
 
     def remove_duplicates(self):
         """ Returns a new list without duplicate elements. """
@@ -97,12 +97,21 @@ class Table:
         value and assigned new_new_referenced_column value."""
         for row_to_update in self._table_data:
             old_id = row_to_update.get(column_to_update)
-            for row_referenced in referenced_table._table_data:
+            for row_referenced in referenced_table:
                 old_column_name = old_referenced_column
 
                 if row_referenced.has_column(old_column_name) and row_referenced.get(old_referenced_column) 
== old_id:
                     row_to_update.set(column_to_update, row_referenced.get(new_referenced_column))
 
+    def __iter__(self):
+        return iter(self._table_data)
+
+    def __len__(self):
+        return len(self._table_data)
+
+    def __getitem__(self, index):
+        return self._table_data[index]
+
 
 class Database:
     """ A database represents the database and read/writes tables. """
@@ -123,9 +132,10 @@ class Database:
             self._conn.close()
             self._is_opened = False
 
-    def column_names(self, table, skip_columns=[]):
+    def column_names(self, table, skip_columns=None):
         """ Returns the column names of table. Doesn't return any columns
         indicated by skip_columns. """
+
         self._cursor.execute("PRAGMA table_info({})".format(table))
         result = self._cursor.fetchall()
 
@@ -133,7 +143,7 @@ class Database:
 
         for row in result:
             column_name = row[1]
-            if column_name not in skip_columns:
+            if skip_columns is None or column_name not in skip_columns:
                 names.append(column_name)
 
         assert len(names) > 0
@@ -167,7 +177,7 @@ class Database:
         regardless of any column.
         """
 
-        for row in table._table_data:
+        for row in table:
             if type(matches_columns) == list:
                 where = ""
                 if len(matches_columns) == 0:
@@ -180,11 +190,12 @@ class Database:
                         where += "{} = ?".format(column)
                         where_values.append(row.get(column))
 
-                format_data = {}
-                format_data['table_name'] = table.name
-                format_data['where_clause'] = " WHERE {}".format(where)
+                format_data = {'table_name': table.name,
+                               'where_clause': " WHERE {}".format(where)
+                              }
+
                 sql = "SELECT uniqueID FROM {table_name} {where_clause}".format(**format_data)
-                self.execute_query_and_log(sql, where_values)
+                self._execute_query_and_log(sql, where_values)
 
                 results = self._cursor.fetchall()
 
@@ -223,7 +234,7 @@ class Database:
         format_data = {"column_names": ",".join(column_names_with_prefixes), "table_name": table_name, 
"join_clause": join_clause, "where": where_condition, "group_by": group_by}
 
         sql = "SELECT {column_names} FROM {table_name} {join_clause} {where} 
{group_by}".format(**format_data)
-        self.execute_query_and_log(sql, [])
+        self._execute_query_and_log(sql, [])
 
         results = self._cursor.fetchall()
 
@@ -258,7 +269,7 @@ class Database:
         sql = "INSERT INTO {table_name} ({column_names}) VALUES 
({place_holders})".format(table_name=table_name,
                                                                                         
column_names=",".join(column_names),
                                                                                         
place_holders=",".join(place_holders))
-        self.execute_query_and_log(sql, values)
+        self._execute_query_and_log(sql, values)
 
         new_id = self._cursor.lastrowid
 
@@ -274,9 +285,8 @@ class Database:
 
         while True:
             sql = "SELECT count(*) FROM {table_name} WHERE {column}=?".format(table_name=table_name, 
column=column_name)
-            binding_values = []
-            binding_values.append(data_row.get(column_name))
-            self.execute_query_and_log(sql, binding_values)
+            binding_values = [data_row.get(column_name)]
+            self._execute_query_and_log(sql, binding_values)
 
             results = self._cursor.fetchall()
 
@@ -301,7 +311,7 @@ class Database:
         self._conn.execute("pragma foreign_keys=ON")
         self._cursor = self._conn.cursor()
 
-    def execute_query_and_log(self, sql, where_values):
+    def _execute_query_and_log(self, sql, where_values):
         logging.debug("SQL: {} - values: {}".format(sql, where_values))
         self._cursor.execute(sql, where_values)
 
@@ -361,7 +371,7 @@ def import_database(source_path, destination_path, source_session):
 
     destination_db.write(table=session, matches_columns=destination_db.column_names("Session", ["uniqueID"]))
 
-    new_session_id = session._table_data[0].get('new_uniqueID')
+    new_session_id = session[0].get('new_uniqueID')
 
     # Imports JumpType table
     jump_types = source_db.read(table_name="JumpType",
@@ -399,7 +409,6 @@ def import_database(source_path, destination_path, source_session):
     persons77.add_table(persons77_jump)
     persons77.add_table(persons77_jump_rj)
     persons77.remove_duplicates()
-    # persons77._table_data = remove_duplicates_list(persons77_jump._table_data + 
persons77_jump_rj._table_data)
 
     destination_db.write(table=persons77,
                          matches_columns=["name"])
@@ -432,16 +441,19 @@ def import_database(source_path, destination_path, source_session):
     destination_db.write(table=person_session_77, matches_columns=None)
 
 
-
 def show_information(database_path):
-    database = open_database(database_path, read_only=True)
-    cursor = database.cursor()
+    database = Database(database_path, read_only=True)
 
-    sessions = get_data_from_table(cursor=cursor, table_name="Session", where_condition="1=1")
+    sessions = database.read(table_name="Session", where_condition="1=1")
 
     print("sessionID, date, place, comments")
     for session in sessions:
-        print("{uniqueID}, {date}, {place}, {comments}".format(**session))
+        data = {'uniqueID': session.get('uniqueID'),
+                'date': session.get('date'),
+                'place': session.get('place'),
+                'comments': session.get('comments')
+                }
+        print("{uniqueID}, {date}, {place}, {comments}".format(**data))
 
 
 def process_command_line():
diff --git a/src/chronojump-importer/chronojump_importer_test.py 
b/src/chronojump-importer/chronojump_importer_test.py
index ed5847c..4f78572 100755
--- a/src/chronojump-importer/chronojump_importer_test.py
+++ b/src/chronojump-importer/chronojump_importer_test.py
@@ -104,13 +104,16 @@ class TestImporter(unittest.TestCase):
         table.insert_row(row2)
         table.insert_row(row3)
 
-        self.assertEqual(len(table._table_data), 3)
+        self.assertEqual(len(table), 3)
         table.remove_duplicates()
 
-        self.assertEqual(len(table._table_data), 2)
+        self.assertEqual(len(table), 2)
 
-        # TODO: verify that the contents is right
+        expected = [row1, row3]
+        for row in table:
+            expected.remove(row)
 
+        self.assertEqual(len(expected), 0)
 
     def test_update_ids_from_table(self):
         table_to_update = chronojump_importer.Table("table_to_update")


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]