Skip to content

Commit 97ebf4b

Browse files
authored
Added nullability awareness to projections. (#7541)
1 parent b5e9821 commit 97ebf4b

File tree

10 files changed

+586
-106
lines changed

10 files changed

+586
-106
lines changed

src/GreenDonut/src/Core/GreenDonut.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
</PropertyGroup>
99

1010
<ItemGroup>
11+
<InternalsVisibleTo Include="HotChocolate.Execution" />
1112
<InternalsVisibleTo Include="HotChocolate.Pagination.Batching" />
1213
</ItemGroup>
1314

src/HotChocolate/Core/src/Execution/Extensions/HotChocolateExecutionSelectionExtensions.cs

Lines changed: 125 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
using System.Buffers;
12
using System.Buffers.Text;
3+
using System.Diagnostics.CodeAnalysis;
24
using System.Linq.Expressions;
35
using System.Text;
46
using System.Runtime.CompilerServices;
7+
using GreenDonut.Projections;
58
using HotChocolate.Execution.Projections;
9+
using HotChocolate.Types;
10+
using HotChocolate.Types.Descriptors.Definitions;
11+
using HotChocolate.Utilities;
612

713
// ReSharper disable once CheckNamespace
814
namespace HotChocolate.Execution.Processing;
@@ -26,17 +32,132 @@ public static class HotChocolateExecutionSelectionExtensions
2632
/// <returns>
2733
/// Returns a selector expression that can be used for data projections.
2834
/// </returns>
35+
[Experimental(Experiments.Projections)]
2936
public static Expression<Func<TValue, TValue>> AsSelector<TValue>(
3037
this ISelection selection)
31-
=> GetOrCreateExpression<TValue>(selection);
38+
{
39+
// we first check if we already have an expression for this selection,
40+
// this would be the cheapest way to get the expression.
41+
if(TryGetExpression<TValue>(selection, out var expression))
42+
{
43+
return expression;
44+
}
45+
46+
// if we do not have an expression we need to create one.
47+
// we first check what kind of field selection we have,
48+
// connection, collection or single field.
49+
var flags = ((ObjectField)selection.Field).Flags;
50+
51+
if ((flags & FieldFlags.Connection) == FieldFlags.Connection)
52+
{
53+
var builder = new DefaultSelectorBuilder<TValue>();
54+
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
55+
var count = GetConnectionSelections(selection, buffer);
56+
for (var i = 0; i < count; i++)
57+
{
58+
builder.Add(GetOrCreateExpression<TValue>(buffer[i]));
59+
}
60+
ArrayPool<ISelection>.Shared.Return(buffer);
61+
return GetOrCreateExpression<TValue>(selection, builder);
62+
}
63+
64+
if ((flags & FieldFlags.CollectionSegment) == FieldFlags.CollectionSegment)
65+
{
66+
var builder = new DefaultSelectorBuilder<TValue>();
67+
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
68+
var count = GetCollectionSelections(selection, buffer);
69+
for (var i = 0; i < count; i++)
70+
{
71+
builder.Add(GetOrCreateExpression<TValue>(buffer[i]));
72+
}
73+
ArrayPool<ISelection>.Shared.Return(buffer);
74+
return GetOrCreateExpression<TValue>(selection, builder);
75+
}
76+
77+
return GetOrCreateExpression<TValue>(selection);
78+
}
3279

3380
private static Expression<Func<TValue, TValue>> GetOrCreateExpression<TValue>(
3481
ISelection selection)
35-
{
36-
return selection.DeclaringOperation.GetOrAddState(
82+
=> selection.DeclaringOperation.GetOrAddState(
3783
CreateExpressionKey(selection.Id),
3884
static (_, ctx) => ctx._builder.BuildExpression<TValue>(ctx.selection),
3985
(_builder, selection));
86+
87+
[Experimental(Experiments.Projections)]
88+
private static Expression<Func<TValue, TValue>> GetOrCreateExpression<TValue>(
89+
ISelection selection,
90+
ISelectorBuilder builder)
91+
=> selection.DeclaringOperation.GetOrAddState(
92+
CreateExpressionKey(selection.Id),
93+
static (_, ctx) => ctx.builder.TryCompile<TValue>()!,
94+
(builder, selection));
95+
96+
private static bool TryGetExpression<TValue>(
97+
ISelection selection,
98+
[NotNullWhen(true)] out Expression<Func<TValue, TValue>>? expression)
99+
=> selection.DeclaringOperation.TryGetState(CreateExpressionKey(selection.Id), out expression);
100+
101+
private static int GetConnectionSelections(ISelection selection, Span<ISelection> buffer)
102+
{
103+
var pageType = (ObjectType)selection.Field.Type.NamedType();
104+
var connectionSelections = selection.DeclaringOperation.GetSelectionSet(selection, pageType);
105+
var count = 0;
106+
107+
foreach (var connectionChild in connectionSelections.Selections)
108+
{
109+
if (connectionChild.Field.Name.EqualsOrdinal("nodes"))
110+
{
111+
if (buffer.Length == count)
112+
{
113+
throw new InvalidOperationException("To many alias selections of nodes and edges.");
114+
}
115+
116+
buffer[count++] = connectionChild;
117+
}
118+
else if (connectionChild.Field.Name.EqualsOrdinal("edges"))
119+
{
120+
var edgeType = (ObjectType)connectionChild.Field.Type.NamedType();
121+
var edgeSelections = connectionChild.DeclaringOperation.GetSelectionSet(connectionChild, edgeType);
122+
123+
foreach (var edgeChild in edgeSelections.Selections)
124+
{
125+
if (edgeChild.Field.Name.EqualsOrdinal("node"))
126+
{
127+
if (buffer.Length == count)
128+
{
129+
throw new InvalidOperationException("To many alias selections of nodes and edges.");
130+
}
131+
132+
buffer[count++] = edgeChild;
133+
}
134+
}
135+
}
136+
}
137+
138+
return count;
139+
}
140+
141+
private static int GetCollectionSelections(ISelection selection, Span<ISelection> buffer)
142+
{
143+
var pageType = (ObjectType)selection.Field.Type.NamedType();
144+
var connectionSelections = selection.DeclaringOperation.GetSelectionSet(selection, pageType);
145+
var count = 0;
146+
147+
foreach (var connectionChild in connectionSelections.Selections)
148+
{
149+
if (connectionChild.Field.Name.EqualsOrdinal("items"))
150+
{
151+
if (buffer.Length == count)
152+
{
153+
throw new InvalidOperationException("To many alias selections of items.");
154+
}
155+
156+
buffer[count++] = connectionChild;
157+
}
158+
}
159+
160+
return count;
40161
}
41162

42163
private static string CreateExpressionKey(int key)
@@ -62,7 +183,7 @@ private static int EstimateIntLength(int value)
62183
}
63184

64185
// if the number is negative we need one more digit for the sign
65-
var length = (value < 0) ? 1 : 0;
186+
var length = value < 0 ? 1 : 0;
66187

67188
// we add the number of digits the number has to the length of the number.
68189
length += (int)Math.Floor(Math.Log10(Math.Abs(value)) + 1);

src/HotChocolate/Core/src/Execution/Processing/Operation.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,25 @@ public long CreateIncludeFlags(IVariableValueCollection variables)
117117
return context;
118118
}
119119

120+
public bool TryGetState<TState>(out TState? state)
121+
{
122+
var key = typeof(TState).FullName ?? throw new InvalidOperationException();
123+
return TryGetState(key, out state);
124+
}
125+
126+
public bool TryGetState<TState>(string key, out TState? state)
127+
{
128+
if(_contextData.TryGetValue(key, out var value)
129+
&& value is TState casted)
130+
{
131+
state = casted;
132+
return true;
133+
}
134+
135+
state = default;
136+
return false;
137+
}
138+
120139
public TState GetOrAddState<TState>(Func<TState> createState)
121140
=> GetOrAddState<TState, object?>(_ => createState(), null);
122141

Lines changed: 17 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
using System.Buffers;
21
using System.Diagnostics.CodeAnalysis;
3-
using HotChocolate.Execution;
42
using HotChocolate.Execution.Processing;
53
using HotChocolate.Pagination;
6-
using HotChocolate.Types;
7-
using HotChocolate.Types.Descriptors.Definitions;
8-
using HotChocolate.Utilities;
94

105
// ReSharper disable once CheckNamespace
116
namespace GreenDonut.Projections;
@@ -39,6 +34,16 @@ public static ISelectionDataLoader<TKey, TValue> Select<TKey, TValue>(
3934
ISelection selection)
4035
where TKey : notnull
4136
{
37+
if (dataLoader == null)
38+
{
39+
throw new ArgumentNullException(nameof(dataLoader));
40+
}
41+
42+
if (selection == null)
43+
{
44+
throw new ArgumentNullException(nameof(selection));
45+
}
46+
4247
var expression = selection.AsSelector<TValue>();
4348
return dataLoader.Select(expression);
4449
}
@@ -66,98 +71,18 @@ public static IPagingDataLoader<TKey, Page<TValue>> Select<TKey, TValue>(
6671
ISelection selection)
6772
where TKey : notnull
6873
{
69-
var flags = ((ObjectField)selection.Field).Flags;
70-
71-
if ((flags & FieldFlags.Connection) == FieldFlags.Connection)
74+
if (dataLoader == null)
7275
{
73-
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
74-
var count = GetConnectionSelections(selection, buffer);
75-
for (var i = 0; i < count; i++)
76-
{
77-
var expression = buffer[i].AsSelector<TValue>();
78-
dataLoader.Select(expression);
79-
}
80-
ArrayPool<ISelection>.Shared.Return(buffer);
81-
}
82-
else if ((flags & FieldFlags.CollectionSegment) == FieldFlags.CollectionSegment)
83-
{
84-
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
85-
var count = GetCollectionSelections(selection, buffer);
86-
for (var i = 0; i < count; i++)
87-
{
88-
var expression = buffer[i].AsSelector<TValue>();
89-
dataLoader.Select(expression);
90-
}
91-
ArrayPool<ISelection>.Shared.Return(buffer);
92-
}
93-
else
94-
{
95-
var expression = selection.AsSelector<TValue>();
96-
dataLoader.Select(expression);
76+
throw new ArgumentNullException(nameof(dataLoader));
9777
}
9878

99-
return dataLoader;
100-
}
101-
102-
private static int GetConnectionSelections(ISelection selection, Span<ISelection> buffer)
103-
{
104-
var pageType = (ObjectType)selection.Field.Type.NamedType();
105-
var connectionSelections = selection.DeclaringOperation.GetSelectionSet(selection, pageType);
106-
var count = 0;
107-
108-
foreach (var connectionChild in connectionSelections.Selections)
109-
{
110-
if (connectionChild.Field.Name.EqualsOrdinal("nodes"))
111-
{
112-
if (buffer.Length == count)
113-
{
114-
throw new InvalidOperationException("To many alias selections of nodes and edges.");
115-
}
116-
117-
buffer[count++] = connectionChild;
118-
}
119-
else if (connectionChild.Field.Name.EqualsOrdinal("edges"))
120-
{
121-
var edgeType = (ObjectType)selection.Field.Type.NamedType();
122-
var edgeSelections = selection.DeclaringOperation.GetSelectionSet(connectionChild, edgeType);
123-
124-
foreach (var edgeChild in edgeSelections.Selections)
125-
{
126-
if (edgeChild.Field.Name.EqualsOrdinal("node"))
127-
{
128-
if (buffer.Length == count)
129-
{
130-
throw new InvalidOperationException("To many alias selections of nodes and edges.");
131-
}
132-
133-
buffer[count++] = edgeChild;
134-
}
135-
}
136-
}
137-
}
138-
139-
return count;
140-
}
141-
142-
private static int GetCollectionSelections(ISelection selection, Span<ISelection> buffer)
143-
{
144-
var pageType = (ObjectType)selection.Field.Type.NamedType();
145-
var connectionSelections = selection.DeclaringOperation.GetSelectionSet(selection, pageType);
146-
var count = 0;
147-
148-
foreach (var connectionChild in connectionSelections.Selections)
79+
if (selection == null)
14980
{
150-
if (connectionChild.Field.Name.EqualsOrdinal("items"))
151-
{
152-
if (buffer.Length == count)
153-
{
154-
throw new InvalidOperationException("To many alias selections of items.");
155-
}
156-
157-
buffer[count++] = connectionChild;
158-
}
81+
throw new ArgumentNullException(nameof(selection));
15982
}
16083

161-
return count;
84+
var expression = selection.AsSelector<TValue>();
85+
dataLoader.Select(expression);
86+
return dataLoader;
16287
}
16388
}

