Skip to content

Commit ddf2826

Browse files
committed
on conflict builder
1 parent b2797e4 commit ddf2826

File tree

1 file changed

+99
-2
lines changed

1 file changed

+99
-2
lines changed

src/builder.rs

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,19 @@ use crate::parser_util::*;
44
use crate::sql_types::*;
55
use graphql_parser::query::*;
66
use serde::Serialize;
7-
use std::collections::HashMap;
7+
use std::collections::{HashMap, HashSet};
88
use std::hash::Hash;
99
use std::ops::Deref;
1010
use std::str::FromStr;
1111
use std::sync::Arc;
1212

13+
#[derive(Clone, Debug)]
14+
pub struct OnConflictBuilder {
15+
pub constraint: Index, // Could probably get away with a name ref
16+
pub update_fields: HashSet<Arc<Column>>, // Could probably get away with a name ref
17+
pub filter: FilterBuilder,
18+
}
19+
1320
#[derive(Clone, Debug)]
1421
pub struct InsertBuilder {
1522
pub alias: String,
@@ -22,6 +29,8 @@ pub struct InsertBuilder {
2229

2330
//fields
2431
pub selections: Vec<InsertSelection>,
32+
33+
pub on_conflict: Option<OnConflictBuilder>,
2534
}
2635

2736
#[derive(Clone, Debug)]
@@ -176,6 +185,90 @@ where
176185
parse_node_id(node_id_base64_encoded_json_string)
177186
}
178187

188+
fn read_argument_on_conflict<'a, T>(
189+
field: &__Field,
190+
query_field: &graphql_parser::query::Field<'a, T>,
191+
variables: &serde_json::Value,
192+
variable_definitions: &Vec<VariableDefinition<'a, T>>,
193+
) -> Result<Option<OnConflictBuilder>, String>
194+
where
195+
T: Text<'a> + Eq + AsRef<str>,
196+
{
197+
let validated: gson::Value = read_argument(
198+
"onConflict",
199+
field,
200+
query_field,
201+
variables,
202+
variable_definitions,
203+
)?;
204+
205+
let insert_type: InsertOnConflictType = match field.get_arg("onConflict") {
206+
None => return Ok(None),
207+
Some(x) => match x.type_().unmodified_type() {
208+
__Type::InsertOnConflictInput(insert_on_conflict) => insert_on_conflict,
209+
_ => return Err("Could not locate Insert Entity type".to_string()),
210+
},
211+
};
212+
213+
let filter: FilterBuilder =
214+
read_argument_filter(field, query_field, variables, variable_definitions)?;
215+
216+
let on_conflict_builder = match validated {
217+
gson::Value::Absent | gson::Value::Null => None,
218+
gson::Value::Object(contents) => {
219+
let constraint = match contents
220+
.get("constraint")
221+
.expect("OnConflict revalidation error. Expected constraint")
222+
{
223+
gson::Value::String(ix_name) => insert_type
224+
.table
225+
.indexes
226+
.iter()
227+
.find(|ix| &ix.name == ix_name)
228+
.expect("OnConflict revalidation error. constraint: unknown constraint name"),
229+
_ => {
230+
return Err(
231+
"OnConflict revalidation error. Expected constraint as String".to_string(),
232+
)
233+
}
234+
};
235+
236+
let update_fields = match contents
237+
.get("updateFields")
238+
.expect("OnConflict revalidation error. Expected updateFields")
239+
{
240+
gson::Value::Array(col_names) => {
241+
let mut update_columns: HashSet<Arc<Column>> = HashSet::new();
242+
for col_name in col_names {
243+
match col_name {
244+
gson::Value::String(c) => {
245+
let col = insert_type.table.columns.iter().find(|column| &column.name == c).expect("OnConflict revalidation error. updateFields: unknown column name");
246+
update_columns.insert(Arc::clone(col));
247+
}
248+
_ => return Err("OnConflict revalidation error. Expected updateFields to be column names".to_string()),
249+
}
250+
}
251+
update_columns
252+
}
253+
_ => {
254+
return Err(
255+
"OnConflict revalidation error. Expected updateFields to be an array"
256+
.to_string(),
257+
)
258+
}
259+
};
260+
261+
Some(OnConflictBuilder {
262+
constraint: constraint.clone(),
263+
update_fields,
264+
filter,
265+
})
266+
}
267+
_ => return Err("Insert re-validation errror".to_string()),
268+
};
269+
Ok(on_conflict_builder)
270+
}
271+
179272
fn read_argument_objects<'a, T>(
180273
field: &__Field,
181274
query_field: &graphql_parser::query::Field<'a, T>,
@@ -277,11 +370,14 @@ where
277370
match &type_ {
278371
__Type::InsertResponse(xtype) => {
279372
// Raise for disallowed arguments
280-
restrict_allowed_arguments(&["objects"], query_field)?;
373+
restrict_allowed_arguments(&["objects", "onConflict"], query_field)?;
281374

282375
let objects: Vec<InsertRowBuilder> =
283376
read_argument_objects(field, query_field, variables, variable_definitions)?;
284377

378+
let on_conflict: Option<OnConflictBuilder> =
379+
read_argument_on_conflict(field, query_field, variables, variable_definitions)?;
380+
285381
let mut builder_fields: Vec<InsertSelection> = vec![];
286382

287383
let selection_fields = normalize_selection_set(
@@ -324,6 +420,7 @@ where
324420
table: Arc::clone(&xtype.table),
325421
objects,
326422
selections: builder_fields,
423+
on_conflict,
327424
})
328425
}
329426
_ => Err(format!(

0 commit comments

Comments
 (0)