Skip to content
This repository was archived by the owner on Dec 24, 2022. It is now read-only.

Commit 85fe109

Browse files
committed
Merge pull request #508 from shift-evgeny/ContainsNullableSupport
Better nullable support in SqlExpression
2 parents 1649afb + 300fbd8 commit 85fe109

File tree

7 files changed

+131
-27
lines changed

7 files changed

+131
-27
lines changed

src/ServiceStack.OrmLite.PostgreSQL.Tests/OrmLiteInsertTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ public void Can_retrieve_LastInsertId_from_inserted_table()
113113
var row2 = new ModelWithIdAndName1() { Name = "B", Id = 5 };
114114

115115
var row1LastInsertId = db.Insert(row1, selectIdentity: true);
116-
Assert.That(db.GetLastSql(), Is.StringMatching("\\) RETURNING \"?id"));
116+
Assert.That(db.GetLastSql(), Is.StringMatching("\\) RETURNING \"?[Ii]d"));
117117

118118
var row2LastInsertId = db.Insert(row2, selectIdentity: true);
119-
Assert.That(db.GetLastSql(), Is.StringMatching("\\) RETURNING \"?id"));
119+
Assert.That(db.GetLastSql(), Is.StringMatching("\\) RETURNING \"?[Ii]d"));
120120

121121
var insertedRow1 = db.SingleById<ModelWithIdAndName1>(row1LastInsertId);
122122
var insertedRow2 = db.SingleById<ModelWithIdAndName1>(row2LastInsertId);

src/ServiceStack.OrmLite.SqlServer.Converters/SqlServerHierarchyIdTypeConverter.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ public override DbType DbType
2727
public override void InitDbParam(IDbDataParameter p, Type fieldType)
2828
{
2929
var sqlParam = (SqlParameter)p;
30-
sqlParam.IsNullable = (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>));
30+
sqlParam.IsNullable = fieldType.IsNullableType();
3131
sqlParam.SqlDbType = SqlDbType.Udt;
3232
sqlParam.UdtTypeName = ColumnDefinition;
3333
}
3434

3535
public override object FromDbValue(Type fieldType, object value)
3636
{
37-
if (((SqlHierarchyId)value).IsNull && fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>))
37+
if (((SqlHierarchyId)value).IsNull && fieldType.IsNullableType())
3838
{
3939
return null;
4040
}

src/ServiceStack.OrmLite.SqlServer.Converters/SqlServerTypeConverter.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public override void InitDbParam(IDbDataParameter p, Type fieldType)
1515
{
1616
var sqlParam = (SqlParameter)p;
1717
sqlParam.SqlDbType = SqlDbType.Udt;
18-
sqlParam.IsNullable = (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>));
18+
sqlParam.IsNullable = fieldType.IsNullableType();
1919
sqlParam.UdtTypeName = ColumnDefinition;
2020
}
2121
}

src/ServiceStack.OrmLite/Expressions/SqlExpression.cs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,14 @@ protected virtual object VisitBinary(BinaryExpression b)
12951295
}
12961296
}
12971297

1298+
if (left.ToString().Equals("null", StringComparison.OrdinalIgnoreCase))
1299+
{
1300+
// "null is x" will not work, so swap the operands
1301+
var temp = right;
1302+
right = left;
1303+
left = temp;
1304+
}
1305+
12981306
if (operand == "=" && right.ToString().Equals("null", StringComparison.OrdinalIgnoreCase))
12991307
operand = "is";
13001308
else if (operand == "<>" && right.ToString().Equals("null", StringComparison.OrdinalIgnoreCase))
@@ -1335,11 +1343,23 @@ protected virtual void ConvertToPlaceholderAndParameter(ref object right)
13351343