src/HotChocolate/Core/src/Execution/Projections/SelectionExpressionBuilder.cs

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Collections.Immutable;
22
using System.Linq.Expressions;
33
using System.Reflection;
4+
using System.Runtime.CompilerServices;
45
using HotChocolate.Execution.Processing;
56
using HotChocolate.Features;
67
using HotChocolate.Types;
@@ -158,7 +159,19 @@ private void CollectSelections(
158159

159160
if (node.Nodes.Count == 0)
160161
{
161-
return Expression.Bind(node.Property, propertyAccessor);
162+
if (IsNullableType(node.Property))
163+
{
164+
var nullCheck = Expression.Condition(
165+
Expression.Equal(propertyAccessor, Expression.Constant(null)),
166+
Expression.Constant(null, node.Property.PropertyType),
167+
propertyAccessor);
168+
169+
return Expression.Bind(node.Property, nullCheck);
170+
}
171+
else
172+
{
173+
return Expression.Bind(node.Property, propertyAccessor);
174+
}
162175
}
163176

164177
if(node.IsArrayOrCollection)
@@ -167,8 +180,36 @@ private void CollectSelections(
167180
}
168181

169182
var newContext = context with { Parent = propertyAccessor, ParentType = node.Property.PropertyType };
170-
var requirementsExpression = BuildExpression(node.Nodes, newContext);
171-
return requirementsExpression is null ? null : Expression.Bind(node.Property, requirementsExpression);
183+
var nestedExpression = BuildExpression(node.Nodes, newContext);
184+
185+
if (IsNullableType(node.Property))
186+
{
187+
var nullCheck = Expression.Condition(
188+
Expression.Equal(propertyAccessor, Expression.Constant(null)),
189+
Expression.Constant(null, node.Property.PropertyType),
190+
nestedExpression ?? (Expression)Expression.Constant(null, node.Property.PropertyType));
191+
192+
return Expression.Bind(node.Property, nullCheck);
193+
}
194+
195+
return nestedExpression is null ? null : Expression.Bind(node.Property, nestedExpression);
196+
}
197+
198+
private static bool IsNullableType(PropertyInfo propertyInfo)
199+
{
200+
if (propertyInfo.PropertyType.IsValueType)
201+
{
202+
return Nullable.GetUnderlyingType(propertyInfo.PropertyType) != null;
203+
}
204+
205+
var nullableAttribute = propertyInfo.GetCustomAttribute<NullableAttribute>();
206+
207+
if (nullableAttribute != null)
208+
{
209+
return nullableAttribute.NullableFlags[0] == 2;
210+
}
211+
212+
return false;
172213
}
173214

174215
private MemberInitExpression? BuildExpression(

0 commit comments

Comments
 (0)