[chronojump/chronojump-importer] Refactors chronojump_importer and creates the appropiate classes for the operations.



commit dc4c99b7b52b608e3b85a0f2f80ec8f56b2f8e7c
Author: Carles Pina i Estany <carles pina cat>
Date:   Fri Sep 9 13:01:29 2016 +0100

    Refactors chronojump_importer and creates the appropiate classes for the operations.
    
    Now there is a Table class, Row class and Database class.
    
    Some more work is needed to avoid accessing protected data from other
    classes but at the moment all the unit tests passes and the code is
    easier to understand.

 src/chronojump-importer/chronojump_importer.py     |  484 +++++++++++---------
 .../chronojump_importer_test.py                    |   79 +++-
 2 files changed, 319 insertions(+), 244 deletions(-)
---
diff --git a/src/chronojump-importer/chronojump_importer.py b/src/chronojump-importer/chronojump_importer.py
index 5904e42..794e636 100755
--- a/src/chronojump-importer/chronojump_importer.py
+++ b/src/chronojump-importer/chronojump_importer.py
@@ -32,180 +32,270 @@ logging.basicConfig(level=logging.INFO)
 """
 
 
-def get_column_names(cursor, table, skip_columns = []):
-    """ Returns the column names of table. Doesn't return any columns
-    indicated by skip_columns. """
-    cursor.execute("PRAGMA table_info({})".format(table))
-    result = cursor.fetchall()
+class Row:
+    """ A row represents a row in a table: it has column-names and their values."""
+    def __init__(self):
+        self._row = {}
+
+    def add(self, column_name, value):
+        self._row[column_name] = value
 
-    names = []
 
-    for row in result:
-        column_name = row[1]
-        if column_name not in skip_columns:
-            names.append(column_name)
+class Table:
+    """ This class has Table operations. Rows should be inserted and then can
+    be manipulated."""
+    def __init__(self, table_name):
+        self._table_data = []
+        self._table_name = table_name
+
+    def insert_row(self, row):
+        self._table_data.append(row)
+
+    def set_table_name(self, table_name):
+        self._table_data = table_name
+
+    def get_table_name(self):
+        return self._table_name
+
+    def update_session_ids(self, new_session_id):
+        """ table argument is a list of dictionaries. It returns a copy of it
+         replacing each sessionID by new_session_id.
+         """
+        changed = False
+
+        for row in self._table_data:
+            row._row["sessionID"] = new_session_id
+            changed = True
+
+        if len(self._table_data) > 0:
+            assert changed
+
+    def update_ids_from_table(self, column_to_update, referenced_table, old_referenced_column, 
new_referenced_column):
+        """From table_to_update: updates column_to_update if there is referenced_table old_referenced_column 
with the same
+        value and assigned new_new_referenced_column value."""
+        for row_to_update in self._table_data:
+            old_id = row_to_update._row[column_to_update]
+            for row_referenced in referenced_table._table_data:
+                old_column_name = old_referenced_column
+
+                if old_column_name in row_referenced._row and row_referenced._row[old_referenced_column] == 
old_id:
+                    row_to_update._row[column_to_update] = row_referenced._row[new_referenced_column]
+
+
+class Database:
+    """ A database represents the database and read/writes tables. """
+    def __init__(self, source_path, read_only):
+        self._is_opened = False
+        self._cursor = None
+        self._conn = None
+
+        self.open_database(source_path, read_only)
+        self._is_opened = True
+
+    def __del__(self):
+        self.close()
+
+    def close(self):
+        if self._is_opened:
+            self._conn.commit()
+            self._conn.close()
+            self._is_opened = False
+
+    def column_names(self, table, skip_columns = []):
+        """ Returns the column names of table. Doesn't return any columns
+        indicated by skip_columns. """
+        self._cursor.execute("PRAGMA table_info({})".format(table))
+        result = self._cursor.fetchall()
+
+        names = []
+
+        for row in result:
+            column_name = row[1]
+            if column_name not in skip_columns:
+                names.append(column_name)
+
+        assert len(names) > 0
+        return names
+
+    @staticmethod
+    def add_prefix(list_of_elements, prefix):
+        """  Returns a copy of list_of_elements prefixing each element with prefix. """
+        result = []
+
+        for element in list_of_elements:
+            result.append("{}{}".format(prefix, element))
+
+        return result
+
+    def insert_table(self, table, matches_columns, avoids_duplicate_column=None):
+        """ Data is a list of dictionaries and the keys should match the columns
+        of table_name.
+
+        Inserts the data and returns a copy of data with a new key per each
+        dictionary: new_unique_id. This is the new uniqueID for this row if it
+        didn't exist or the old one. The matching is based on matches_columns.
+
+        For example, if matches_columns = ["Name"] it will insert a new row
+        in the table if the name didn't exist and will add new_unique_id
+        with this unique id.
+        If name already existed it will NOT insert anything in the table
+        but will add a new_unique_id with the ID of this person.
+
+        If matches_columns is None it means that will insert the data
+        regardless of any column.
+        """
+
+        for row in table._table_data:
+            if type(matches_columns) == list:
+                where = ""
+                if len(matches_columns) == 0:
+                    where = "1=1"
+                else:
+                    where_values = []
+                    for column in matches_columns:
+                        if where != "":
+                            where += " AND "
+                        where += "{} = ?".format(column)
+                        where_values.append(row._row[column])
+
+                format_data = {}
+                format_data['table_name'] = table.get_table_name()
+                format_data['where_clause'] = " WHERE {}".format(where)
+                sql = "SELECT uniqueID FROM {table_name} {where_clause}".format(**format_data)
+                execute_query_and_log(self._cursor, sql, where_values)
+
+                results = self._cursor.fetchall()
+
+            if matches_columns is None or len(results) == 0:
+                # Needs to insert it
+
+                self.avoids_column_duplicate(table_name=table.get_table_name(), 
column_name=avoids_duplicate_column, data_row=row)
+
+                new_id = self.insert_dictionary_into_table(table.get_table_name(), row)
+                row.add('importer_action', 'inserted')
 
-    assert len(names) > 0
-    return names
-
-
-def remove_elements(list_of_elements, elements_to_remove):
-    """Returns a new list with list_of_elements without elements_to_remove"""
-    result = []
-
-    for element in list_of_elements:
-        if element not in elements_to_remove:
-            result.append(element)
-
-    return result
-
-
-def add_prefix(list_of_elements, prefix):
-    """  Returns a copy of list_of_elements prefixing each element with prefix. """
-    result = []
-
-    for element in list_of_elements:
-        result.append("{}{}".format(prefix, element))
-
-    return result
-
-
-def insert_data_into_table(cursor, table_name, data, matches_columns, avoids_duplicate_column=None):
-    """ Data is a list of dictionaries and the keys should match the columns
-    of table_name.
+            else:
+                # Uses the existing id as new_unique_id
+                new_id = results[0][0]
+                row.add('importer_action', 'reused')
 