13361344
protected virtual object VisitMemberAccess(MemberExpression m)
13371345
{
1338-
if (m.Expression != null &&
1339-
(m.Expression.NodeType == ExpressionType.Parameter ||
1340-
m.Expression.NodeType == ExpressionType.Convert))
1346+
if (m.Expression != null)
13411347
{
1342-
return GetMemberExpression(m);
1348+
if (m.Member.DeclaringType.IsNullableType())
1349+
{
1350+
if (m.Member.Name == nameof(Nullable<bool>.Value))
1351+
return Visit(m.Expression);
1352+
if (m.Member.Name == nameof(Nullable<bool>.HasValue))
1353+
{
1354+
var doesNotEqualNull = Expression.MakeBinary(ExpressionType.NotEqual, m.Expression, Expression.Constant(null));
1355+
return Visit(doesNotEqualNull); // Nullable<T>.HasValue is equivalent to "!= null"
1356+
}
1357+
1358+
throw new ArgumentException(string.Format("Expression '{0}' accesses unsupported property '{1}' of Nullable<T>", m, m.Member));
1359+
}
1360+
1361+
if (m.Expression.NodeType == ExpressionType.Parameter || m.Expression.NodeType == ExpressionType.Convert)
1362+
return GetMemberExpression(m);
13431363
}
13441364

13451365
return CachedExpressionCompiler.Evaluate(m);
@@ -1666,7 +1686,7 @@ protected virtual object VisitStaticArrayMethodCall(MethodCallExpression m)
16661686
{
16671687
case "Contains":
16681688
List<Object> args = this.VisitExpressionList(m.Arguments);
1669-
object quotedColName = args[1];
1689+
object quotedColName = args.Last();
16701690

16711691
Expression memberExpr = m.Arguments[0];
16721692
if (memberExpr.NodeType == ExpressionType.MemberAccess)

src/ServiceStack.OrmLite/OrmLiteConfigExtensions.cs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@ internal static class OrmLiteConfigExtensions
2222
{
2323
private static Dictionary<Type, ModelDefinition> typeModelDefinitionMap = new Dictionary<Type, ModelDefinition>();
2424

25-
private static bool IsNullableType(Type theType)
26-
{
27-
return (theType.IsGenericType
28-
&& theType.GetGenericTypeDefinition() == typeof(Nullable<>));
29-
}
30-
3125
internal static bool CheckForIdField(IEnumerable<PropertyInfo> objProperties)
3226
{
3327
// Not using Linq.Where() and manually iterating through objProperties just to avoid dependencies on System.Xml??
@@ -104,7 +98,7 @@ internal static ModelDefinition GetModelDefinition(this Type modelType)
10498
var isRowVersion = propertyInfo.Name == ModelDefinition.RowVersionName
10599
&& propertyInfo.PropertyType == typeof(ulong);
106100

107-
var isNullableType = IsNullableType(propertyInfo.PropertyType);
101+
var isNullableType = propertyInfo.PropertyType.IsNullableType();
108102

109103
var isNullable = (!propertyInfo.PropertyType.IsValueType
110104
&& !propertyInfo.HasAttributeNamed(typeof(RequiredAttribute).Name))

tests/ServiceStack.OrmLite.Tests/Expression/MethodExpressionTests.cs

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,22 @@ public class MethodExpressionTests : ExpressionsTestBase
1212
[Test]
1313
public void Can_select_ints_using_array_contains()
1414
{
15-
var ints = new[] { 1, 2, 3 };
15+
var ints = new[] { 1, 20, 30 };
16+
var nullableInts = new int?[] { 5, 30, null, 20 };
1617

1718
using (var db = OpenDbConnection())
1819
{
19-
db.Select<TestType>(x => ints.Contains(x.Id));
20+
var int10 = new TestType { IntColumn = 10 };
21+
var int20 = new TestType { IntColumn = 20 };
22+
var int30 = new TestType { IntColumn = 30 };
23+
24+
EstablishContext(0, int10, int20, int30);
25+
26+
var results = db.Select<TestType>(x => ints.Contains(x.IntColumn));
27+
var resultsNullable = db.Select<TestType>(x => nullableInts.Contains(x.IntColumn));
28+
29+
CollectionAssert.AreEquivalent(new[] { int20, int30 }, results);
30+
CollectionAssert.AreEquivalent(new[] { int20, int30 }, resultsNullable);
2031

2132
Assert.That(db.GetLastSql(), Is.StringContaining("(@0,@1,@2)").
2233
Or.StringContaining("(:0,:1,:2)"));
@@ -26,11 +37,22 @@ public void Can_select_ints_using_array_contains()
2637
[Test]
2738
public void Can_select_ints_using_list_contains()
2839
{
29-
var ints = new[] { 1, 2, 3 }.ToList();
40+
var ints = new[] { 1, 20, 30 }.ToList();
41+
var nullableInts = new int?[] { 5, 30, null, 20 }.ToList();
3042

3143
using (var db = OpenDbConnection())
3244
{
33-
db.Select<TestType>(x => ints.Contains(x.Id));
45+
var int10 = new TestType { IntColumn = 10 };
46+
var int20 = new TestType { IntColumn = 20 };
47+
var int30 = new TestType { IntColumn = 30 };
48+
49+
EstablishContext(0, int10, int20, int30);
50+
51+
var results = db.Select<TestType>(x => ints.Contains(x.IntColumn));
52+
var resultsNullable = db.Select<TestType>(x => nullableInts.Contains(x.IntColumn));
53+
54+
CollectionAssert.AreEquivalent(new[] { int20, int30 }, results);
55+
CollectionAssert.AreEquivalent(new[] { int20, int30 }, resultsNullable);
3456

3557
Assert.That(db.GetLastSql(), Is.StringContaining("(@0,@1,@2)").
3658
Or.StringContaining("(:0,:1,:2)"));
@@ -44,8 +66,11 @@ public void Can_select_ints_using_empty_array_contains()
4466

4567
using (var db = OpenDbConnection())
4668
{
47-
db.Select<TestType>(x => ints.Contains(x.Id));
48-
69+
EstablishContext(5);
70+
71+
var results = db.Select<TestType>(x => ints.Contains(x.Id));
72+
73+
CollectionAssert.IsEmpty(results);
4974
Assert.That(db.GetLastSql(), Is.StringContaining("(NULL)"));
5075
}
5176
}

tests/ServiceStack.OrmLite.Tests/ExpressionVisitorTests.cs

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Collections.Generic;
23
using System.Data;
34
using System.Linq;
45
using NUnit.Framework;
@@ -16,10 +17,10 @@ public void Setup()
1617
using (var db = OpenDbConnection())
1718
{
1819
db.DropAndCreateTable<TestType>();
19-
db.Insert(new TestType { Id = 1, BoolCol = true, DateCol = new DateTime(2012, 1, 1), TextCol = "asdf", EnumCol = TestEnum.Val0 });
20-
db.Insert(new TestType { Id = 2, BoolCol = true, DateCol = new DateTime(2012, 2, 1), TextCol = "asdf123", EnumCol = TestEnum.Val1 });
21-
db.Insert(new TestType { Id = 3, BoolCol = false, DateCol = new DateTime(2012, 3, 1), TextCol = "qwer", EnumCol = TestEnum.Val2 });
22-
db.Insert(new TestType { Id = 4, BoolCol = false, DateCol = new DateTime(2012, 4, 1), TextCol = "qwer123", EnumCol = TestEnum.Val3 });
20+
db.Insert(new TestType { Id = 1, BoolCol = true, DateCol = new DateTime(2012, 1, 1), TextCol = "asdf", EnumCol = TestEnum.Val0, NullableIntCol = 10 });
21+
db.Insert(new TestType { Id = 2, BoolCol = true, DateCol = new DateTime(2012, 2, 1), TextCol = "asdf123", EnumCol = TestEnum.Val1, NullableIntCol = null });
22+
db.Insert(new TestType { Id = 3, BoolCol = false, DateCol = new DateTime(2012, 3, 1), TextCol = "qwer", EnumCol = TestEnum.Val2, NullableIntCol = 30 });
23+
db.Insert(new TestType { Id = 4, BoolCol = false, DateCol = new DateTime(2012, 4, 1), TextCol = "qwer123", EnumCol = TestEnum.Val3, NullableIntCol = 40 });
2324
}
2425
Db = OpenDbConnection();
2526
}
@@ -145,6 +146,69 @@ public void Can_Select_using_IN_using_object_array()
145146
Assert.AreEqual(3, target.Count);
146147
}
147148

149+
[Test]
150+
public void Can_Select_using_int_array_Contains()
151+
{
152+
var ids = new[] { 1, 2 };
153+
var q = Db.From<TestType>().Where(x => ids.Contains(x.Id));
154+
var target = Db.Select(q);
155+
CollectionAssert.AreEquivalent(ids, target.Select(t => t.Id).ToArray());
156+
}
157+
158+
[Test]
159+
public void Can_Select_using_int_list_Contains()
160+
{
161+
var ids = new List<int> { 1, 2 };
162+
var q = Db.From<TestType>().Where(x => ids.Contains(x.Id));
163+
var target = Db.Select(q);
164+
CollectionAssert.AreEquivalent(ids, target.Select(t => t.Id).ToArray());
165+
}
166+
167+
[Test]
168+
public void Can_Select_using_int_array_Contains_Value()
169+
{
170+
var ints = new[] { 10, 40 };
171+
var q = Db.From<TestType>().Where(x => ints.Contains(x.NullableIntCol.Value)); // Doesn't compile without ".Value" here - "ints" is not nullable
172+
var target = Db.Select(q);
173+
CollectionAssert.AreEquivalent(new[] { 1, 4 }, target.Select(t => t.Id).ToArray());
174+
}
175+
176+
[Test]
177+
public void Can_Select_using_Nullable_HasValue()
178+
{
179+
var q = Db.From<TestType>().Where(x => x.NullableIntCol.HasValue); // WHERE NullableIntCol IS NOT NULL
180+
var target = Db.Select(q);
181+
CollectionAssert.AreEquivalent(new[] { 1, 3, 4 }, target.Select(t => t.Id).ToArray());
182+
183+
q = Db.From<TestType>().Where(x => !x.NullableIntCol.HasValue); // WHERE NOT (NullableIntCol IS NOT NULL)
184+
target = Db.Select(q);
185+
CollectionAssert.AreEquivalent(new[] { 2 }, target.Select(t => t.Id).ToArray());
186+
}
187+
188+
[Test]
189+
public void Can_Select_using_constant_Yoda_condition()
190+
{
191+
var q = Db.From<TestType>().Where(x => null != x.NullableIntCol); // "null != x.NullableIntCol" should be the same as "x.NullableIntCol != null"
192+
var target = Db.Select(q);
193+
CollectionAssert.AreEquivalent(new[] { 1, 3, 4 }, target.Select(t => t.Id).ToArray());
194+
}
195+
196+
[Test]
197+
public void Can_Select_using_int_array_constructed_inside_Contains()
198+
{
199+
var q = Db.From<TestType>().Where(x => new int?[] { 10, 30 }.Contains(x.NullableIntCol));
200+
var target = Db.Select(q);
201+
CollectionAssert.AreEquivalent(new[] { 1, 3 }, target.Select(t => t.Id).ToArray());
202+
}
203+
204+
[Test]
205+
public void Can_Select_using_int_list_constructed_inside_Contains()
206+
{
207+
var q = Db.From<TestType>().Where(x => new List<int?> { 10, 30 }.Contains(x.NullableIntCol));
208+
var target = Db.Select(q);
209+
CollectionAssert.AreEquivalent(new[] { 1, 3 }, target.Select(t => t.Id).ToArray());
210+
}
211+
148212
[Test]
149213
public void Can_Select_using_Startswith()
150214
{
@@ -240,5 +304,6 @@ public class TestType
240304
public DateTime DateCol { get; set; }
241305
public TestEnum EnumCol { get; set; }
242306
public TestType ComplexObjCol { get; set; }
307+
public int? NullableIntCol { get; set; }
243308
}
244309
}

0 commit comments

Comments
 (0)