List:Commits« Previous MessageNext Message »
From:rburnett Date:September 18 2006 5:32pm
Subject:Connector/NET commit: r343 - in trunk: TestSuite mysqlclient
View as plain text  
Modified:
   trunk/TestSuite/DataReaderTests.cs
   trunk/mysqlclient/datareader.cs
Log:
Added checks on each type safe datareader method to keep them from working on null
columns.  You should use .IsDbNull to check for null.

Modified: trunk/TestSuite/DataReaderTests.cs
===================================================================
--- trunk/TestSuite/DataReaderTests.cs	2006-09-18 15:27:39 UTC (rev 342)
+++ trunk/TestSuite/DataReaderTests.cs	2006-09-18 15:32:51 UTC (rev 343)
@@ -368,43 +368,57 @@
 			}
 		}
 
-		[Test]
-		public void ConsecutiveNulls() 
-		{
-			execSQL("INSERT INTO Test (id, name) VALUES (1, 'Test')");
-			execSQL("INSERT INTO Test (id, name) VALUES (2, NULL)");
-			execSQL("INSERT INTO Test (id, name) VALUES (3, 'Test2')");
+        [Test]
+        public void ConsecutiveNulls()
+        {
+            execSQL("INSERT INTO Test (id, name, dt) VALUES (1, 'Test', NULL)");
+            execSQL("INSERT INTO Test (id, name, dt) VALUES (2, NULL, now())");
+            execSQL("INSERT INTO Test (id, name, dt) VALUES (3, 'Test2', NULL)");
 
-			MySqlCommand cmd = new MySqlCommand("SELECT * FROM Test", conn);
-			MySqlDataReader reader = null;
-			try 
-			{
-				reader = cmd.ExecuteReader();
-				reader.Read();
-				Assert.AreEqual(1, reader.GetValue(0));
-				Assert.AreEqual("Test", reader.GetValue(1));
-				Assert.AreEqual("Test", reader.GetString(1));
-				reader.Read();
-				Assert.AreEqual(2, reader.GetValue(0));
-				Assert.AreEqual(DBNull.Value, reader.GetValue(1));
-				reader.Read();
-				Assert.AreEqual(3, reader.GetValue(0));
-				Assert.AreEqual("Test2", reader.GetValue(1));
-				Assert.AreEqual("Test2", reader.GetString(1));
-				Assert.IsFalse(reader.Read());
-				Assert.IsFalse(reader.NextResult());
-			}
-			catch (Exception ex) 
-			{
-				Assert.Fail(ex.Message);
-			}
-			finally 
-			{
-				if (reader != null) reader.Close();
-			}
+            MySqlCommand cmd = new MySqlCommand("SELECT id, name, dt FROM Test", conn);
+            MySqlDataReader reader = null;
+            try
+            {
+                reader = cmd.ExecuteReader();
+                reader.Read();
+                Assert.AreEqual(1, reader.GetValue(0));
+                Assert.AreEqual("Test", reader.GetValue(1));
+                Assert.AreEqual("Test", reader.GetString(1));
+                Assert.AreEqual(DBNull.Value, reader.GetValue(2));
+                reader.Read();
+                Assert.AreEqual(2, reader.GetValue(0));
+                Assert.AreEqual(DBNull.Value, reader.GetValue(1));
+                try
+                {
+                    reader.GetString(1);
+                    Assert.Fail("Should not get here");
+                }
+                catch (Exception) { }
+                Assert.IsFalse(reader.IsDBNull(2));
+                reader.Read();
+                Assert.AreEqual(3, reader.GetValue(0));
+                Assert.AreEqual("Test2", reader.GetValue(1));
+                Assert.AreEqual("Test2", reader.GetString(1));
+                Assert.AreEqual(DBNull.Value, reader.GetValue(2));
+                try
+                {
+                    reader.GetMySqlDateTime(2);
+                    Assert.Fail("Should not get here");
+                }
+                catch (Exception) { }
+                Assert.IsFalse(reader.Read());
+                Assert.IsFalse(reader.NextResult());
+            }
+            catch (Exception ex)
+            {
+                Assert.Fail(ex.Message);
+            }
+            finally
+            {
+                if (reader != null) reader.Close();
+            }
+        }
 
-		}
-
 		[Test]
 		public void HungDataReader() 
 		{

Modified: trunk/mysqlclient/datareader.cs
===================================================================
--- trunk/mysqlclient/datareader.cs	2006-09-18 15:27:39 UTC (rev 342)
+++ trunk/mysqlclient/datareader.cs	2006-09-18 15:32:51 UTC (rev 343)
@@ -216,7 +216,7 @@
 		/// <returns></returns>
 		public override byte GetByte(int i)
 		{
-			IMySqlValue v = GetFieldValue(i);
+			IMySqlValue v = GetFieldValue(i, false);
 			if (v is MySqlUByte)
 				return ((MySqlUByte)v).Value;
 			else
@@ -238,7 +238,7 @@
 			if (i >= fields.Length) 
 				throw new IndexOutOfRangeException();
 
-			IMySqlValue val = GetFieldValue(i);
+			IMySqlValue val = GetFieldValue(i, false);
 
 			if (! (val is MySqlBinary))
 				throw new MySqlException("GetBytes can only be called on binary columns");
@@ -337,7 +337,7 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetMySqlDateTime/*'/>
 		public MySqlDateTime GetMySqlDateTime(int index)
 		{
-			return (MySqlDateTime)GetFieldValue(index);
+			return (MySqlDateTime)GetFieldValue(index, true);
 		}
 
         public DateTime GetDateTime(string name)
@@ -348,7 +348,7 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetDateTime/*'/>
 		public override DateTime GetDateTime(int index)
 		{
-			IMySqlValue val = GetFieldValue(index);
+			IMySqlValue val = GetFieldValue(index, true);
             MySqlDateTime dt;
 
             // we need to do this because functions like date_add return string
@@ -374,8 +374,8 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetDecimal/*'/>
 		public override Decimal GetDecimal(int index)
 		{
-			IMySqlValue v = GetFieldValue(index);
-			if (v is MySqlDecimal)
+            IMySqlValue v = GetFieldValue(index, true);
+            if (v is MySqlDecimal)
 				return ((MySqlDecimal)v).Value;
 			return Convert.ToDecimal(v.Value);
 		}
@@ -388,8 +388,8 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetDouble/*'/>
 		public override double GetDouble(int index)
 		{
-			IMySqlValue v = GetFieldValue(index);
-			if (v is MySqlDouble)
+            IMySqlValue v = GetFieldValue(index, true);
+            if (v is MySqlDouble)
 				return ((MySqlDouble)v).Value;
 			return Convert.ToDouble(v.Value);
 		}
@@ -417,8 +417,8 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetFloat/*'/>
 		public override float GetFloat(int index)
 		{
-			IMySqlValue v = GetFieldValue(index);
-			if (v is MySqlSingle)
+            IMySqlValue v = GetFieldValue(index, true);
+            if (v is MySqlSingle)
 				return ((MySqlSingle)v).Value;
 			return Convert.ToSingle(v.Value);
 		}
@@ -431,7 +431,7 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetGuid/*'/>
 		public override Guid GetGuid(int index)
 		{
-			return new Guid( GetString(index) );
+			return new Guid(GetString(index));
 		}
 
         public Int16 GetInt16(string name)
@@ -442,8 +442,8 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetInt16/*'/>
 		public override Int16 GetInt16(int index)
 		{
-			IMySqlValue v = GetFieldValue(index);
-			if (v is MySqlInt16)
+            IMySqlValue v = GetFieldValue(index, true);
+            if (v is MySqlInt16)
 				return ((MySqlInt16)v).Value;
             
             connection.UsageAdvisor.Converting(command.CommandText,
@@ -459,8 +459,8 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetInt32/*'/>
 		public override Int32 GetInt32(int index)
 		{
-			IMySqlValue v = GetFieldValue(index);
-			if (v is MySqlInt32)
+            IMySqlValue v = GetFieldValue(index, true);
+            if (v is MySqlInt32)
 				return ((MySqlInt32)v).Value;
             
             connection.UsageAdvisor.Converting(command.CommandText,
@@ -476,7 +476,7 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetInt64/*'/>
 		public override Int64 GetInt64(int index)
 		{
-			IMySqlValue v = GetFieldValue(index);
+            IMySqlValue v = GetFieldValue(index, true);
 			if (v is MySqlInt64)
 				return ((MySqlInt64)v).Value;
 
@@ -596,9 +596,7 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetString/*'/>
 		public override String GetString(int index)
 		{
-			IMySqlValue val = GetFieldValue(index);
-            if (val.IsNull)
-                throw new SqlNullValueException();
+			IMySqlValue val = GetFieldValue(index, true);
 
 			if (val is MySqlBinary)
 			{
@@ -617,7 +615,9 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetTimeSpan/*'/>
 		public TimeSpan GetTimeSpan(int index)
 		{
-			MySqlTimeSpan ts = (MySqlTimeSpan)GetFieldValue(index);
+            IMySqlValue val = GetFieldValue(index, true);
+
+            MySqlTimeSpan ts = (MySqlTimeSpan)val;
 			return ts.Value;
 		}
 
@@ -631,8 +631,9 @@
 			if (! isOpen) throw new Exception("No current query in data reader");
 			if (i >= fields.Length) throw new IndexOutOfRangeException();
 
-			IMySqlValue val = GetFieldValue(i);
-			if (val.IsNull) return DBNull.Value;
+			IMySqlValue val = GetFieldValue(i, false);
+			if (val.IsNull) 
+                return DBNull.Value;
 
 			// if the column is a date/time, then we return a MySqlDateTime
 			// so .ToString() will print '0000-00-00' correctly
@@ -673,7 +674,7 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetUInt16/*'/>
 		public UInt16 GetUInt16(int index)
 		{
-			IMySqlValue v = GetFieldValue(index);
+			IMySqlValue v = GetFieldValue(index, true);
 			if (v is MySqlUInt16)
 				return ((MySqlUInt16)v).Value;
 
@@ -690,7 +691,7 @@
         /// <include file='docs/MySqlDataReader.xml' path='docs/GetUInt32/*'/>
 		public UInt32 GetUInt32(int index)
 		{
-			IMySqlValue v = GetFieldValue(index);
+			IMySqlValue v = GetFieldValue(index, true);
 			if (v is MySqlUInt32)
 				return ((MySqlUInt32)v).Value;
 
@@ -707,7 +708,7 @@
 		/// <include file='docs/MySqlDataReader.xml' path='docs/GetUInt64/*'/>
 		public UInt64 GetUInt64(int index)
 		{
-			IMySqlValue v = GetFieldValue(index);
+			IMySqlValue v = GetFieldValue(index, true);
 			if (v is MySqlUInt64)
 				return ((MySqlUInt64)v).Value;
 
@@ -862,7 +863,7 @@
 		}
 
 
-		private IMySqlValue GetFieldValue(int index) 
+		private IMySqlValue GetFieldValue(int index, bool checkNull) 
 		{
 			if (index < 0 || index >= fields.Length) 
 				throw new ArgumentException( "You have specified an invalid column ordinal." );
@@ -884,7 +885,11 @@
                 seqIndex = index;
             }
 
-            return values[index];
+            IMySqlValue v = values[index];
+            if (checkNull && v.IsNull)
+                throw new SqlNullValueException();
+
+            return v;
 		}
 
 		private void ClearCurrentResultset() 

Thread
Connector/NET commit: r343 - in trunk: TestSuite mysqlclientrburnett18 Sep