-    Inserts the data and returns a copy of data with a new key per each
-    dictionary: new_unique_id. This is the new uniqueID for this row if it
-    didn't exist or the old one. The matching is based on matches_columns.
+            row.add('new_uniqueID', new_id)
 
-    For example, if matches_columns = ["Name"] it will insert a new row
-    in the table if the name didn't exist and will add new_unique_id
-    with this unique id.
-    If name already existed it will NOT insert anything in the table
-    but will add a new_unique_id with the ID of this person.
+        # TODO print_summary(table, data_result)
 
-    If matches_columns is None it means that will insert the data
-    regardless of any column.
-    """
+    def get_data_from_table(self, table_name, where_condition, join_clause ="", group_by_clause=""):
+        """ Returns a list of dictionaries of the table table_name applying the where_condition, join_clause 
and group_by_clause. """
+        column_names = self.column_names(table_name)
 
-    data_result = copy.deepcopy(data)
+        column_names_with_prefixes = self.add_prefix(column_names, "{}.".format(table_name))
 
-    for row in data_result:
-        if type(matches_columns) == list:
-            where = ""
-            if len(matches_columns) == 0:
-                where = "1=1"
-            else:
-                where_values = []
-                for column in matches_columns:
-                    if where != "":
-                        where += " AND "
-                    where += "{} = ?".format(column)
-                    where_values.append(row[column])
+        where_condition = " WHERE {} ".format(where_condition)
+        assert '"' not in where_condition   # Easy way to avoid problems - where_condition is only used by 
us (programmers) and
+                                            # it doesn't depend on user data.
 
-            format_data = {}
-            format_data['table_name'] = table_name
-            format_data['where_clause'] = " WHERE {}".format(where)
-            sql = "SELECT uniqueID FROM {table_name} {where_clause}".format(**format_data)
-            execute_query_and_log(cursor, sql, where_values)
+        if group_by_clause != "":
+            group_by = " GROUP BY {}".format(group_by_clause)
+        else:
+            group_by = ""
 
-            results = cursor.fetchall()
+        format_data = {"column_names": ",".join(column_names_with_prefixes), "table_name": table_name, 
"join_clause": join_clause, "where": where_condition, "group_by": group_by}
 
-        if matches_columns is None or len(results) == 0:
-            # Needs to insert it
+        sql = "SELECT {column_names} FROM {table_name} {join_clause} {where} 
{group_by}".format(**format_data)
+        execute_query_and_log(self._cursor, sql, [])
 
-            avoids_column_duplicate(cursor=cursor, table_name=table_name, 
column_name=avoids_duplicate_column, data_row=row)
+        results = self._cursor.fetchall()
 
-            new_id = insert_dictionary_into_table(cursor, table_name, row)
-            row['importer_action'] = "inserted"
+        table = Table(table_name)
 
-        else:
-            # Uses the existing id as new_unique_id
-            new_id = results[0][0]
-            row['importer_action'] = "reused"
+        for row in results:
+            table_row = Row()
+            for i, col in enumerate(row):
+                table_row.add(column_names[i], col)
 
-        row['new_uniqueID'] = new_id
+            table.insert_row(table_row)
 
-    print_summary(table_name, data_result)
-    return data_result
+        return table
 
+    def insert_dictionary_into_table(self, table_name, row, skip_columns=["uniqueID"]):
+        """ Inserts the row (it's a dictionary) into table_name and skips skip_column.
+        Returns the new Id of the inserted row.
+        """
+        values = []
+        column_names = []
+        place_holders = []
+        table_column_names = self.column_names(table_name)
 
-def get_data_from_table(cursor, table_name, where_condition, join_clause ="", group_by_clause=""):
-    """ Returns a list of dictionaries of the table table_name applying the where_condition, join_clause and 
group_by_clause. """
-    column_names = get_column_names(cursor, table_name)
+        for column_name in row._row.keys():
+            if column_name in skip_columns or column_name not in table_column_names:
+                continue
 
-    column_names_with_prefixes = add_prefix(column_names, "{}.".format(table_name))
+            values.append(row._row[column_name])
+            column_names.append(column_name)
+            place_holders.append("?")
 
-    where_condition = " WHERE {} ".format(where_condition)
-    assert '"' not in where_condition   # Easy way to avoid problems - where_condition is only used by us 
(programmers) and
-                                        # it doesn't depend on user data.
+        sql = "INSERT INTO {table_name} ({column_names}) VALUES 
({place_holders})".format(table_name=table_name,
+                                                                                        
column_names=",".join(column_names),
+                                                                                        
place_holders=",".join(place_holders))
+        execute_query_and_log(self._cursor, sql, values)
 
-    if group_by_clause != "":
-        group_by = " GROUP BY {}".format(group_by_clause)
-    else:
-        group_by = ""
+        new_id = self._cursor.lastrowid
 
-    format_data = {"column_names": ",".join(column_names_with_prefixes), "table_name": table_name, 
"join_clause": join_clause, "where": where_condition, "group_by": group_by}
+        return new_id
 
-    sql = "SELECT {column_names} FROM {table_name} {join_clause} {where} {group_by}".format(**format_data)
-    execute_query_and_log(cursor, sql, [])
+    def avoids_column_duplicate(self, table_name, column_name, data_row):
+        """ Makes sure that data_row[column_name] doesn't exist in table_name. If it exists
+        it changes data_row[column_name] to the same with (1) or (2)"""
+        if column_name is None:
+            return
 
-    results = cursor.fetchall()
+        data_row._row['old_' + column_name] = data_row._row[column_name]
 
-    data = []
-    for row in results:
-        r = {}
-        for i, col in enumerate(row):
-            r[column_names[i]] = col
+        while True:
+            sql = "SELECT count(*) FROM {table_name} WHERE {column}=?".format(table_name=table_name, 
column=column_name)
+            binding_values = []
+            binding_values.append(data_row._row[column_name])
+            execute_query_and_log(self._cursor, sql, binding_values)
 
-        data.append(r)
+            results = self._cursor.fetchall()
 
-    return data
+            if results[0][0] == 0:
+                break
+            else:
+                data_row._row[column_name] = increment_suffix(data_row._row[column_name])
+                data_row._row['new_' + column_name] = data_row._row[column_name]
+
+    def open_database(self, filename, read_only):
+        """Opens the database specified by filename. If read_only is True
+        the database cannot be changed.
+        """
+        if read_only:
+            mode = "ro"
+        else:
+            mode = "rw"
 
+        uri = "file:{}?mode={}".format(filename,mode)
+        self._conn = sqlite3.connect(uri, uri=True)
 
-def insert_dictionary_into_table(cursor, table_name, row, skip_columns=["uniqueID"]):
-    """ Inserts the row (it's a dictionary) into table_name and skips skip_column.
-    Returns the new Id of the inserted row.
-    """
-    values = []
-    column_names = []
-    place_holders = []
-    table_column_names = get_column_names(cursor, table_name)
+        self._conn.execute("pragma foreign_keys=ON")
+        self._cursor = self._conn.cursor()
 
-    for column_name in row.keys():
-        if column_name in skip_columns or column_name not in table_column_names:
-            continue
+def remove_elements(list_of_elements, elements_to_remove):
+    """Returns a new list with list_of_elements without elements_to_remove"""
+    result = []
 
-        values.append(row[column_name])
-        column_names.append(column_name)
-        place_holders.append("?")
+    for element in list_of_elements:
+        if element not in elements_to_remove:
+            result.append(element)
 
-    sql = "INSERT INTO {table_name} ({column_names}) VALUES ({place_holders})".format(table_name=table_name,
-                                                                                    
column_names=",".join(column_names),
-                                                                                    
place_holders=",".join(place_holders))
-    execute_query_and_log(cursor, sql, values)
+    return result
 
-    new_id = cursor.lastrowid
 
-    return new_id
 
 
-def update_session_ids(table, new_session_id):
-    """ table argument is a list of dictionaries. It returns a copy of it
-     replacing each sessionID by new_session_id.
-     """
-    result = copy.deepcopy(table)
 
-    changed = False
 
-    for row in result:
-        row["sessionID"] = new_session_id
-        changed = True
 
-    if len(table) > 0:
-        assert changed
 
-    return result
 
 
 def print_summary(table_name, table_data):
@@ -248,43 +338,6 @@ def increment_suffix(value):
         return "{} ({})".format(base_name, counter)
 
 
-def avoids_column_duplicate(cursor, table_name, column_name, data_row):
-    """ Makes sure that data_row[column_name] doesn't exist in table_name. If it exists
-    it changes data_row[column_name] to the same with (1) or (2)"""
-    if column_name is None:
-        return
-
-    data_row['old_' + column_name] = data_row[column_name]
-
-    while True:
-        sql = "SELECT count(*) FROM {table_name} WHERE {column}=?".format(table_name=table_name, 
column=column_name)
-        binding_values = []
-        binding_values.append(data_row[column_name])
-        execute_query_and_log(cursor, sql, binding_values)
-
-        results = cursor.fetchall()
-
-        if results[0][0] == 0:
-            break
-        else:
-            data_row[column_name] = increment_suffix(data_row[column_name])
-            data_row['new_' + column_name] = data_row[column_name]
-
-
-def update_ids_from_table(table_to_update, column_to_update, referenced_table, old_referenced_column, 
new_referenced_column):
-    """From table_to_update: updates column_to_update if there is referenced_table old_referenced_column 
with the same
-    value and assigned new_new_referenced_column value."""
-    result = copy.deepcopy(table_to_update)
-
-    for row_to_update in result:
-        old_id = row_to_update[column_to_update]
-        for row_referenced in referenced_table:
-            old_column_name = old_referenced_column
-
-            if old_column_name in row_referenced and row_referenced[old_referenced_column] == old_id:
-                row_to_update[column_to_update] = row_referenced[new_referenced_column]
-
-    return result
 
 
 def import_database(source_path, destination_path, source_session):
@@ -293,17 +346,14 @@ def import_database(source_path, destination_path, source_session):
     logging.debug("source path:" + source_path)
     logging.debug("destination path:" + destination_path)
 
-    source_db = open_database(source_path, read_only=True)
-    destination_db = open_database(destination_path, read_only=False)
-
-    source_cursor = source_db.cursor()
-    destination_cursor = destination_db.cursor()
+    source_db = Database(source_path, read_only=True)
+    destination_db = Database(destination_path, read_only=False)
 
     # Imports the session
-    session = get_data_from_table(cursor=source_cursor, table_name="Session",
-                                  where_condition="Session.uniqueID={}".format(source_session))
+    session = source_db.get_data_from_table(table_name="Session",
+                                            where_condition="Session.uniqueID={}".format(source_session))
 
-    number_of_matching_sessions = len(session)
+    number_of_matching_sessions = len(session._table_data)
 
     if number_of_matching_sessions == 0:
         print("Trying to import {session} from {source_file} and it doesn't exist. 
Cancelling...".format(session=source_session,
@@ -314,96 +364,74 @@ def import_database(source_path, destination_path, source_session):
                                                                                                         
source_file=source_path))
         sys.exit(1)
 
-    session = insert_data_into_table(cursor=destination_cursor, table_name="Session", data=session,
-                                     matches_columns=get_column_names(destination_cursor, "Session", 
["uniqueID"]))
+    destination_db.insert_table(table=session, matches_columns=destination_db.column_names("Session", 
["uniqueID"]))
 
-    new_session_id = session[0]['new_uniqueID']
+    new_session_id = session._table_data[0]._row['new_uniqueID']
 
     # Imports JumpType table
-    jump_types = get_data_from_table(cursor=source_cursor, table_name="JumpType",
+    jump_types = source_db.get_data_from_table(table_name="JumpType",
                                      where_condition="Session.uniqueID={}".format(source_session),
                                      join_clause="LEFT JOIN Jump ON JumpType.name=Jump.type LEFT JOIN 
Session ON Jump.sessionID=Session.uniqueID",
                                      group_by_clause="JumpType.uniqueID")
 
-    jump_types = insert_data_into_table(cursor=destination_cursor, table_name="JumpType", data=jump_types,
-                           matches_columns=get_column_names(destination_cursor, "JumpType", ["uniqueID"]),
+    destination_db.insert_table(table=jump_types,
+                           matches_columns=destination_db.column_names("JumpType", ["uniqueID"]),
                            avoids_duplicate_column="name")
 
     # Imports JumpRjType table
-    jump_rj_types = get_data_from_table(cursor=source_cursor, table_name="JumpRjType",
+    jump_rj_types = source_db.get_data_from_table(table_name="JumpRjType",
                                         where_condition="Session.uniqueID={}".format(source_session),
                                         join_clause="LEFT JOIN JumpRj ON JumpRjType.name=JumpRj.type LEFT 
JOIN Session on JumpRj.sessionID=Session.uniqueID",
                                         group_by_clause="JumpRjType.uniqueID")
 
-    jump_rj_types = insert_data_into_table(cursor=destination_cursor, table_name="JumpRjType", 
data=jump_rj_types,
-                           matches_columns=get_column_names(destination_cursor, "JumpRjType", ["uniqueID"]),
+    jump_rj_types = destination_db.insert_table(table=jump_rj_types,
+                           matches_columns=destination_db.column_names("JumpRjType", ["uniqueID"]),
                            avoids_duplicate_column="name")
 
     # Imports Persons77 used by JumpRj table
-    persons77_jump_rj = get_data_from_table(cursor=source_cursor, table_name="Person77",
+    persons77_jump_rj = source_db.get_data_from_table(table_name="Person77",
                                             where_condition="JumpRj.sessionID={}".format(source_session),
                                             join_clause="LEFT JOIN JumpRj ON 
Person77.uniqueID=JumpRj.personID",
                                             group_by_clause="Person77.uniqueID")
 
     # Imports Person77 used by Jump table
-    persons77_jump = get_data_from_table(cursor=source_cursor, table_name="Person77",
+    persons77_jump = source_db.get_data_from_table(table_name="Person77",
                                          where_condition="Jump.sessionID={}".format(source_session),
                                          join_clause="LEFT JOIN Jump ON Person77.uniqueID=Jump.personID",
                                          group_by_clause="Person77.uniqueID")
 
-    persons77 = remove_duplicates_list(persons77_jump + persons77_jump_rj)
+    persons77 = Table("person77")
+    persons77._table_data = remove_duplicates_list(persons77._table_data + persons77_jump_rj._table_data)
 
-    persons77 = insert_data_into_table(cursor=destination_cursor, table_name="Person77", data=persons77,
+    destination_db.insert_table(table=persons77,
                                             matches_columns=["name"])
 
     # Imports JumpRj table (with the new Person77's uniqueIDs)
-    jump_rj = get_data_from_table(cursor=source_cursor, table_name="JumpRj",
+    jump_rj = source_db.get_data_from_table(table_name="JumpRj",
                                   where_condition="JumpRj.sessionID={}".format(source_session))
 
-    jump_rj = update_ids_from_table(jump_rj, "personID", persons77, "uniqueID", "new_uniqueID")
-    jump_rj = update_session_ids(jump_rj, new_session_id)
-    jump_rj = update_ids_from_table(jump_rj, "type", persons77, "old_name", "new_name")
+    jump_rj.update_ids_from_table("personID", persons77, "uniqueID", "new_uniqueID")
+    jump_rj.update_session_ids(new_session_id)
+    jump_rj.update_ids_from_table("type", persons77, "old_name", "new_name")
 
-    insert_data_into_table(cursor=destination_cursor, table_name="JumpRj", data=jump_rj, 
matches_columns=None)
+    destination_db.insert_table(table=jump_rj, matches_columns=None)
 
     # Imports Jump table (with the new Person77's uniqueIDs)
-    jump = get_data_from_table(cursor=source_cursor, table_name="Jump",
+    jump = source_db.get_data_from_table(table_name="Jump",
                                where_condition="Jump.sessionID={}".format(source_session))
 
-    jump = update_ids_from_table(jump, "personID", persons77, "uniqueID", "new_uniqueID")
-    jump = update_session_ids(jump, new_session_id)
-    jump = update_ids_from_table(jump, "type", jump_types, "old_name", "new_name")
+    jump.update_ids_from_table("personID", persons77, "uniqueID", "new_uniqueID")
+    jump.update_session_ids(new_session_id)
+    jump.update_ids_from_table("type", jump_types, "old_name", "new_name")
 
-    insert_data_into_table(cursor=destination_cursor, table_name="Jump", data=jump, matches_columns=None)
+    destination_db.insert_table(table=jump, matches_columns=None)
 
     # Imports PersonSession77
-    person_session_77 = get_data_from_table(cursor=source_cursor, table_name="PersonSession77",
+    person_session_77 = source_db.get_data_from_table(table_name="PersonSession77",
                                             
where_condition="PersonSession77.sessionID={}".format(source_session))
-    person_session_77 = update_ids_from_table(person_session_77, "personID", persons77, "uniqueID", 
"new_uniqueID")
-    person_session_77 = update_session_ids(person_session_77, new_session_id)
-    insert_data_into_table(cursor=destination_cursor, table_name="PersonSession77", data=person_session_77, 
matches_columns=None)
-
-    destination_db.commit()
-    destination_db.close()
-
-    source_db.close()
-
-
-def open_database(filename, read_only):
-    """Opens the database specified by filename. If read_only is True
-    the database cannot be changed.
-    """
-    if read_only:
-        mode = "ro"
-    else:
-        mode = "rw"
-
-    uri = "file:{}?mode={}".format(filename,mode)
-    conn = sqlite3.connect(uri, uri=True)
-
-    conn.execute("pragma foreign_keys=ON")
-
-    return conn
+    person_session_77.update_ids_from_table("personID", persons77, "uniqueID", "new_uniqueID")
+    person_session_77.update_session_ids(new_session_id)
+    destination_db.insert_table(table=person_session_77, matches_columns=None)
 
 
 def execute_query_and_log(cursor, sql, where_values):
diff --git a/src/chronojump-importer/chronojump_importer_test.py 
b/src/chronojump-importer/chronojump_importer_test.py
index 34dc42c..2aaa47f 100755
--- a/src/chronojump-importer/chronojump_importer_test.py
+++ b/src/chronojump-importer/chronojump_importer_test.py
@@ -64,7 +64,7 @@ class TestImporter(unittest.TestCase):
         self.maxDiff = None
         self.assertEqual(diff, "")
 
-        # shutil.rmtree(self.temporary_directory_path)
+        shutil.rmtree(temporary_directory_path)
 
     def test_increment_suffix(self):
         self.assertEqual(chronojump_importer.increment_suffix("Free Jump"), "Free Jump (1)")
@@ -78,15 +78,25 @@ class TestImporter(unittest.TestCase):
 
     def test_add_prefix(self):
         l=['hello', 'chronojump']
-        actual = chronojump_importer.add_prefix(l, "test_")
+        actual = chronojump_importer.Database.add_prefix(l, "test_")
         self.assertEqual(actual, ["test_hello", "test_chronojump"])
 
     def test_update_session_ids(self):
-        table=[{'sessionID': 2, 'name': 'hello'}, {'sessionID':3, 'name':'bye'}]
+        table = chronojump_importer.Table("test")
+        row1 = chronojump_importer.Row()
+        row1.add("sessionID", 2)
+        row1.add("name", "john")
 
-        actual = chronojump_importer.update_session_ids(table, 4)
-        for row in actual:
-            self.assertEqual(row['sessionID'], 4)
+        row2 = chronojump_importer.Row()
+        row2.add("sessionID", 3)
+        row2.add("name", "mark")
+
+        table.insert_row(row1)
+        table.insert_row(row2)
+
+        table.update_session_ids(4)
+        for row in table._table_data:
+            self.assertEqual(row._row['sessionID'], 4)
 
     def test_remove_duplicates_list(self):
         l = [1,1,2,3,2]
@@ -96,29 +106,66 @@ class TestImporter(unittest.TestCase):
         self.assertEqual(sorted(actual), sorted([1,2,3]))
 
     def test_update_ids_from_table(self):
-        table_to_update = [{'name': 'john', 'personId': 1}, {'name': 'mark', 'personId': 4}, {'name': 
'alex', 'personId': 5}]
+        table_to_update = chronojump_importer.Table("table_to_update")
+        row1 = chronojump_importer.Row()
+        row1.add("name", "john")
+        row1.add("personId", 1)
+
+        row2 = chronojump_importer.Row()
+        row2.add("name", "mark")
+        row2.add("personId", 4)
+
+        row3 = chronojump_importer.Row()
+        row3.add("name", "alex")
+        row3.add("personId", 5)
+
+        table_to_update.insert_row(row1)
+        table_to_update.insert_row(row2)
+        table_to_update.insert_row(row3)
+
+
         column_to_update = 'personId'
-        referenced_table = [{'personId': 11, 'old_personId': 1}, {'personId': 12, 'old_personId': 4}]
+
+        referenced_table = chronojump_importer.Table("referenced_table")
+        row4 = chronojump_importer.Row()
+        row4.add("personId", 11)
+        row4.add("old_personId", 1)
+
+        row5 = chronojump_importer.Row()
+        row5.add("personId", 12)
+        row5.add("old_personId", 4)
+
+        referenced_table.insert_row(row4)
+        referenced_table.insert_row(row5)
+
         old_reference_column = 'old_personId'
         new_reference_column = 'personId'
 
-        actual = chronojump_importer.update_ids_from_table(table_to_update, column_to_update, 
referenced_table, old_reference_column, new_reference_column)
+        table_to_update.update_ids_from_table(column_to_update, referenced_table, old_reference_column, 
new_reference_column)
+
+        self.assertEqual(len(table_to_update._table_data), 3)
+
+        def verify_exists(table, name, personId):
+            for row in table._table_data:
+                if row._row['name'] == name and row._row['personId'] == personId:
+                    return True
+
+            return False
 
-        self.assertEqual(len(actual), 3)
-        self.assertTrue({'name': 'john', 'personId': 11} in actual)
-        self.assertTrue({'name': 'mark', 'personId': 12} in actual)
-        self.assertTrue({'name': 'alex', 'personId': 5} in actual)
+        self.assertTrue(verify_exists(table_to_update, "john", 11))
+        self.assertTrue(verify_exists(table_to_update, "mark", 12))
+        self.assertTrue(verify_exists(table_to_update, "alex", 5))
 
     def test_get_column_names(self):
         filename = tempfile.mktemp(prefix="chronojump_importer_test_get_column_", suffix=".sqlite")
         open(filename, 'a').close()
 
-        database = chronojump_importer.open_database(filename, read_only=False)
-        cursor = database.cursor()
+        database = chronojump_importer.Database(filename, read_only=False)
+        cursor = database._cursor
 
         cursor.execute("CREATE TABLE test (uniqueID INTEGER, name TEXT, surname1 TEXT, surname2 TEXT, age 
INTEGER)")
 
-        columns = chronojump_importer.get_column_names(cursor=cursor, table="test", 
skip_columns=["surname1", "surname2"])
+        columns = database.column_names(table="test", skip_columns=["surname1", "surname2"])
 
         self.assertEqual(columns, ["uniqueID", "name", "age"])
 


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]