[hyena/sqlite] [Hyena.Data.Sqlite] Add type tests, remove QueryScalar



commit 49c3ac2064fe2ff46120a78cc217cd2a91610595
Author: Gabriel Burt <gabriel burt gmail com>
Date:   Wed Nov 10 23:36:02 2010 -0600

    [Hyena.Data.Sqlite] Add type tests, remove QueryScalar
    
    Instead of QueryScalar, use Query<T>

 .../Hyena.Data.Sqlite/HyenaSqliteCommand.cs        |    2 +-
 Hyena.Data.Sqlite/Hyena.Data.Sqlite/Sqlite.cs      |  118 ++++++++++++++----
 .../Hyena.Data.Sqlite/Tests/SqliteTests.cs         |  134 +++++++++++++++++---
 3 files changed, 211 insertions(+), 43 deletions(-)
---
diff --git a/Hyena.Data.Sqlite/Hyena.Data.Sqlite/HyenaSqliteCommand.cs b/Hyena.Data.Sqlite/Hyena.Data.Sqlite/HyenaSqliteCommand.cs
index c513dc7..1e746a3 100644
--- a/Hyena.Data.Sqlite/Hyena.Data.Sqlite/HyenaSqliteCommand.cs
+++ b/Hyena.Data.Sqlite/Hyena.Data.Sqlite/HyenaSqliteCommand.cs
@@ -104,7 +104,7 @@ namespace Hyena.Data.Sqlite
                         break;
 
                     case HyenaCommandType.Scalar:
-                        result = connection.QueryScalar (command_text);
+                        result = connection.Query<object> (command_text);
                         break;
 
                     case HyenaCommandType.Execute:
diff --git a/Hyena.Data.Sqlite/Hyena.Data.Sqlite/Sqlite.cs b/Hyena.Data.Sqlite/Hyena.Data.Sqlite/Sqlite.cs
index 09b36c0..89d4bf0 100644
--- a/Hyena.Data.Sqlite/Hyena.Data.Sqlite/Sqlite.cs
+++ b/Hyena.Data.Sqlite/Hyena.Data.Sqlite/Sqlite.cs
@@ -82,10 +82,10 @@ namespace Hyena.Data.Sqlite
             return new Statement (this, sql) { ReaderDisposes = true }.Query ();
         }
 
-        public object QueryScalar (string sql)
+        public T Query<T> (string sql)
         {
             using (var stmt = new Statement (this, sql)) {
-                return stmt.QueryScalar ();
+                return stmt.Query<T> ();
             }
         }
 
@@ -150,6 +150,8 @@ namespace Hyena.Data.Sqlite
         bool Read ();
         object this[int i] { get; }
         object this[string columnName] { get; }
+        T Get<T> (int i);
+        T Get<T> (string columnName);
         int FieldCount { get; }
     }
 
@@ -239,9 +241,14 @@ namespace Hyena.Data.Sqlite
             Dispose ();
         }
 
