@@ -4,12 +4,19 @@ use crate::parser_util::*;
4
4
use crate :: sql_types:: * ;
5
5
use graphql_parser:: query:: * ;
6
6
use serde:: Serialize ;
7
- use std:: collections:: HashMap ;
7
+ use std:: collections:: { HashMap , HashSet } ;
8
8
use std:: hash:: Hash ;
9
9
use std:: ops:: Deref ;
10
10
use std:: str:: FromStr ;
11
11
use std:: sync:: Arc ;
12
12
13
+ #[ derive( Clone , Debug ) ]
14
+ pub struct OnConflictBuilder {
15
+ pub constraint : Index , // Could probably get away with a name ref
16
+ pub update_fields : HashSet < Arc < Column > > , // Could probably get away with a name ref
17
+ pub filter : FilterBuilder ,
18
+ }
19
+
13
20
#[ derive( Clone , Debug ) ]
14
21
pub struct InsertBuilder {
15
22
pub alias : String ,
@@ -22,6 +29,8 @@ pub struct InsertBuilder {
22
29
23
30
//fields
24
31
pub selections : Vec < InsertSelection > ,
32
+
33
+ pub on_conflict : Option < OnConflictBuilder > ,
25
34
}
26
35
27
36
#[ derive( Clone , Debug ) ]
@@ -176,6 +185,90 @@ where
176
185
parse_node_id ( node_id_base64_encoded_json_string)
177
186
}
178
187
188
+ fn read_argument_on_conflict < ' a , T > (
189
+ field : & __Field ,
190
+ query_field : & graphql_parser:: query:: Field < ' a , T > ,
191
+ variables : & serde_json:: Value ,
192
+ variable_definitions : & Vec < VariableDefinition < ' a , T > > ,
193
+ ) -> Result < Option < OnConflictBuilder > , String >
194
+ where
195
+ T : Text < ' a > + Eq + AsRef < str > ,
196
+ {
197
+ let validated: gson:: Value = read_argument (
198
+ "onConflict" ,
199
+ field,
200
+ query_field,
201
+ variables,
202
+ variable_definitions,
203
+ ) ?;
204
+
205
+ let insert_type: InsertOnConflictType = match field. get_arg ( "onConflict" ) {
206
+ None => return Ok ( None ) ,
207
+ Some ( x) => match x. type_ ( ) . unmodified_type ( ) {
208
+ __Type:: InsertOnConflictInput ( insert_on_conflict) => insert_on_conflict,
209
+ _ => return Err ( "Could not locate Insert Entity type" . to_string ( ) ) ,
210
+ } ,
211
+ } ;
212
+
213
+ let filter: FilterBuilder =
214
+ read_argument_filter ( field, query_field, variables, variable_definitions) ?;
215
+
216
+ let on_conflict_builder = match validated {
217
+ gson:: Value :: Absent | gson:: Value :: Null => None ,
218
+ gson:: Value :: Object ( contents) => {
219
+ let constraint = match contents
220
+ . get ( "constraint" )
221
+ . expect ( "OnConflict revalidation error. Expected constraint" )
222
+ {
223
+ gson:: Value :: String ( ix_name) => insert_type
224
+ . table
225
+ . indexes
226
+ . iter ( )
227
+ . find ( |ix| & ix. name == ix_name)
228
+ . expect ( "OnConflict revalidation error. constraint: unknown constraint name" ) ,
229
+ _ => {
230
+ return Err (
231
+ "OnConflict revalidation error. Expected constraint as String" . to_string ( ) ,
232
+ )
233
+ }
234
+ } ;
235
+
236
+ let update_fields = match contents
237
+ . get ( "updateFields" )
238
+ . expect ( "OnConflict revalidation error. Expected updateFields" )
239
+ {
240
+ gson:: Value :: Array ( col_names) => {
241
+ let mut update_columns: HashSet < Arc < Column > > = HashSet :: new ( ) ;
242
+ for col_name in col_names {
243
+ match col_name {
244
+ gson:: Value :: String ( c) => {
245
+ let col = insert_type. table . columns . iter ( ) . find ( |column| & column. name == c) . expect ( "OnConflict revalidation error. updateFields: unknown column name" ) ;
246
+ update_columns. insert ( Arc :: clone ( col) ) ;
247
+ }
248
+ _ => return Err ( "OnConflict revalidation error. Expected updateFields to be column names" . to_string ( ) ) ,
249
+ }
250
+ }
251
+ update_columns
252
+ }
253
+ _ => {
254
+ return Err (
255
+ "OnConflict revalidation error. Expected updateFields to be an array"
256
+ . to_string ( ) ,
257
+ )
258
+ }
259
+ } ;
260
+
261
+ Some ( OnConflictBuilder {
262
+ constraint : constraint. clone ( ) ,
263
+ update_fields,
264
+ filter,
265
+ } )
266
+ }
267
+ _ => return Err ( "Insert re-validation errror" . to_string ( ) ) ,
268
+ } ;
269
+ Ok ( on_conflict_builder)
270
+ }
271
+
179
272
fn read_argument_objects < ' a , T > (
180
273
field : & __Field ,
181
274
query_field : & graphql_parser:: query:: Field < ' a , T > ,
@@ -277,11 +370,14 @@ where
277
370
match & type_ {
278
371
__Type:: InsertResponse ( xtype) => {
279
372
// Raise for disallowed arguments
280
- restrict_allowed_arguments ( & [ "objects" ] , query_field) ?;
373
+ restrict_allowed_arguments ( & [ "objects" , "onConflict" ] , query_field) ?;
281
374
282
375
let objects: Vec < InsertRowBuilder > =
283
376
read_argument_objects ( field, query_field, variables, variable_definitions) ?;
284
377
378
+ let on_conflict: Option < OnConflictBuilder > =
379
+ read_argument_on_conflict ( field, query_field, variables, variable_definitions) ?;
380
+
285
381
let mut builder_fields: Vec < InsertSelection > = vec ! [ ] ;
286
382
287
383
let selection_fields = normalize_selection_set (
@@ -324,6 +420,7 @@ where
324
420
table : Arc :: clone ( & xtype. table ) ,
325
421
objects,
326
422
selections : builder_fields,
423
+ on_conflict,
327
424
} )
328
425
}
329
426
_ => Err ( format ! (
0 commit comments