Modified:
trunk/TestSuite/DataReaderTests.cs
trunk/mysqlclient/datareader.cs
Log:
Added checks on each type safe datareader method to keep them from working on null
columns. You should use .IsDbNull to check for null.
Modified: trunk/TestSuite/DataReaderTests.cs
===================================================================
--- trunk/TestSuite/DataReaderTests.cs 2006-09-18 15:27:39 UTC (rev 342)
+++ trunk/TestSuite/DataReaderTests.cs 2006-09-18 15:32:51 UTC (rev 343)
@@ -368,43 +368,57 @@
}
}
- [Test]
- public void ConsecutiveNulls()
- {
- execSQL("INSERT INTO Test (id, name) VALUES (1, 'Test')");
- execSQL("INSERT INTO Test (id, name) VALUES (2, NULL)");
- execSQL("INSERT INTO Test (id, name) VALUES (3, 'Test2')");
+ [Test]
+ public void ConsecutiveNulls()
+ {
+ execSQL("INSERT INTO Test (id, name, dt) VALUES (1, 'Test', NULL)");
+ execSQL("INSERT INTO Test (id, name, dt) VALUES (2, NULL, now())");
+ execSQL("INSERT INTO Test (id, name, dt) VALUES (3, 'Test2', NULL)");
- MySqlCommand cmd = new MySqlCommand("SELECT * FROM Test", conn);
- MySqlDataReader reader = null;
- try
- {
- reader = cmd.ExecuteReader();
- reader.Read();
- Assert.AreEqual(1, reader.GetValue(0));
- Assert.AreEqual("Test", reader.GetValue(1));
- Assert.AreEqual("Test", reader.GetString(1));
- reader.Read();
- Assert.AreEqual(2, reader.GetValue(0));
- Assert.AreEqual(DBNull.Value, reader.GetValue(1));
- reader.Read();
- Assert.AreEqual(3, reader.GetValue(0));
- Assert.AreEqual("Test2", reader.GetValue(1));
- Assert.AreEqual("Test2", reader.GetString(1));
- Assert.IsFalse(reader.Read());
- Assert.IsFalse(reader.NextResult());
- }
- catch (Exception ex)
- {
- Assert.Fail(ex.Message);
- }
- finally
- {
- if (reader != null) reader.Close();
- }
+ MySqlCommand cmd = new MySqlCommand("SELECT id, name, dt FROM Test", conn);
+ MySqlDataReader reader = null;
+ try
+ {
+ reader = cmd.ExecuteReader();
+ reader.Read();
+ Assert.AreEqual(1, reader.GetValue(0));
+ Assert.AreEqual("Test", reader.GetValue(1));
+ Assert.AreEqual("Test", reader.GetString(1));
+ Assert.AreEqual(DBNull.Value, reader.GetValue(2));
+ reader.Read();
+ Assert.AreEqual(2, reader.GetValue(0));
+ Assert.AreEqual(DBNull.Value, reader.GetValue(1));
+ try
+ {
+ reader.GetString(1);
+ Assert.Fail("Should not get here");
+ }
+ catch (Exception) { }
+ Assert.IsFalse(reader.IsDBNull(2));
+ reader.Read();
+ Assert.AreEqual(3, reader.GetValue(0));
+ Assert.AreEqual("Test2", reader.GetValue(1));
+ Assert.AreEqual("Test2", reader.GetString(1));
+ Assert.AreEqual(DBNull.Value, reader.GetValue(2));
+ try
+ {
+ reader.GetMySqlDateTime(2);
+ Assert.Fail("Should not get here");
+ }
+ catch (Exception) { }
+ Assert.IsFalse(reader.Read());
+ Assert.IsFalse(reader.NextResult());
+ }
+ catch (Exception ex)
+ {
+ Assert.Fail(ex.Message);
+ }
+ finally
+ {
+ if (reader != null) reader.Close();
+ }
+ }
- }
-
[Test]
public void HungDataReader()
{
Modified: trunk/mysqlclient/datareader.cs
===================================================================
--- trunk/mysqlclient/datareader.cs 2006-09-18 15:27:39 UTC (rev 342)
+++ trunk/mysqlclient/datareader.cs 2006-09-18 15:32:51 UTC (rev 343)
@@ -216,7 +216,7 @@
/// <returns></returns>
public override byte GetByte(int i)
{
- IMySqlValue v = GetFieldValue(i);
+ IMySqlValue v = GetFieldValue(i, false);
if (v is MySqlUByte)
return ((MySqlUByte)v).Value;
else
@@ -238,7 +238,7 @@
if (i >= fields.Length)
throw new IndexOutOfRangeException();
- IMySqlValue val = GetFieldValue(i);
+ IMySqlValue val = GetFieldValue(i, false);
if (! (val is MySqlBinary))
throw new MySqlException("GetBytes can only be called on binary columns");
@@ -337,7 +337,7 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetMySqlDateTime/*'/>
public MySqlDateTime GetMySqlDateTime(int index)
{
- return (MySqlDateTime)GetFieldValue(index);
+ return (MySqlDateTime)GetFieldValue(index, true);
}
public DateTime GetDateTime(string name)
@@ -348,7 +348,7 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetDateTime/*'/>
public override DateTime GetDateTime(int index)
{
- IMySqlValue val = GetFieldValue(index);
+ IMySqlValue val = GetFieldValue(index, true);
MySqlDateTime dt;
// we need to do this because functions like date_add return string
@@ -374,8 +374,8 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetDecimal/*'/>
public override Decimal GetDecimal(int index)
{
- IMySqlValue v = GetFieldValue(index);
- if (v is MySqlDecimal)
+ IMySqlValue v = GetFieldValue(index, true);
+ if (v is MySqlDecimal)
return ((MySqlDecimal)v).Value;
return Convert.ToDecimal(v.Value);
}
@@ -388,8 +388,8 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetDouble/*'/>
public override double GetDouble(int index)
{
- IMySqlValue v = GetFieldValue(index);
- if (v is MySqlDouble)
+ IMySqlValue v = GetFieldValue(index, true);
+ if (v is MySqlDouble)
return ((MySqlDouble)v).Value;
return Convert.ToDouble(v.Value);
}
@@ -417,8 +417,8 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetFloat/*'/>
public override float GetFloat(int index)
{
- IMySqlValue v = GetFieldValue(index);
- if (v is MySqlSingle)
+ IMySqlValue v = GetFieldValue(index, true);
+ if (v is MySqlSingle)
return ((MySqlSingle)v).Value;
return Convert.ToSingle(v.Value);
}
@@ -431,7 +431,7 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetGuid/*'/>
public override Guid GetGuid(int index)
{
- return new Guid( GetString(index) );
+ return new Guid(GetString(index));
}
public Int16 GetInt16(string name)
@@ -442,8 +442,8 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetInt16/*'/>
public override Int16 GetInt16(int index)
{
- IMySqlValue v = GetFieldValue(index);
- if (v is MySqlInt16)
+ IMySqlValue v = GetFieldValue(index, true);
+ if (v is MySqlInt16)
return ((MySqlInt16)v).Value;
connection.UsageAdvisor.Converting(command.CommandText,
@@ -459,8 +459,8 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetInt32/*'/>
public override Int32 GetInt32(int index)
{
- IMySqlValue v = GetFieldValue(index);
- if (v is MySqlInt32)
+ IMySqlValue v = GetFieldValue(index, true);
+ if (v is MySqlInt32)
return ((MySqlInt32)v).Value;
connection.UsageAdvisor.Converting(command.CommandText,
@@ -476,7 +476,7 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetInt64/*'/>
public override Int64 GetInt64(int index)
{
- IMySqlValue v = GetFieldValue(index);
+ IMySqlValue v = GetFieldValue(index, true);
if (v is MySqlInt64)
return ((MySqlInt64)v).Value;
@@ -596,9 +596,7 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetString/*'/>
public override String GetString(int index)
{
- IMySqlValue val = GetFieldValue(index);
- if (val.IsNull)
- throw new SqlNullValueException();
+ IMySqlValue val = GetFieldValue(index, true);
if (val is MySqlBinary)
{
@@ -617,7 +615,9 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetTimeSpan/*'/>
public TimeSpan GetTimeSpan(int index)
{
- MySqlTimeSpan ts = (MySqlTimeSpan)GetFieldValue(index);
+ IMySqlValue val = GetFieldValue(index, true);
+
+ MySqlTimeSpan ts = (MySqlTimeSpan)val;
return ts.Value;
}
@@ -631,8 +631,9 @@
if (! isOpen) throw new Exception("No current query in data reader");
if (i >= fields.Length) throw new IndexOutOfRangeException();
- IMySqlValue val = GetFieldValue(i);
- if (val.IsNull) return DBNull.Value;
+ IMySqlValue val = GetFieldValue(i, false);
+ if (val.IsNull)
+ return DBNull.Value;
// if the column is a date/time, then we return a MySqlDateTime
// so .ToString() will print '0000-00-00' correctly
@@ -673,7 +674,7 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetUInt16/*'/>
public UInt16 GetUInt16(int index)
{
- IMySqlValue v = GetFieldValue(index);
+ IMySqlValue v = GetFieldValue(index, true);
if (v is MySqlUInt16)
return ((MySqlUInt16)v).Value;
@@ -690,7 +691,7 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetUInt32/*'/>
public UInt32 GetUInt32(int index)
{
- IMySqlValue v = GetFieldValue(index);
+ IMySqlValue v = GetFieldValue(index, true);
if (v is MySqlUInt32)
return ((MySqlUInt32)v).Value;
@@ -707,7 +708,7 @@
/// <include file='docs/MySqlDataReader.xml' path='docs/GetUInt64/*'/>
public UInt64 GetUInt64(int index)
{
- IMySqlValue v = GetFieldValue(index);
+ IMySqlValue v = GetFieldValue(index, true);
if (v is MySqlUInt64)
return ((MySqlUInt64)v).Value;
@@ -862,7 +863,7 @@
}
- private IMySqlValue GetFieldValue(int index)
+ private IMySqlValue GetFieldValue(int index, bool checkNull)
{
if (index < 0 || index >= fields.Length)
throw new ArgumentException( "You have specified an invalid column ordinal." );
@@ -884,7 +885,11 @@
seqIndex = index;
}
- return values[index];
+ IMySqlValue v = values[index];
+ if (checkNull && v.IsNull)
+ throw new SqlNullValueException();
+
+ return v;
}
private void ClearCurrentResultset()
| Thread |
|---|
| • Connector/NET commit: r343 - in trunk: TestSuite mysqlclient | rburnett | 18 Sep |