+        object [] null_val = new object [] { null };
         public Statement Bind (params object [] vals)
         {
             CheckDisposed ();
+
+            if (vals == null && ParameterCount == 1)
+                vals = null_val;
+
             if (vals == null || vals.Length != ParameterCount || ParameterCount == 0)
                 throw new ArgumentException ("vals", String.Format ("Statement has {0} parameters", ParameterCount));
 
@@ -253,19 +260,25 @@ namespace Hyena.Data.Sqlite
 
                 if (o == null)
                     code = Native.sqlite3_bind_null (Ptr, i);
-                else if (o is double || o is float)
+                else if (o is double)
                     code = Native.sqlite3_bind_double (Ptr, i, (double)o);
-                else if (o is int || o is uint)
+                else if (o is float)
+                    code = Native.sqlite3_bind_double (Ptr, i, (double)(float)o);
+                else if (o is int)
                     code = Native.sqlite3_bind_int (Ptr, i, (int)o);
-                else if (o is long || o is ulong)
+                else if (o is uint)
+                    code = Native.sqlite3_bind_int (Ptr, i, (int)(uint)o);
+                else if (o is long)
                     code = Native.sqlite3_bind_int64 (Ptr, i, (long)o);
+                else if (o is ulong)
+                    code = Native.sqlite3_bind_int64 (Ptr, i, (long)(ulong)o);
                 else if (o is byte[]) {
                     byte [] bytes = o as byte[];
                     code = Native.sqlite3_bind_blob (Ptr, i, bytes, bytes.Length, (IntPtr)(-1));
                 } else {
                     // C# strings are UTF-16, so 2 bytes per char
                     // -1 for the last arg is the TRANSIENT destructor type so that sqlite will make its own copy of the string
-                    string str = o.ToString ();
+                    string str = (o as string) ?? o.ToString ();
                     code = Native.sqlite3_bind_text16 (Ptr, i, str, str.Length * 2, (IntPtr)(-1));
                 }
 
@@ -283,13 +296,13 @@ namespace Hyena.Data.Sqlite
 
         private void Reset ()
         {
+            CheckDisposed ();
             CheckError (Native.sqlite3_reset (ptr));
             Reading = false;
         }
 
         public IEnumerator<IDataReader> GetEnumerator ()
         {
-            CheckDisposed ();
             Reset ();
             while (reader.Read ()) {
                 yield return reader;
@@ -303,22 +316,20 @@ namespace Hyena.Data.Sqlite
 
         public Statement Execute ()
         {
-            CheckDisposed ();
             Reset ();
             reader.Read ();
             return this;
         }
 
-        public object QueryScalar ()
+        public T Query<T> ()
         {
-            CheckDisposed ();
             Reset ();
-            return reader.Read () ? reader[0] : null;
+            return reader.Read () ? reader.Get<T> (0) : (T) SqliteUtils.FromDbFormat <T> (null);
         }
 
         public QueryReader Query ()
         {
-            CheckDisposed ();
+            Reset ();
             return reader;
         }
     }
@@ -377,7 +388,13 @@ namespace Hyena.Data.Sqlite
                     case SQLITE3_TEXT:
                         return Native.sqlite3_column_text16 (Ptr, i).PtrToString ();
                     case SQLITE_BLOB:
-                        return Native.sqlite3_column_blob (Ptr, i);
+                        int num_bytes = Native.sqlite3_column_bytes (Ptr, i);
+                        if (num_bytes == 0)
+                            return null;
+
+                        byte [] bytes = new byte[num_bytes];
+                        Marshal.Copy (Native.sqlite3_column_blob (Ptr, i), bytes, 0, num_bytes);
+                        return bytes;
                     case SQLITE_NULL:
                         return null;
                     default:
@@ -387,21 +404,70 @@ namespace Hyena.Data.Sqlite
         }
 
         public object this[string columnName] {
-            get {
-                Statement.CheckReading ();
-                if (columns == null) {
-                    columns = new Dictionary<string, int> ();
-                    for (int i = 0; i < FieldCount; i++) {
-                        columns[Native.sqlite3_column_name16 (Ptr, i).PtrToString ()] = i;
-                    }
-                }
+            get { return this[GetColumnIndex (columnName)]; }
+        }
 
-                int col = 0;
-                if (!columns.TryGetValue (columnName, out col))
-                    throw new ArgumentException ("columnName");
+        public T Get<T> (int i)
+        {
+            var type = typeof (T);
+            var o = GetAs (this[i], type);
+
+            if (o is T)
+                return (T) o;
+
+            return (T) SqliteUtils.FromDbFormat (type, o);
+        }
+
+        private object GetAs (object o, Type type)
+        {
+            if (o == null)
+                return null;
+            else if (type == typeof(uint))
+                return (uint)(long)o;
+            else if (type == typeof(ulong))
+                return (ulong)(long)o;
+            else if (type == typeof(float))
+                return (float)(double)o;
+            return o;
+        }
+
+        static Type long_type = typeof(long);
+        static Type double_type = typeof(double);
+        static Type [] long_types = { typeof(int), typeof(uint), typeof(ulong) };
+        static Type [] double_types = { typeof(float) };
+        static Type [] self_types = { typeof(string), typeof(byte[]), long_type, double_type};
+
+        static Type DbTypeFor (Type type)
+        {
+            if (long_types.Contains (type))
+                return typeof(int);
+            else if (double_types.Contains (type))
+                return double_type;
+            else if (self_types.Contains (type))
+                return type;
+            else
+                return null;
+        }
+
+        public T Get<T> (string columnName)
+        {
+            return Get<T> (GetColumnIndex (columnName));
+        }
 
-                return this[col];
+        private int GetColumnIndex (string columnName)
+        {
+            Statement.CheckReading ();
+            if (columns == null) {
+                columns = new Dictionary<string, int> ();
+                for (int i = 0; i < FieldCount; i++) {
+                    columns[Native.sqlite3_column_name16 (Ptr, i).PtrToString ()] = i;
+                }
             }
+
+            int col = 0;
+            if (!columns.TryGetValue (columnName, out col))
+                throw new ArgumentException ("columnName");
+            return col;
         }
 
         const int SQLITE_INTEGER = 1;
@@ -466,7 +532,7 @@ namespace Hyena.Data.Sqlite
         internal static extern int sqlite3_column_type(IntPtr stmt, int iCol);
 
         [DllImport(SQLITE_DLL)]
-        internal static extern byte [] sqlite3_column_blob(IntPtr stmt, int iCol);
+        internal static extern IntPtr sqlite3_column_blob(IntPtr stmt, int iCol);
 
         [DllImport(SQLITE_DLL)]
         internal static extern int sqlite3_column_bytes(IntPtr stmt, int iCol);
diff --git a/Hyena.Data.Sqlite/Hyena.Data.Sqlite/Tests/SqliteTests.cs b/Hyena.Data.Sqlite/Hyena.Data.Sqlite/Tests/SqliteTests.cs
index 604cba8..8246344 100644
--- a/Hyena.Data.Sqlite/Hyena.Data.Sqlite/Tests/SqliteTests.cs
+++ b/Hyena.Data.Sqlite/Hyena.Data.Sqlite/Tests/SqliteTests.cs
@@ -40,17 +40,20 @@ namespace Hyena.Data.Sqlite
     public class SqliteTests
     {
         Connection con;
+        Statement select_literal;
         string dbfile = "hyena-data-sqlite-test.db";
 
         [SetUp]
         public void Setup ()
         {
             con = new Connection (dbfile);
+            select_literal = con.CreateStatement ("SELECT ?");
         }
 
         [TearDown]
         public void TearDown ()
         {
+            select_literal.Dispose ();
             Assert.AreEqual (0, con.Statements.Count);
             con.Dispose ();
             System.IO.File.Delete (dbfile);
@@ -60,13 +63,13 @@ namespace Hyena.Data.Sqlite
         public void Test ()
         {
             using (var stmt = con.CreateStatement ("SELECT 'foobar' as version")) {
-                Assert.AreEqual ("foobar", stmt.QueryScalar ());
+                Assert.AreEqual ("foobar", stmt.Query<string> ());
                 Assert.AreEqual ("foobar", stmt.First ()[0]);
                 Assert.AreEqual ("foobar", stmt.First ()["version"]);
             }
 
             using (var stmt = con.CreateStatement ("SELECT 2 + 5 as res")) {
-                Assert.AreEqual (7, stmt.QueryScalar ());
+                Assert.AreEqual (7, stmt.Query<int> ());
                 Assert.AreEqual (7, stmt.First ()[0]);
                 Assert.AreEqual (7, stmt.First ()["res"]);
 
@@ -107,7 +110,7 @@ namespace Hyena.Data.Sqlite
 
             using (var stmt = con.CreateStatement ("SELECT ? as a, ? as b, ?")) {
                 stmt.Bind (1, "two", 3.3);
-                Assert.AreEqual (1, stmt.QueryScalar ());
+                Assert.AreEqual (1, stmt.Query<int> ());
                 Assert.AreEqual ("two", stmt.First ()["b"]);
                 Assert.AreEqual (3.3, stmt.First ()[2]);
             }
@@ -119,7 +122,7 @@ namespace Hyena.Data.Sqlite
             CreateUsers (con);
 
             using (var stmt = con.CreateStatement ("SELECT COUNT(*) FROM Users")) {
-                Assert.AreEqual (2, stmt.QueryScalar ());
+                Assert.AreEqual (2, stmt.Query<int> ());
             }
 
             using (var stmt = con.CreateStatement ("SELECT ID, Name FROM Users ORDER BY NAME")) {
@@ -171,7 +174,7 @@ namespace Hyena.Data.Sqlite
             Assert.AreEqual ("Gabriel", q2["Name"]);
 
             con.Execute ("INSERT INTO Users (Name) VALUES ('Zeus')");
-            Assert.AreEqual (3, con.QueryScalar ("SELECT COUNT(*) FROM Users"));
+            Assert.AreEqual (3, con.Query<int> ("SELECT COUNT(*) FROM Users"));
 
             Assert.IsTrue (q2.Read ());
             Assert.AreEqual ("Aaron", q2[1]);
@@ -187,11 +190,11 @@ namespace Hyena.Data.Sqlite
             // Insert a value, see that q2 can see it, then delete it and try to
             // get the now-deleted value from q2
             con.Execute ("INSERT INTO Users (Name) VALUES ('Apollo')");
-            Assert.AreEqual (4, con.QueryScalar ("SELECT COUNT(*) FROM Users"));
+            Assert.AreEqual (4, con.Query<int> ("SELECT COUNT(*) FROM Users"));
             Assert.IsTrue (q2.Read ());
 
             con.Execute ("DELETE FROM Users WHERE Name='Apollo'");
-            Assert.AreEqual (3, con.QueryScalar ("SELECT COUNT(*) FROM Users"));
+            Assert.AreEqual (3, con.Query<int> ("SELECT COUNT(*) FROM Users"));
             Assert.AreEqual ("Apollo", q2[1]);
             Assert.IsFalse (q2.Read ());
 
@@ -240,18 +243,18 @@ namespace Hyena.Data.Sqlite
         [Test]
         public void QueryScalar ()
         {
-            Assert.AreEqual (7, con.QueryScalar ("SELECT 7"));
+            Assert.AreEqual (7, con.Query<int> ("SELECT 7"));
         }
 
         [Test]
         public void Execute ()
         {
             try {
-                con.QueryScalar ("SELECT COUNT(*) FROM Users");
+                con.Query<int> ("SELECT COUNT(*) FROM Users");
                 Assert.Fail ("Should have thrown an exception");
             } catch {}
             con.Execute ("CREATE TABLE Users (ID INTEGER PRIMARY KEY, Name TEXT)");
-            Assert.AreEqual (0, con.QueryScalar ("SELECT COUNT(*) FROM Users"));
+            Assert.AreEqual (0, con.Query<int> ("SELECT COUNT(*) FROM Users"));
         }
 
         [Test]
@@ -260,17 +263,17 @@ namespace Hyena.Data.Sqlite
             con.AddFunction<Md5Function> ();
 
             using (var stmt = con.CreateStatement ("SELECT HYENA_MD5(?, ?)")) {
-                Assert.AreEqual ("ae2b1fca515949e5d54fb22b8ed95575", stmt.Bind (1, "testing").QueryScalar ());
-                Assert.AreEqual (null, stmt.Bind (1, null).QueryScalar ());
+                Assert.AreEqual ("ae2b1fca515949e5d54fb22b8ed95575", stmt.Bind (1, "testing").Query<string> ());
+                Assert.AreEqual (null, stmt.Bind (1, null).Query<string> ());
             }
 
             using (var stmt = con.CreateStatement ("SELECT HYENA_MD5(?, ?, ?)")) {
-                Assert.AreEqual ("ae2b1fca515949e5d54fb22b8ed95575", stmt.Bind (2, "test", "ing").QueryScalar ());
-                Assert.AreEqual (null, stmt.Bind (2, null, null).QueryScalar ());
+                Assert.AreEqual ("ae2b1fca515949e5d54fb22b8ed95575", stmt.Bind (2, "test", "ing").Query<string> ());
+                Assert.AreEqual (null, stmt.Bind (2, null, null).Query<string> ());
             }
 
             using (var stmt = con.CreateStatement ("SELECT HYENA_MD5(?, ?, ?, ?)")) {
-                Assert.AreEqual (null, stmt.Bind (3, null, "", null).QueryScalar ());
+                Assert.AreEqual (null, stmt.Bind (3, null, "", null).Query<string> ());
 
                 try {
                     con.RemoveFunction<Md5Function> ();
@@ -282,11 +285,110 @@ namespace Hyena.Data.Sqlite
 
             try {
                 using (var stmt = con.CreateStatement ("SELECT HYENA_MD5(?, ?, ?, ?)")) {
-                    Assert.AreEqual ("ae2b1fca515949e5d54fb22b8ed95575", stmt.QueryScalar ());
+                    Assert.AreEqual ("ae2b1fca515949e5d54fb22b8ed95575", stmt.Query<string> ());
                     Assert.Fail ("Function HYENA_MD5 should no longer exist");
                 }
             } catch {}
         }
+
+        [Test]
+        public void DataTypes ()
+        {
+            AssertGetNull<int> (0);
+            AssertRoundTrip<int> (0);
+            AssertRoundTrip<int> (1);
+            AssertRoundTrip<int> (-1);
+            AssertRoundTrip<int> (42);
+            AssertRoundTrip<int> (int.MaxValue);
+            AssertRoundTrip<int> (int.MinValue);
+
+            AssertGetNull<uint> (0);
+            AssertRoundTrip<uint> (0);
+            AssertRoundTrip<uint> (1);
+            AssertRoundTrip<uint> (42);
+            AssertRoundTrip<uint> (uint.MaxValue);
+            AssertRoundTrip<uint> (uint.MinValue);
+
+            AssertGetNull<long> (0);
+            AssertRoundTrip<long> (0);
+            AssertRoundTrip<long> (1);
+            AssertRoundTrip<long> (-1);
+            AssertRoundTrip<long> (42);
+            AssertRoundTrip<long> (long.MaxValue);
+            AssertRoundTrip<long> (long.MinValue);
+
+            AssertGetNull<ulong> (0);
+            AssertRoundTrip<ulong> (0);
+            AssertRoundTrip<ulong> (1);
+            AssertRoundTrip<ulong> (42);
+            AssertRoundTrip<ulong> (ulong.MaxValue);
+            AssertRoundTrip<ulong> (ulong.MinValue);
+
+            AssertGetNull<float> (0f);
+            AssertRoundTrip<float> (0f);
+            AssertRoundTrip<float> (1f);
+            AssertRoundTrip<float> (-1f);
+            AssertRoundTrip<float> (42.222f);
+            AssertRoundTrip<float> (float.MaxValue);
+            AssertRoundTrip<float> (float.MinValue);
+
+            AssertGetNull<double> (0);
+            AssertRoundTrip<double> (0);
+            AssertRoundTrip<double> (1);
+            AssertRoundTrip<double> (-1);
+            AssertRoundTrip<double> (42.222);
+            AssertRoundTrip<double> (double.MaxValue);
+            AssertRoundTrip<double> (double.MinValue);
+
+            AssertGetNull<string> (null);
+            AssertRoundTrip<string> ("a");
+            AssertRoundTrip<string> ("üb�r;&#co¯ol!~`\n\r\t");
+
+            AssertGetNull<byte[]> (null);
+            AssertRoundTrip<byte[]> (new byte [] { 0 });
+            AssertRoundTrip<byte[]> (new byte [] { 0, 1});
+            AssertRoundTrip<byte[]> (System.Text.Encoding.UTF8.GetBytes ("üb�r;&#co¯ol!~`\n\r\t"));
+
+            var ignore_ms = new Func<DateTime, DateTime, bool> ((a, b) => (a - b).TotalSeconds < 1);
+            AssertGetNull<DateTime> (DateTime.MinValue);
+            AssertRoundTrip<DateTime> (new DateTime (1970, 1, 1).ToLocalTime ());
+            AssertRoundTrip<DateTime> (DateTime.Now, ignore_ms);
+            AssertRoundTrip<DateTime> (DateTime.MinValue);
+            // FIXME
+            //AssertRoundTrip<DateTime> (DateTime.MaxValue);
+            Assert.AreEqual (new DateTime (1970, 1, 1).ToLocalTime (), con.Query<DateTime> ("SELECT 0"));
+
+            AssertGetNull<TimeSpan> (TimeSpan.MinValue);
+            AssertRoundTrip<TimeSpan> (TimeSpan.MinValue);
+            AssertRoundTrip<TimeSpan> (TimeSpan.FromSeconds (0));
+            AssertRoundTrip<TimeSpan> (TimeSpan.FromSeconds (0.001));
+            AssertRoundTrip<TimeSpan> (TimeSpan.FromSeconds (0.002));
+            AssertRoundTrip<TimeSpan> (TimeSpan.FromSeconds (0.503));
+            AssertRoundTrip<TimeSpan> (TimeSpan.FromSeconds (1.01));
+            AssertRoundTrip<TimeSpan> (TimeSpan.FromHours (999.00193));
+            // FIXME
+            //AssertRoundTrip<TimeSpan> (TimeSpan.MaxValue);
+        }
+
+        private void AssertRoundTrip<T> (T val)
+        {
+            AssertRoundTrip (val, null);
+        }
+
+        private void AssertRoundTrip<T> (T val, Func<T, T, bool> func)
+        {
+            var o = select_literal.Bind (val).First ().Get<T> (0);
+            if (func == null) {
+                Assert.AreEqual (val, o);
+            } else {
+                Assert.IsTrue (func (val, o));
+            }
+        }
+
+        private void AssertGetNull<T> (T val)
+        {
+            Assert.AreEqual (val, select_literal.Bind (null).First ().Get<T> (0));
+        }
     }
 }
 



[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]