diff --git a/connectors/postgresql-connector/pom.xml b/connectors/postgresql-connector/pom.xml index 605a046..84c493b 100644 --- a/connectors/postgresql-connector/pom.xml +++ b/connectors/postgresql-connector/pom.xml @@ -57,9 +57,25 @@ provided - com.datasqrl.flinkrunner + ${project.groupId} + json-type + ${project.version} + + + ${project.groupId} + vector-type + ${project.version} + + + ${project.groupId} flexible-json-format - 1.0.0-SNAPSHOT + ${project.version} + + + ${project.groupId} + system-functions-discovery + ${project.version} + provided org.apache.flink diff --git a/connectors/postgresql-connector/src/main/java/com/datasqrl/connector/postgresql/type/PostgresVectorTypeSerializer.java b/connectors/postgresql-connector/src/main/java/com/datasqrl/connector/postgresql/type/PostgresVectorTypeSerializer.java new file mode 100644 index 0000000..e4cf05c --- /dev/null +++ b/connectors/postgresql-connector/src/main/java/com/datasqrl/connector/postgresql/type/PostgresVectorTypeSerializer.java @@ -0,0 +1,77 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.connector.postgresql.type; + +import com.datasqrl.connector.postgresql.type.JdbcTypeSerializer.GenericDeserializationConverter; +import com.datasqrl.connector.postgresql.type.JdbcTypeSerializer.GenericSerializationConverter; +import com.datasqrl.types.vector.FlinkVectorType; +import com.datasqrl.types.vector.FlinkVectorTypeSerializer; +import java.util.Arrays; +import org.apache.flink.connector.jdbc.converter.AbstractJdbcRowConverter.JdbcDeserializationConverter; +import org.apache.flink.connector.jdbc.converter.AbstractJdbcRowConverter.JdbcSerializationConverter; +import org.apache.flink.table.data.RawValueData; +import org.apache.flink.table.types.logical.LogicalType; +import org.postgresql.util.PGobject; + +public class PostgresVectorTypeSerializer + implements JdbcTypeSerializer { + + @Override + public String getDialectId() { + return "postgres"; + } + + @Override + public Class getConversionClass() { + return FlinkVectorType.class; + } + + @Override + public String dialectTypeName() { + return "vector"; + } + + @Override + public GenericDeserializationConverter getDeserializerConverter() { + return () -> + (val) -> { + FlinkVectorType t = (FlinkVectorType) val; + return t.getValue(); + }; + } + + @Override + public GenericSerializationConverter getSerializerConverter( + LogicalType type) { + FlinkVectorTypeSerializer flinkVectorTypeSerializer = new FlinkVectorTypeSerializer(); + return () -> + (val, index, statement) -> { + if (val != null && !val.isNullAt(index)) { + RawValueData object = val.getRawValue(index); + FlinkVectorType vec = object.toObject(flinkVectorTypeSerializer); + + if (vec != null) { + PGobject pgObject = new PGobject(); + pgObject.setType("vector"); + pgObject.setValue(Arrays.toString(vec.getValue())); + statement.setObject(index, pgObject); + return; + } + } + statement.setObject(index, null); + }; + } +} diff --git a/connectors/postgresql-connector/src/main/resources/META-INF/services/com.datasqrl.connector.postgresql.type.JdbcTypeSerializer b/connectors/postgresql-connector/src/main/resources/META-INF/services/com.datasqrl.connector.postgresql.type.JdbcTypeSerializer index 0d0b9e6..fe5d8c8 100644 --- a/connectors/postgresql-connector/src/main/resources/META-INF/services/com.datasqrl.connector.postgresql.type.JdbcTypeSerializer +++ b/connectors/postgresql-connector/src/main/resources/META-INF/services/com.datasqrl.connector.postgresql.type.JdbcTypeSerializer @@ -1,2 +1,3 @@ com.datasqrl.connector.postgresql.type.PostgresRowTypeSerializer -com.datasqrl.connector.postgresql.type.PostgresJsonTypeSerializer \ No newline at end of file +com.datasqrl.connector.postgresql.type.PostgresJsonTypeSerializer +com.datasqrl.connector.postgresql.type.PostgresVectorTypeSerializer \ No newline at end of file diff --git a/connectors/postgresql-connector/src/test/java/com/datasqrl/connector/postgresql/jdbc/FlinkJdbcTest.java b/connectors/postgresql-connector/src/test/java/com/datasqrl/connector/postgresql/jdbc/FlinkJdbcTest.java index 1f87d77..1b901f9 100644 --- a/connectors/postgresql-connector/src/test/java/com/datasqrl/connector/postgresql/jdbc/FlinkJdbcTest.java +++ b/connectors/postgresql-connector/src/test/java/com/datasqrl/connector/postgresql/jdbc/FlinkJdbcTest.java @@ -17,13 +17,21 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import com.datasqrl.types.json.FlinkJsonTypeSerializer; +import com.datasqrl.types.json.FlinkJsonTypeSerializerSnapshot; +import java.io.IOException; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.Statement; +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.EnvironmentSettings; +import org.apache.flink.table.api.ResultKind; +import org.apache.flink.table.api.TableResult; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.utils.EncodingUtils; import org.apache.flink.test.junit5.MiniClusterExtension; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -32,6 +40,82 @@ @ExtendWith(MiniClusterExtension.class) public class FlinkJdbcTest { + public static void main(String[] args) throws IOException { + var input = + new DataInputDeserializer( + EncodingUtils.decodeBase64ToBytes( + "ADFjb20uZGF0YXNxcmwuanNvbi5GbGlua0pzb25UeXBlU2VyaWFsaXplclNuYXBzaG90AAAAAQApY29tLmRhdGFzcXJsLmpzb24uRmxpbmtKc29uVHlwZVNlcmlhbGl6ZXI=")); + System.out.println(input.readUTF()); + System.out.println(input.readInt()); + System.out.println(input.readUTF()); + + var output = new DataOutputSerializer(53); + output.writeUTF(FlinkJsonTypeSerializerSnapshot.class.getName()); + output.writeInt(1); + output.writeUTF(FlinkJsonTypeSerializer.class.getName()); + System.out.println(EncodingUtils.encodeBytesToBase64(output.getSharedBuffer())); + } + + @Test + public void testFlinkWithPostgres() throws Exception { + // Start PostgreSQL container + try (PostgreSQLContainer postgres = new PostgreSQLContainer<>("postgres:14")) { + postgres.start(); + // Establish a connection and create the PostgreSQL table + try (Connection conn = + DriverManager.getConnection( + postgres.getJdbcUrl(), postgres.getUsername(), postgres.getPassword()); + Statement stmt = conn.createStatement()) { + String createTableSQL = "CREATE TABLE test_table (" + " \"arrayOfRows\" JSONB " + ")"; + stmt.executeUpdate(createTableSQL); + } + + // Set up Flink environment + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env); + + // Define the schema + String createSourceTable = + "CREATE TABLE datagen_source (" + + " arrayOfRows ARRAY> " + + ") WITH (" + + " 'connector' = 'datagen'," + + " 'number-of-rows' = '10'" + + ")"; + + String createSinkTable = + "CREATE TABLE jdbc_sink (" + + " arrayOfRows RAW('com.datasqrl.types.json.FlinkJsonType', 'ADdjb20uZGF0YXNxcmwudHlwZXMuanNvbi5GbGlua0pzb25UeXBlU2VyaWFsaXplclNuYXBzaG90AAAAAQAvY29tLmRhdGFzcXJsLnR5cGVzLmpzb24uRmxpbmtKc29uVHlwZVNlcmlhbGl6ZXI=') " + + ") WITH (" + + " 'connector' = 'jdbc-sqrl', " + + " 'url' = '" + + postgres.getJdbcUrl() + + "', " + + " 'table-name' = 'test_table', " + + " 'username' = '" + + postgres.getUsername() + + "', " + + " 'password' = '" + + postgres.getPassword() + + "'" + + ")"; + + // Register tables in the environment + tableEnv.executeSql( + "CREATE TEMPORARY FUNCTION IF NOT EXISTS `tojson` AS 'com.datasqrl.types.json.functions.ToJson' LANGUAGE JAVA"); + tableEnv.executeSql(createSourceTable); + tableEnv.executeSql(createSinkTable); + + // Set up a simple Flink job + TableResult tableResult = + tableEnv.executeSql( + "INSERT INTO jdbc_sink SELECT tojson(arrayOfRows) AS arrayOfRows FROM datagen_source"); + tableResult.print(); + + assertEquals(ResultKind.SUCCESS_WITH_CONTENT, tableResult.getResultKind()); + } + } + @Test public void testWriteAndReadToPostgres() throws Exception { try (PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("postgres:14")) { diff --git a/flink-sql-runner/pom.xml b/flink-sql-runner/pom.xml index 4fcb9af..ab11a2c 100644 --- a/flink-sql-runner/pom.xml +++ b/flink-sql-runner/pom.xml @@ -202,38 +202,38 @@ runtime test - + - ${project.groupId} - flexible-csv-format - ${project.version} - runtime - - - ${project.groupId} - flexible-json-format - ${project.version} - runtime - - - ${project.groupId} - system-functions-discovery - ${project.version} - runtime - - - ${project.groupId} - vector-type - ${project.version} - runtime - - - ${project.groupId} - postgresql-connector - ${project.version} - runtime - + ${project.groupId} + flexible-csv-format + ${project.version} + runtime + + + ${project.groupId} + flexible-json-format + ${project.version} + runtime + + + ${project.groupId} + system-functions-discovery + ${project.version} + runtime + + + ${project.groupId} + vector-type + ${project.version} + runtime + + + ${project.groupId} + postgresql-connector + ${project.version} + runtime + diff --git a/pom.xml b/pom.xml index e805409..8a4bb60 100644 --- a/pom.xml +++ b/pom.xml @@ -450,6 +450,57 @@ ${project.basedir}/m2e-target + + + org.codehaus.mojo + build-helper-maven-plugin + 3.4.0 + + + add-source + + add-source + + generate-sources + + + target/generated-sources/annotations + target/generated-sources/java + + + + + add-google-auto + + add-resource + + generate-sources + + + + target/classes + + **/*.class + + + + + + + add-test-source + + add-test-source + + generate-test-sources + + + target/generated-test-sources/test-annotations + + + + + + diff --git a/testing/system-functions-sample/pom.xml b/testing/system-functions-sample/pom.xml index d0b74c1..24817a5 100644 --- a/testing/system-functions-sample/pom.xml +++ b/testing/system-functions-sample/pom.xml @@ -35,9 +35,9 @@ provided - com.datasqrl.flinkrunner + ${project.groupId} system-functions-discovery - 1.0.0-SNAPSHOT + ${project.version} provided diff --git a/types/json-type/pom.xml b/types/json-type/pom.xml index 82182ff..12a24fe 100644 --- a/types/json-type/pom.xml +++ b/types/json-type/pom.xml @@ -34,5 +34,27 @@ ${flink.version} provided + + org.apache.flink + flink-table-runtime + ${flink.version} + provided + + + ${project.groupId} + system-functions-discovery + ${project.version} + provided + + + com.jayway.jsonpath + json-path + 2.8.0 + + + com.google.auto.service + auto-service + 1.1.1 + diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/ArrayAgg.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/ArrayAgg.java new file mode 100644 index 0000000..336086a --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/ArrayAgg.java @@ -0,0 +1,36 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import java.util.List; +import lombok.Value; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; +import org.apache.flink.table.annotation.DataTypeHint; + +@Value +public class ArrayAgg { + + @DataTypeHint(value = "RAW") + private List objects; + + public void add(JsonNode value) { + objects.add(value); + } + + public void remove(JsonNode value) { + objects.remove(value); + } +} diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonArray.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonArray.java new file mode 100644 index 0000000..6c1426f --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonArray.java @@ -0,0 +1,64 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import static com.datasqrl.types.json.functions.JsonFunctions.createJsonArgumentTypeStrategy; +import static com.datasqrl.types.json.functions.JsonFunctions.createJsonType; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.json.FlinkJsonType; +import com.google.auto.service.AutoService; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ArrayNode; +import org.apache.flink.table.catalog.DataTypeFactory; +import org.apache.flink.table.functions.ScalarFunction; +import org.apache.flink.table.types.inference.InputTypeStrategies; +import org.apache.flink.table.types.inference.InputTypeStrategy; +import org.apache.flink.table.types.inference.TypeInference; +import org.apache.flink.table.types.inference.TypeStrategies; +import org.apache.flink.util.jackson.JacksonMapperFactory; + +/** Creates a JSON array from the list of JSON objects and scalar values. */ +@AutoService(AutoRegisterSystemFunction.class) +public class JsonArray extends ScalarFunction implements AutoRegisterSystemFunction { + private static final ObjectMapper mapper = JacksonMapperFactory.createObjectMapper(); + + public FlinkJsonType eval(Object... objects) { + ArrayNode arrayNode = mapper.createArrayNode(); + + for (Object value : objects) { + if (value instanceof FlinkJsonType) { + FlinkJsonType type = (FlinkJsonType) value; + arrayNode.add(type.json); + } else { + arrayNode.addPOJO(value); + } + } + + return new FlinkJsonType(arrayNode); + } + + @Override + public TypeInference getTypeInference(DataTypeFactory typeFactory) { + InputTypeStrategy inputTypeStrategy = + InputTypeStrategies.varyingSequence(createJsonArgumentTypeStrategy(typeFactory)); + + return TypeInference.newBuilder() + .inputTypeStrategy(inputTypeStrategy) + .outputTypeStrategy(TypeStrategies.explicit(createJsonType(typeFactory))) + .build(); + } +} diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonArrayAgg.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonArrayAgg.java new file mode 100644 index 0000000..1042544 --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonArrayAgg.java @@ -0,0 +1,102 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import com.datasqrl.types.json.FlinkJsonType; +import java.util.ArrayList; +import lombok.SneakyThrows; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ArrayNode; +import org.apache.flink.table.functions.AggregateFunction; +import org.apache.flink.util.jackson.JacksonMapperFactory; + +/** Aggregation function that aggregates JSON objects into a JSON array. */ +public class JsonArrayAgg extends AggregateFunction { + + private static final ObjectMapper mapper = JacksonMapperFactory.createObjectMapper(); + + @Override + public ArrayAgg createAccumulator() { + return new ArrayAgg(new ArrayList<>()); + } + + public void accumulate(ArrayAgg accumulator, String value) { + accumulator.add(mapper.getNodeFactory().textNode(value)); + } + + @SneakyThrows + public void accumulate(ArrayAgg accumulator, FlinkJsonType value) { + if (value != null) { + accumulator.add(value.json); + } else { + accumulator.add(null); + } + } + + public void accumulate(ArrayAgg accumulator, Double value) { + accumulator.add(mapper.getNodeFactory().numberNode(value)); + } + + public void accumulate(ArrayAgg accumulator, Long value) { + accumulator.add(mapper.getNodeFactory().numberNode(value)); + } + + public void accumulate(ArrayAgg accumulator, Integer value) { + accumulator.add(mapper.getNodeFactory().numberNode(value)); + } + + public void retract(ArrayAgg accumulator, String value) { + accumulator.remove(mapper.getNodeFactory().textNode(value)); + } + + @SneakyThrows + public void retract(ArrayAgg accumulator, FlinkJsonType value) { + if (value != null) { + accumulator.remove(value.json); + } else { + accumulator.remove(null); + } + } + + public void retract(ArrayAgg accumulator, Double value) { + accumulator.remove(mapper.getNodeFactory().numberNode(value)); + } + + public void retract(ArrayAgg accumulator, Long value) { + accumulator.remove(mapper.getNodeFactory().numberNode(value)); + } + + public void retract(ArrayAgg accumulator, Integer value) { + accumulator.remove(mapper.getNodeFactory().numberNode(value)); + } + + public void merge(ArrayAgg accumulator, java.lang.Iterable iterable) { + iterable.forEach(o -> accumulator.getObjects().addAll(o.getObjects())); + } + + @Override + public FlinkJsonType getValue(ArrayAgg accumulator) { + ArrayNode arrayNode = mapper.createArrayNode(); + for (Object o : accumulator.getObjects()) { + if (o instanceof FlinkJsonType) { + arrayNode.add(((FlinkJsonType) o).json); + } else { + arrayNode.addPOJO(o); + } + } + return new FlinkJsonType(arrayNode); + } +} diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonConcat.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonConcat.java new file mode 100644 index 0000000..27e2800 --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonConcat.java @@ -0,0 +1,45 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.json.FlinkJsonType; +import com.google.auto.service.AutoService; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.flink.table.functions.ScalarFunction; + +/** + * Merges two JSON objects into one. If two objects share the same key, the value from the later + * object is used. + */ +@AutoService(AutoRegisterSystemFunction.class) +public class JsonConcat extends ScalarFunction implements AutoRegisterSystemFunction { + + public FlinkJsonType eval(FlinkJsonType json1, FlinkJsonType json2) { + if (json1 == null || json2 == null) { + return null; + } + try { + ObjectNode node1 = (ObjectNode) json1.getJson(); + ObjectNode node2 = (ObjectNode) json2.getJson(); + + node1.setAll(node2); + return new FlinkJsonType(node1); + } catch (Exception e) { + return null; + } + } +} diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonExists.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonExists.java new file mode 100644 index 0000000..7602c30 --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonExists.java @@ -0,0 +1,38 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.json.FlinkJsonType; +import com.google.auto.service.AutoService; +import org.apache.flink.table.functions.ScalarFunction; +import org.apache.flink.table.runtime.functions.SqlJsonUtils; + +/** For a given JSON object, checks whether the provided JSON path exists */ +@AutoService(AutoRegisterSystemFunction.class) +public class JsonExists extends ScalarFunction implements AutoRegisterSystemFunction { + + public Boolean eval(FlinkJsonType json, String path) { + if (json == null) { + return null; + } + try { + return SqlJsonUtils.jsonExists(json.json.toString(), path); + } catch (Exception e) { + return false; + } + } +} diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonExtract.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonExtract.java new file mode 100644 index 0000000..98feedc --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonExtract.java @@ -0,0 +1,102 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.json.FlinkJsonType; +import com.google.auto.service.AutoService; +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.ReadContext; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; +import org.apache.flink.table.functions.ScalarFunction; + +/** + * Extracts a value from the JSON object based on the provided JSON path. An optional third argument + * can be provided to specify a default value when the given JSON path does not yield a value for + * the JSON object. + */ +@AutoService(AutoRegisterSystemFunction.class) +public class JsonExtract extends ScalarFunction implements AutoRegisterSystemFunction { + + public String eval(FlinkJsonType input, String pathSpec) { + if (input == null) { + return null; + } + try { + JsonNode jsonNode = input.getJson(); + ReadContext ctx = JsonPath.parse(jsonNode.toString()); + Object value = ctx.read(pathSpec); + if (value == null) { + return null; + } + return value.toString(); + } catch (Exception e) { + return null; + } + } + + public String eval(FlinkJsonType input, String pathSpec, String defaultValue) { + if (input == null) { + return null; + } + try { + ReadContext ctx = JsonPath.parse(input.getJson().toString()); + JsonPath parse = JsonPath.compile(pathSpec); + return ctx.read(parse, String.class); + } catch (Exception e) { + return defaultValue; + } + } + + public Boolean eval(FlinkJsonType input, String pathSpec, Boolean defaultValue) { + if (input == null) { + return null; + } + try { + ReadContext ctx = JsonPath.parse(input.getJson().toString()); + JsonPath parse = JsonPath.compile(pathSpec); + return ctx.read(parse, Boolean.class); + } catch (Exception e) { + return defaultValue; + } + } + + public Double eval(FlinkJsonType input, String pathSpec, Double defaultValue) { + if (input == null) { + return null; + } + try { + ReadContext ctx = JsonPath.parse(input.getJson().toString()); + JsonPath parse = JsonPath.compile(pathSpec); + return ctx.read(parse, Double.class); + } catch (Exception e) { + return defaultValue; + } + } + + public Integer eval(FlinkJsonType input, String pathSpec, Integer defaultValue) { + if (input == null) { + return null; + } + try { + ReadContext ctx = JsonPath.parse(input.getJson().toString()); + JsonPath parse = JsonPath.compile(pathSpec); + return ctx.read(parse, Integer.class); + } catch (Exception e) { + return defaultValue; + } + } +} diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonFunctions.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonFunctions.java new file mode 100644 index 0000000..09a7a85 --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonFunctions.java @@ -0,0 +1,49 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import com.datasqrl.types.json.FlinkJsonType; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.catalog.DataTypeFactory; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.ArgumentTypeStrategy; +import org.apache.flink.table.types.inference.InputTypeStrategies; +import org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies; + +public class JsonFunctions { + + public static final ToJson TO_JSON = new ToJson(); + public static final JsonToString JSON_TO_STRING = new JsonToString(); + public static final JsonObject JSON_OBJECT = new JsonObject(); + public static final JsonArray JSON_ARRAY = new JsonArray(); + public static final JsonExtract JSON_EXTRACT = new JsonExtract(); + public static final JsonQuery JSON_QUERY = new JsonQuery(); + public static final JsonExists JSON_EXISTS = new JsonExists(); + public static final JsonArrayAgg JSON_ARRAYAGG = new JsonArrayAgg(); + public static final JsonObjectAgg JSON_OBJECTAGG = new JsonObjectAgg(); + public static final JsonConcat JSON_CONCAT = new JsonConcat(); + + public static ArgumentTypeStrategy createJsonArgumentTypeStrategy(DataTypeFactory typeFactory) { + return InputTypeStrategies.or( + SpecificInputTypeStrategies.JSON_ARGUMENT, + InputTypeStrategies.explicit(createJsonType(typeFactory))); + } + + public static DataType createJsonType(DataTypeFactory typeFactory) { + DataType dataType = DataTypes.of(FlinkJsonType.class).toDataType(typeFactory); + return dataType; + } +} diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonObject.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonObject.java new file mode 100644 index 0000000..593b8ef --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonObject.java @@ -0,0 +1,82 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import static com.datasqrl.types.json.functions.JsonFunctions.createJsonArgumentTypeStrategy; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.json.FlinkJsonType; +import com.google.auto.service.AutoService; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.catalog.DataTypeFactory; +import org.apache.flink.table.functions.ScalarFunction; +import org.apache.flink.table.types.inference.InputTypeStrategies; +import org.apache.flink.table.types.inference.InputTypeStrategy; +import org.apache.flink.table.types.inference.TypeInference; +import org.apache.flink.table.types.inference.TypeStrategies; +import org.apache.flink.util.jackson.JacksonMapperFactory; + +/** + * Creates a JSON object from key-value pairs, where the key is mapped to a field with the + * associated value. Key-value pairs are provided as a list of even length, with the first element + * of each pair being the key and the second being the value. If multiple key-value pairs have the + * same key, the last pair is added to the JSON object. + */ +@AutoService(AutoRegisterSystemFunction.class) +public class JsonObject extends ScalarFunction implements AutoRegisterSystemFunction { + static final ObjectMapper mapper = JacksonMapperFactory.createObjectMapper(); + + public FlinkJsonType eval(Object... objects) { + if (objects.length % 2 != 0) { + throw new IllegalArgumentException("Arguments should be in key-value pairs"); + } + + ObjectNode objectNode = mapper.createObjectNode(); + + for (int i = 0; i < objects.length; i += 2) { + if (!(objects[i] instanceof String)) { + throw new IllegalArgumentException("Key must be a string"); + } + String key = (String) objects[i]; + Object value = objects[i + 1]; + if (value instanceof FlinkJsonType) { + FlinkJsonType type = (FlinkJsonType) value; + objectNode.put(key, type.json); + } else { + objectNode.putPOJO(key, value); + } + } + + return new FlinkJsonType(objectNode); + } + + @Override + public TypeInference getTypeInference(DataTypeFactory typeFactory) { + InputTypeStrategy anyJsonCompatibleArg = + InputTypeStrategies.repeatingSequence(createJsonArgumentTypeStrategy(typeFactory)); + + InputTypeStrategy inputTypeStrategy = + InputTypeStrategies.compositeSequence().finishWithVarying(anyJsonCompatibleArg); + + return TypeInference.newBuilder() + .inputTypeStrategy(inputTypeStrategy) + .outputTypeStrategy( + TypeStrategies.explicit(DataTypes.of(FlinkJsonType.class).toDataType(typeFactory))) + .build(); + } +} diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonObjectAgg.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonObjectAgg.java new file mode 100644 index 0000000..b9a817a --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonObjectAgg.java @@ -0,0 +1,112 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import com.datasqrl.types.json.FlinkJsonType; +import com.datasqrl.types.json.FlinkJsonTypeSerializer; +import java.util.LinkedHashMap; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.flink.table.annotation.DataTypeHint; +import org.apache.flink.table.annotation.FunctionHint; +import org.apache.flink.table.annotation.InputGroup; +import org.apache.flink.table.functions.AggregateFunction; +import org.apache.flink.util.jackson.JacksonMapperFactory; + +/** + * Aggregation function that merges JSON objects into a single JSON object. If two JSON objects + * share the same field name, the value of the later one is used in the aggregated result. + */ +@FunctionHint( + output = + @DataTypeHint( + value = "RAW", + bridgedTo = FlinkJsonType.class, + rawSerializer = FlinkJsonTypeSerializer.class)) +public class JsonObjectAgg extends AggregateFunction { + + private static final ObjectMapper mapper = JacksonMapperFactory.createObjectMapper(); + + @Override + public ObjectAgg createAccumulator() { + return new ObjectAgg(new LinkedHashMap<>()); + } + + public void accumulate(ObjectAgg accumulator, String key, String value) { + accumulateObject(accumulator, key, value); + } + + public void accumulate( + ObjectAgg accumulator, String key, @DataTypeHint(inputGroup = InputGroup.ANY) Object value) { + if (value instanceof FlinkJsonType) { + accumulateObject(accumulator, key, ((FlinkJsonType) value).getJson()); + } else { + accumulator.add(key, mapper.getNodeFactory().pojoNode(value)); + } + } + + public void accumulate(ObjectAgg accumulator, String key, Double value) { + accumulateObject(accumulator, key, value); + } + + public void accumulate(ObjectAgg accumulator, String key, Long value) { + accumulateObject(accumulator, key, value); + } + + public void accumulate(ObjectAgg accumulator, String key, Integer value) { + accumulateObject(accumulator, key, value); + } + + public void accumulateObject(ObjectAgg accumulator, String key, Object value) { + accumulator.add(key, mapper.getNodeFactory().pojoNode(value)); + } + + public void retract(ObjectAgg accumulator, String key, String value) { + retractObject(accumulator, key); + } + + public void retract( + ObjectAgg accumulator, String key, @DataTypeHint(inputGroup = InputGroup.ANY) Object value) { + retractObject(accumulator, key); + } + + public void retract(ObjectAgg accumulator, String key, Double value) { + retractObject(accumulator, key); + } + + public void retract(ObjectAgg accumulator, String key, Long value) { + retractObject(accumulator, key); + } + + public void retract(ObjectAgg accumulator, String key, Integer value) { + retractObject(accumulator, key); + } + + public void retractObject(ObjectAgg accumulator, String key) { + accumulator.remove(key); + } + + public void merge(ObjectAgg accumulator, java.lang.Iterable iterable) { + iterable.forEach(o -> accumulator.getObjects().putAll(o.getObjects())); + } + + @Override + public FlinkJsonType getValue(ObjectAgg accumulator) { + ObjectNode objectNode = mapper.createObjectNode(); + accumulator.getObjects().forEach(objectNode::putPOJO); + return new FlinkJsonType(objectNode); + } +} diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonQuery.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonQuery.java new file mode 100644 index 0000000..442b5ef --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonQuery.java @@ -0,0 +1,49 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.json.FlinkJsonType; +import com.google.auto.service.AutoService; +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.ReadContext; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.table.functions.ScalarFunction; +import org.apache.flink.util.jackson.JacksonMapperFactory; + +/** + * For a given JSON object, executes a JSON path query against the object and returns the result as + * string. + */ +@AutoService(AutoRegisterSystemFunction.class) +public class JsonQuery extends ScalarFunction implements AutoRegisterSystemFunction { + static final ObjectMapper mapper = JacksonMapperFactory.createObjectMapper(); + + public String eval(FlinkJsonType input, String pathSpec) { + if (input == null) { + return null; + } + try { + JsonNode jsonNode = input.getJson(); + ReadContext ctx = JsonPath.parse(jsonNode.toString()); + Object result = ctx.read(pathSpec); + return mapper.writeValueAsString(result); // Convert the result back to JSON string + } catch (Exception e) { + return null; + } + } +} diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonToString.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonToString.java new file mode 100644 index 0000000..94ac706 --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/JsonToString.java @@ -0,0 +1,32 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.json.FlinkJsonType; +import com.google.auto.service.AutoService; +import org.apache.flink.table.functions.ScalarFunction; + +@AutoService(AutoRegisterSystemFunction.class) +public class JsonToString extends ScalarFunction implements AutoRegisterSystemFunction { + + public String eval(FlinkJsonType json) { + if (json == null) { + return null; + } + return json.getJson().toString(); + } +} diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/ObjectAgg.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/ObjectAgg.java new file mode 100644 index 0000000..7c4aa5c --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/ObjectAgg.java @@ -0,0 +1,42 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import java.util.Map; +import lombok.Getter; +import lombok.Value; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; +import org.apache.flink.table.annotation.DataTypeHint; + +@Value +public class ObjectAgg { + + @DataTypeHint(value = "RAW") + @Getter + Map objects; + + public void add(String key, JsonNode value) { + if (key != null) { + objects.put(key, value); + } + } + + public void remove(String key) { + if (key != null) { + objects.remove(key); + } + } +} diff --git a/types/json-type/src/main/java/com/datasqrl/types/json/functions/ToJson.java b/types/json-type/src/main/java/com/datasqrl/types/json/functions/ToJson.java new file mode 100644 index 0000000..5370f8b --- /dev/null +++ b/types/json-type/src/main/java/com/datasqrl/types/json/functions/ToJson.java @@ -0,0 +1,85 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.json.functions; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.json.FlinkJsonType; +import com.google.auto.service.AutoService; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ArrayNode; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.flink.table.annotation.DataTypeHint; +import org.apache.flink.table.annotation.InputGroup; +import org.apache.flink.table.functions.ScalarFunction; +import org.apache.flink.types.Row; +import org.apache.flink.util.jackson.JacksonMapperFactory; + +/** Parses a JSON object from string */ +@AutoService(AutoRegisterSystemFunction.class) +public class ToJson extends ScalarFunction implements AutoRegisterSystemFunction { + + public static final ObjectMapper mapper = JacksonMapperFactory.createObjectMapper(); + + public FlinkJsonType eval(String json) { + if (json == null) { + return null; + } + try { + return new FlinkJsonType(mapper.readTree(json)); + } catch (JsonProcessingException e) { + return null; + } + } + + public FlinkJsonType eval(@DataTypeHint(inputGroup = InputGroup.ANY) Object json) { + if (json == null) { + return null; + } + if (json instanceof FlinkJsonType) { + return (FlinkJsonType) json; + } + + return new FlinkJsonType(unboxFlinkToJsonNode(json)); + } + + JsonNode unboxFlinkToJsonNode(Object json) { + if (json instanceof Row) { + Row row = (Row) json; + ObjectNode objectNode = mapper.createObjectNode(); + String[] fieldNames = + row.getFieldNames(true).toArray(new String[0]); // Get field names in an array + for (String fieldName : fieldNames) { + Object field = row.getField(fieldName); + objectNode.set(fieldName, unboxFlinkToJsonNode(field)); // Recursively unbox each field + } + return objectNode; + } else if (json instanceof Row[]) { + Row[] rows = (Row[]) json; + ArrayNode arrayNode = mapper.createArrayNode(); + for (Row row : rows) { + if (row == null) { + arrayNode.addNull(); + } else { + arrayNode.add(unboxFlinkToJsonNode(row)); // Recursively unbox each row in the array + } + } + return arrayNode; + } + return mapper.valueToTree(json); // Directly serialize other types + } +} diff --git a/types/json-type/src/test/java/com/datasqrl/json/JsonConversionTest.java b/types/json-type/src/test/java/com/datasqrl/json/JsonConversionTest.java new file mode 100644 index 0000000..24de3cc --- /dev/null +++ b/types/json-type/src/test/java/com/datasqrl/json/JsonConversionTest.java @@ -0,0 +1,490 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// package com.datasqrl.json; +// +// import static com.datasqrl.function.SqrlFunction.getFunctionNameFromClass; +// import static com.datasqrl.plan.local.analyze.RetailSqrlModule.createTableSource; +// import static org.junit.jupiter.api.Assertions.assertEquals; +// +// import com.datasqrl.calcite.Dialect; +// import com.datasqrl.calcite.function.SqrlTableMacro; +// import com.datasqrl.calcite.type.TypeFactory; +// import com.datasqrl.canonicalizer.Name; +// import com.datasqrl.canonicalizer.NameCanonicalizer; +// import com.datasqrl.canonicalizer.NamePath; +// import com.datasqrl.config.SourceFactory; +// import com.datasqrl.engine.database.relational.ddl.PostgresDDLFactory; +// import com.datasqrl.engine.database.relational.ddl.statements.CreateTableDDL; +// import com.datasqrl.error.ErrorCollector; +// import com.datasqrl.function.SqrlFunction; +// import com.datasqrl.functions.json.StdJsonLibraryImpl; +// import com.datasqrl.graphql.AbstractGraphqlTest; +// import com.datasqrl.io.DataSystemConnectorFactory; +// import com.datasqrl.io.InMemSourceFactory; +// import com.datasqrl.io.mem.MemoryConnectorFactory; +// import com.datasqrl.types.json.FlinkJsonType; +// import com.datasqrl.loaders.TableSourceNamespaceObject; +// import com.datasqrl.module.NamespaceObject; +// import com.datasqrl.module.SqrlModule; +// import com.datasqrl.plan.global.PhysicalDAGPlan.EngineSink; +// import com.datasqrl.plan.local.analyze.MockModuleLoader; +// import com.datasqrl.plan.table.CalciteTableFactory; +// import com.datasqrl.plan.table.TableConverter; +// import com.datasqrl.plan.table.TableIdFactory; +// import com.datasqrl.plan.validate.ScriptPlanner; +// import com.datasqrl.util.SnapshotTest; +// import com.google.auto.service.AutoService; +// import com.ibm.icu.impl.Pair; +// import java.io.IOException; +// import java.sql.Connection; +// import java.sql.ResultSet; +// import java.sql.Statement; +// import java.util.*; +// +// import lombok.SneakyThrows; +// import lombok.Value; +// import lombok.extern.slf4j.Slf4j; +// import org.apache.calcite.rel.RelNode; +// import org.apache.calcite.sql.ScriptNode; +// import org.apache.calcite.sql.SqrlStatement; +// import org.apache.flink.api.common.typeinfo.Types; +// import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +// import org.apache.flink.table.api.Table; +// import org.apache.flink.table.api.TableResult; +// import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +// import org.apache.flink.table.functions.FunctionDefinition; +// import org.apache.flink.table.functions.UserDefinedFunction; +// import org.apache.flink.test.junit5.MiniClusterExtension; +// import org.apache.flink.types.Row; +// import org.junit.jupiter.api.AfterEach; +// import org.junit.jupiter.api.Assertions; +// import org.junit.jupiter.api.BeforeAll; +// import org.junit.jupiter.api.BeforeEach; +// import org.junit.jupiter.api.Test; +// import org.junit.jupiter.api.TestInfo; +// import org.junit.jupiter.api.extension.ExtendWith; +// import org.postgresql.util.PGobject; +// import org.testcontainers.shaded.com.fasterxml.jackson.databind.ObjectMapper; +// +/// ** +// * A test suite to convert SQRL queries to their respective dialects +// */ +// @Slf4j +// @ExtendWith(MiniClusterExtension.class) +// public class JsonConversionTest extends AbstractGraphqlTest { +// +// protected SnapshotTest.Snapshot snapshot; +// ObjectMapper objectMapper = new ObjectMapper(); +// private ScriptPlanner planner; +// +// @BeforeAll +// public static void setupAll() { +// createPostgresTable(); +// insertDataIntoPostgresTable(); +// } +// +// @SneakyThrows +// private static void createPostgresTable() { +// try (Connection conn = AbstractGraphqlTest.getPostgresConnection(); Statement stmt = +// conn.createStatement()) { +// String createTableSQL = +// "CREATE TABLE IF NOT EXISTS jsondata$2 (" + "id INT, " + "json jsonb);"; +// stmt.execute(createTableSQL); +// } +// } +// +// @SneakyThrows +// private static void insertDataIntoPostgresTable() { +// try (Connection conn = AbstractGraphqlTest.getPostgresConnection(); Statement stmt = +// conn.createStatement()) { +// String insertSQL = "INSERT INTO jsondata$2 (id, json) VALUES " +// + "(1, '{\"example\":[1,2,3]}'),(2, '{\"example\":[4,5,6]}');"; +// stmt.execute(insertSQL); +// } +// } +//// +//// @AfterAll +//// public static void tearDownAll() { +////// testDatabase.stop(); +//// } +// +// @BeforeEach +// public void setup(TestInfo testInfo) throws IOException { +// initialize(IntegrationTestSettings.getInMemory(), null, Optional.empty(), +// ErrorCollector.root(), +// createJson(), false); +// +// this.snapshot = SnapshotTest.Snapshot.of(getClass(), testInfo); +// +// this.planner = injector.getInstance(ScriptPlanner.class); +// runStatement("IMPORT json-data.jsondata TIMESTAMP _ingest_time"); +// } +// +// private void runStatement(String statement) { +// planner.validateStatement(parse(statement)); +// } +// +// private SqrlStatement parse(String statement) { +// return (SqrlStatement) ((ScriptNode)framework.getQueryPlanner().parse(Dialect.SQRL, +// statement)) +// .getStatements().get(0); +// } +// +// public Map createJson() { +// CalciteTableFactory tableFactory = new CalciteTableFactory(new TableIdFactory(new +// HashMap<>()), +// new TableConverter(new TypeFactory(), framework)); +// SqrlModule module = new SqrlModule() { +// +// private final Map tables = new HashMap(); +// +// @Override +// public Optional getNamespaceObject(Name name) { +// NamespaceObject obj = new TableSourceNamespaceObject( +// RetailSqrlModule.createTableSource(JsonData.class, "data", "json-data"), +// tableFactory); +// return Optional.of(obj); +// } +// +// @Override +// public List getNamespaceObjects() { +// return new ArrayList<>(tables.values()); +// } +// }; +// +// return Map.of(NamePath.of("json-data"), module); +// } +// +// @SneakyThrows +// private Object executePostgresQuery(String query) { +// try (Connection conn = AbstractGraphqlTest.getPostgresConnection(); Statement stmt = +// conn.createStatement()) { +// System.out.println(query); +// ResultSet rs = stmt.executeQuery(query); +// // Assuming the result is a single value for simplicity +// return rs.next() ? rs.getObject(1) : null; +// } +// } +// +// @SneakyThrows +// public Object jsonFunctionTest(String query) { +// StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); +// env.setParallelism(1); +// +// // Assuming you have a method to create or get Flink SQL environment +// StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env); +// +// List inputRows = Arrays.asList(Row.of(1, "{\"example\":[1,2,3]}"), +// Row.of(2, "{\"example\":[4,5,6]}")); +// +// // Create a Table from the list of rows +// Table inputTable = tableEnv.fromDataStream(env.fromCollection(inputRows, +// Types.ROW_NAMED(new String[]{"id", "json"}, Types.INT, Types.STRING))); +// +// for (FunctionDefinition sqrlFunction : StdJsonLibraryImpl.json) { +// UserDefinedFunction userDefinedFunction = (UserDefinedFunction) sqrlFunction; +// tableEnv.createFunction(getFunctionNameFromClass(sqrlFunction.getClass()), +// userDefinedFunction.getClass()); +// } +// +// // Register the Table under a name +// tableEnv.createTemporaryView("jsondata$2", inputTable); +// +// // Run your query +// Table result = tableEnv.sqlQuery(query); +// TableResult execute = result.execute(); +// List rows = new ArrayList<>(); +// execute.collect().forEachRemaining(rows::add); +// +// return rows.get(rows.size() - 1).getField(0); +// } +// +// @AfterEach +// public void tearDown() { +// snapshot.createOrValidate(); +// } +// +// @Test +// public void jsonArrayTest() { +// testJsonReturn("jsonArray('a', null, 'b', 123)"); +// } +// +// @Test +// public void jsonArrayAgg() { +// testJsonReturn("jsonArrayAgg(jsonExtract(toJson(json), '$.example[0]', 0))"); +// } +// +// @Test +// public void jsonObjectAgg() { +// testJsonReturn("jsonObjectAgg('key', toJson(json))"); +// } +// +// @Test +// public void jsonArrayAgg2() { +// testJsonReturn("jsonArrayAgg(id)"); +// } +// +// @Test +// public void jsonArrayAggNull() { +// testJsonReturn("jsonArrayAgg(toJson(null))"); +// } +// +// @Test +// public void jsonArrayArray() { +// testJsonReturn("JSONARRAY(JSONARRAY(1))"); +// } +// +// @Test +// public void jsonExistsTest() { +// testScalarReturn("jsonExists(toJson('{\"a\": true}'), '$.a')"); +// } +// +// @Test +// public void jsonExtractTest() { +// testScalarReturn("jsonExtract(toJson('{\"a\": \"hello\"}'), '$.a', 'default')"); +// } +// +// @Test +// public void jsonConcat() { +// testJsonReturn("jsonConcat(toJson('{\"a\": \"hello\"}'), toJson('{\"b\": \"hello\"}'))"); +// } +// +// @Test +// public void jsonObjectTest() { +// testJsonReturn("jsonObject('key1', 'value1', 'key2', 123)"); +// } +// +// @Test +// public void jsonQueryTest() { +// testJsonReturn("jsonQuery(toJson('{\"a\": {\"b\": 1}}'), '$.a')"); +// } +// +// @Test +// public void jsonArrayWithNulls() { +// // Testing JSON array creation with null values +// testJsonReturn("jsonArray('a', null, 'b', null)"); +// } +// +// @Test +// public void jsonObjectWithNulls() { +// // Testing JSON object creation with null values +// testJsonReturn("jsonObject('key1', null, 'key2', 'value2')"); +// } +// +// @Test +// public void jsonExtract() { +// testScalarReturn("jsonExtract(toJson('{\"a\": \"hello\"}'), '$.a')"); +// } +// +// @Test +// public void jsonExtractWithDefaultString() { +// // Test with a default string value +// testScalarReturn("jsonExtract(toJson('{\"a\": \"hello\"}'), '$.b', 'defaultString')"); +// } +// +// @Test +// public void jsonExtractWithDefaultInteger() { +// // Test with a default integer value +// testScalarReturn("jsonExtract(toJson('{\"a\": \"hello\"}'), '$.b', 123)"); +// } +// +// @Test +// public void jsonExtractWithDefaultBoolean() { +// // Test with a default boolean value +// testScalarReturn("jsonExtract(toJson('{\"a\": \"hello\"}'), '$.b', true)"); +// } +// +// @Test +// public void jsonExtractWithDefaultBoolean2() { +// // Test with a default boolean value +// testScalarReturn("jsonExtract(toJson('{\"a\": false}'), '$.a', true)"); +// } +// +// @Test +// public void jsonExtractWithDefaultDouble3() { +// // Test with a default boolean value +// testScalarReturn("jsonExtract(toJson('{\"a\": 0.2}'), '$.a', 0.0)"); +// } +// +// @Test +// public void jsonExtractWithDefaultDouble4() { +// // Test with a default boolean value +// testScalarReturn("jsonExtract(toJson('{\"a\": 0.2}'), '$.a', 0)"); +// } +// +// @Test +// public void jsonExtractWithDefaultNull() { +// // Test with a default null value +// testScalarReturn("jsonExtract(toJson('{\"a\": \"hello\"}'), '$.b', null)"); +// } +// +// @Test +// public void jsonExtractWithNonexistentPath() { +// // Test extraction from a nonexistent path (should return default value) +// testScalarReturn("jsonExtract(toJson('{\"a\": \"hello\"}'), '$.nonexistent', 'default')"); +// } +// +// @Test +// public void jsonExtractWithEmptyJson() { +// // Test extraction from an empty JSON object +// testScalarReturn("jsonExtract(toJson('{}'), '$.a', 'default')"); +// } +// +// @Test +// public void jsonExtractWithComplexJsonPath() { +// // Test extraction with a complex JSON path +// testScalarReturn( +// "jsonExtract(toJson('{\"a\": {\"b\": {\"c\": \"value\"}}}'), '$.a.b.c', 'default')"); +// } +// +// @Test +// public void jsonExtractWithArrayPath() { +// // Test extraction where the path leads to an array +// testScalarReturn("jsonExtract(toJson('{\"a\": [1, 2, 3]}'), '$.a[1]', 'default')"); +// } +// +// @Test +// public void jsonExtractWithNumericDefault() { +// // Test extraction with a numeric default value +// testScalarReturn("jsonExtract(toJson('{\"a\": \"hello\"}'), '$.b', 0)"); +// } +// +// @Test +// public void jsonObject() { +// // Test extraction with a numeric default value +// testJsonReturn("jsonObject('key', toJson('{\"a\": \"hello\"}'), 'key2', 0)"); +// } +// +// @Test +// public void jsonArrayWithMixedDataTypes() { +// // Testing JSON array creation with mixed data types +// testJsonReturn("jsonArray('a', 1, true, null, 3.14)"); +// } +// +// @Test +// public void jsonArrayWithNestedArrays() { +// // Testing JSON array creation with nested arrays +// testJsonReturn("jsonArray('a', jsonArray('nested', 1), 'b', jsonArray('nested', 2))"); +// } +// +// @Test +// public void jsonArrayWithEmptyValues() { +// // Testing JSON array creation with empty values +// testJsonReturn("jsonArray('', '', '', '')"); +// } +// +// @Test +// public void jsonObjectWithMixedDataTypes() { +// // Testing JSON object creation with mixed data types +// testJsonReturn("jsonObject('string', 'text', 'number', 123, 'boolean', true)"); +// } +// +// @Test +// public void jsonObjectWithNestedObjects() { +// // Testing JSON object creation with nested objects +// testJsonReturn("jsonObject('key1', jsonObject('nestedKey', 'nestedValue'), 'key2', +// 'value2')"); +// } +// +// @Test +// public void jsonObjectWithEmptyKeys() { +// // Testing JSON object creation with empty keys +// testJsonReturn("jsonObject('', 'value1', '', 'value2')"); +// } +// +// @SneakyThrows +// private void testJsonReturn(String function) { +// Pair x = executeScript(function); +// Assertions.assertEquals(objectMapper.readTree((String) x.first), +// objectMapper.readTree((String) x.second)); +// } +// +// @SneakyThrows +// private void testScalarReturn(String function) { +// Pair x = executeScript(function); +// assertEquals(x.first.toString().trim(), x.second.toString().trim()); +// } +// +// public Pair executeScript(String fncName) { +// runStatement("IMPORT json.*"); +// runStatement("X(@a: Int) := SELECT " + fncName + " AS json FROM jsondata"); +// return convert("X"); +// } +// +// @SneakyThrows +// private Pair convert(String fncName) { +// SqrlTableMacro x = framework.getQueryPlanner().getSchema().getTableFunction(fncName); +// RelNode relNode = x.getViewTransform().get(); +// +// RelNode pgRelNode = framework.getQueryPlanner().convertRelToDialect(Dialect.POSTGRES, +// relNode); +// String pgQuery = framework.getQueryPlanner().relToString(Dialect.POSTGRES, +// pgRelNode).getSql(); +// snapshot.addContent(pgQuery, "postgres"); +// +// // Execute Postgres query +// Object pgResult = executePostgresQuery(pgQuery); +// //Unbox result +// pgResult = pgResult instanceof PGobject ? ((PGobject) pgResult).getValue() : pgResult; +// pgResult = pgResult == null ? "" : pgResult.toString(); +// +// CreateTableDDL pg = new PostgresDDLFactory().createTable( +// new EngineSink("pg", new int[]{0}, relNode.getRowType(), OptionalInt.of(0), null)); +// +// snapshot.addContent((String) pgResult, "Postgres Result"); +// +// RelNode flinkRelNode = framework.getQueryPlanner().convertRelToDialect(Dialect.FLINK, +// relNode); +// String query = framework.getQueryPlanner().relToString(Dialect.FLINK, flinkRelNode).getSql(); +// snapshot.addContent(query, "flink"); +// +// Object flinkResult = jsonFunctionTest(query); +// if (flinkResult instanceof FlinkJsonType) { +// flinkResult = ((FlinkJsonType) flinkResult).getJson(); +// } +// flinkResult = flinkResult == null ? "" : flinkResult.toString(); +// snapshot.addContent((String) flinkResult, "Flink Result"); +// return Pair.of(pgResult, flinkResult); +// } +// +// //todo: Hacky way to get different in-mem sources to load +// @AutoService(SourceFactory.class) +// public static class InMemJson extends InMemSourceFactory { +// +// static Map> tableData = Map.of("data", +// List.of(new JsonData(1, "{\"example\":[1,2,3]}"), +// new JsonData(2, "{\"example\":[4,5,6]}"))); +// +// public InMemJson() { +// super("data", tableData); +// } +// } +// +// @Value +// public static class JsonData { +// +// int id; +// String json; +// } +// +// @AutoService(DataSystemConnectorFactory.class) +// public static class InMemJsonConnector extends MemoryConnectorFactory { +// +// public InMemJsonConnector() { +// super("data"); +// } +// } +// } diff --git a/types/json-type/src/test/java/com/datasqrl/json/JsonFunctionsTest.java b/types/json-type/src/test/java/com/datasqrl/json/JsonFunctionsTest.java new file mode 100644 index 0000000..66e7692 --- /dev/null +++ b/types/json-type/src/test/java/com/datasqrl/json/JsonFunctionsTest.java @@ -0,0 +1,522 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.json; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.datasqrl.types.json.FlinkJsonType; +import com.datasqrl.types.json.functions.ArrayAgg; +import com.datasqrl.types.json.functions.JsonFunctions; +import com.datasqrl.types.json.functions.ObjectAgg; +import lombok.SneakyThrows; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.types.Row; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +class JsonFunctionsTest { + ObjectMapper mapper = new ObjectMapper(); + + @SneakyThrows + JsonNode readTree(String val) { + return mapper.readTree(val); + } + + @Nested + class ToJsonTest { + + @Test + void testUnicodeJson() { + Row row = Row.withNames(); + row.setField("key", "”value”"); + Row[] rows = new Row[] {row}; + FlinkJsonType result = JsonFunctions.TO_JSON.eval(rows); + assertNotNull(result); + assertEquals("[{\"key\":\"”value”\"}]", result.getJson().toString()); + } + + @Test + void testValidJson() { + String json = "{\"key\":\"value\"}"; + FlinkJsonType result = JsonFunctions.TO_JSON.eval(json); + assertNotNull(result); + assertEquals(json, result.getJson().toString()); + } + + @Test + void testInvalidJson() { + String json = "Not a JSON"; + FlinkJsonType result = JsonFunctions.TO_JSON.eval(json); + assertNull(result); + } + + @Test + void testNullInput() { + assertNull(JsonFunctions.TO_JSON.eval(null)); + } + } + + @Nested + class JsonToStringTest { + + @Test + void testNonNullJson() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": \"value\"}")); + String result = JsonFunctions.JSON_TO_STRING.eval(json); + assertEquals("{\"key\":\"value\"}", result); + } + + @Test + void testNullJson() { + String result = JsonFunctions.JSON_TO_STRING.eval(null); + assertNull(result); + } + } + + @Nested + class JsonObjectTest { + + @Test + void testValidKeyValuePairs() { + FlinkJsonType result = JsonFunctions.JSON_OBJECT.eval("key1", "value1", "key2", "value2"); + assertNotNull(result); + assertEquals("{\"key1\":\"value1\",\"key2\":\"value2\"}", result.getJson().toString()); + } + + @Test + void testInvalidNumberOfArguments() { + assertThrows( + IllegalArgumentException.class, + () -> JsonFunctions.JSON_OBJECT.eval("key1", "value1", "key2")); + } + + @Test + void testNullKeyOrValue() { + FlinkJsonType resultWithNullValue = JsonFunctions.JSON_OBJECT.eval("key1", null); + assertNotNull(resultWithNullValue); + assertEquals("{\"key1\":null}", resultWithNullValue.getJson().toString()); + } + } + + @Nested + class JsonArrayTest { + + @Test + void testArrayWithJsonObjects() { + FlinkJsonType json1 = new FlinkJsonType(readTree("{\"key1\": \"value1\"}")); + FlinkJsonType json2 = new FlinkJsonType(readTree("{\"key2\": \"value2\"}")); + FlinkJsonType result = JsonFunctions.JSON_ARRAY.eval(json1, json2); + assertNotNull(result); + assertEquals("[{\"key1\":\"value1\"},{\"key2\":\"value2\"}]", result.getJson().toString()); + } + + @Test + void testArrayWithMixedTypes() { + FlinkJsonType result = JsonFunctions.JSON_ARRAY.eval("stringValue", 123, true); + assertNotNull(result); + assertEquals("[\"stringValue\",123,true]", result.getJson().toString()); + } + + @Test + void testArrayWithNullValues() { + FlinkJsonType result = JsonFunctions.JSON_ARRAY.eval((Object) null); + assertNotNull(result); + assertEquals("[null]", result.getJson().toString()); + } + } + + @Nested + class JsonExtractTest { + + @Test + void testValidPath() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": \"value\"}")); + String result = JsonFunctions.JSON_EXTRACT.eval(json, "$.key"); + assertEquals("value", result); + } + + @Test + void testValidPathBoolean() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": true}")); + String result = JsonFunctions.JSON_EXTRACT.eval(json, "$.key"); + assertEquals("true", result); + } + + // Testing eval method with a default value for String + @Test + void testStringPathWithDefaultValue() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": \"value\"}")); + String defaultValue = "default"; + String result = JsonFunctions.JSON_EXTRACT.eval(json, "$.nonexistentKey", defaultValue); + assertEquals(defaultValue, result); + } + + // Testing eval method with a default value for boolean + @Test + void testBooleanPathNormalWithDefaultValue() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": true}")); + boolean defaultValue = false; + boolean result = JsonFunctions.JSON_EXTRACT.eval(json, "$.key", defaultValue); + assertTrue(result); + } + + @Test + void testBooleanPathWithDefaultValue() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": true}")); + boolean defaultValue = false; + boolean result = JsonFunctions.JSON_EXTRACT.eval(json, "$.nonexistentKey", defaultValue); + assertFalse(result); + } + + // Testing eval method with a default value for boolean:false + @Test + void testBooleanPathWithDefaultValueTrue() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": true}")); + boolean defaultValue = true; + boolean result = JsonFunctions.JSON_EXTRACT.eval(json, "$.nonexistentKey", defaultValue); + assertTrue(result); + } + + // Testing eval method with a default value for Double + @Test + void testDoublePathWithDefaultValue() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": 1.23}")); + Double defaultValue = 4.56; + Double result = JsonFunctions.JSON_EXTRACT.eval(json, "$.key", defaultValue); + assertEquals(1.23, result); + } + + // Testing eval method with a default value for Integer + @Test + void testIntegerPathWithDefaultValue() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": 123}")); + Integer defaultValue = 456; + Integer result = JsonFunctions.JSON_EXTRACT.eval(json, "$.key", defaultValue); + assertEquals(123, result); + } + + @Test + void testInvalidPath() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": \"value\"}")); + String result = JsonFunctions.JSON_EXTRACT.eval(json, "$.nonexistentKey"); + assertNull(result); + } + } + + @Nested + class JsonQueryTest { + + @Test + void testValidQuery() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": \"value\"}")); + String result = JsonFunctions.JSON_QUERY.eval(json, "$.key"); + assertEquals("\"value\"", result); // Note the JSON representation of a string value + } + + // Test for a more complex JSON path query + @Test + void testComplexQuery() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key1\": {\"key2\": \"value\"}}")); + String result = JsonFunctions.JSON_QUERY.eval(json, "$.key1.key2"); + assertEquals("\"value\"", result); // JSON representation of the result + } + + // Test for an invalid query + @Test + void testInvalidQuery() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": \"value\"}")); + String result = JsonFunctions.JSON_QUERY.eval(json, "$.invalidKey"); + assertNull(result); + } + } + + @Nested + class JsonExistsTest { + + @Test + void testPathExists() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": \"value\"}")); + Boolean result = JsonFunctions.JSON_EXISTS.eval(json, "$.key"); + assertTrue(result); + } + + // Test for a path that exists + @Test + void testPathExistsComplex() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key1\": {\"key2\": \"value\"}}")); + Boolean result = JsonFunctions.JSON_EXISTS.eval(json, "$.key1.key2"); + assertTrue(result); + } + + @Test + void testPathDoesNotExistComplex() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key1\": {\"key2\": \"value\"}}")); + Boolean result = JsonFunctions.JSON_EXISTS.eval(json, "$.key1.nonexistentKey"); + assertFalse(result); + } + + @Test + void testPathDoesNotExist() { + FlinkJsonType json = new FlinkJsonType(readTree("{\"key\": \"value\"}")); + Boolean result = JsonFunctions.JSON_EXISTS.eval(json, "$.nonexistentKey"); + assertFalse(result); + } + + @Test + void testNullInput() { + Boolean result = JsonFunctions.JSON_EXISTS.eval(null, "$.key"); + assertNull(result); + } + } + + @Nested + class JsonConcatTest { + + @Test + void testSimpleMerge() { + FlinkJsonType json1 = new FlinkJsonType(readTree("{\"key1\": \"value1\"}")); + FlinkJsonType json2 = new FlinkJsonType(readTree("{\"key2\": \"value2\"}")); + FlinkJsonType result = JsonFunctions.JSON_CONCAT.eval(json1, json2); + assertEquals("{\"key1\":\"value1\",\"key2\":\"value2\"}", result.getJson().toString()); + } + + @Test + void testOverlappingKeys() { + FlinkJsonType json1 = new FlinkJsonType(readTree("{\"key\": \"value1\"}")); + FlinkJsonType json2 = new FlinkJsonType(readTree("{\"key\": \"value2\"}")); + FlinkJsonType result = JsonFunctions.JSON_CONCAT.eval(json1, json2); + assertEquals("{\"key\":\"value2\"}", result.getJson().toString()); + } + + @Test + void testNullInput() { + FlinkJsonType json1 = new FlinkJsonType(readTree("{\"key1\": \"value1\"}")); + FlinkJsonType result = JsonFunctions.JSON_CONCAT.eval(json1, null); + assertNull(result); + } + + @Test + void testNullInput2() { + FlinkJsonType json1 = new FlinkJsonType(readTree("{\"key1\": \"value1\"}")); + FlinkJsonType result = JsonFunctions.JSON_CONCAT.eval(null, json1); + assertNull(result); + } + } + + @Nested + class JsonArrayAggTest { + + @Test + void testAggregateJsonTypes() { + ArrayAgg accumulator = JsonFunctions.JSON_ARRAYAGG.createAccumulator(); + JsonFunctions.JSON_ARRAYAGG.accumulate( + accumulator, new FlinkJsonType(readTree("{\"key1\": \"value1\"}"))); + JsonFunctions.JSON_ARRAYAGG.accumulate( + accumulator, new FlinkJsonType(readTree("{\"key2\": \"value2\"}"))); + + FlinkJsonType result = JsonFunctions.JSON_ARRAYAGG.getValue(accumulator); + assertNotNull(result); + assertEquals("[{\"key1\":\"value1\"},{\"key2\":\"value2\"}]", result.getJson().toString()); + } + + @Test + void testAggregateMixedTypes() { + ArrayAgg accumulator = JsonFunctions.JSON_ARRAYAGG.createAccumulator(); + JsonFunctions.JSON_ARRAYAGG.accumulate(accumulator, "stringValue"); + JsonFunctions.JSON_ARRAYAGG.accumulate(accumulator, 123); + + FlinkJsonType result = JsonFunctions.JSON_ARRAYAGG.getValue(accumulator); + assertNotNull(result); + assertEquals("[\"stringValue\",123]", result.getJson().toString()); + } + + @Test + void testAccumulateNullValues() { + ArrayAgg accumulator = JsonFunctions.JSON_ARRAYAGG.createAccumulator(); + JsonFunctions.JSON_ARRAYAGG.accumulate(accumulator, (FlinkJsonType) null); + FlinkJsonType result = JsonFunctions.JSON_ARRAYAGG.getValue(accumulator); + assertEquals("[null]", result.getJson().toString()); + } + + @Test + void testArrayWithNullElements() { + FlinkJsonType json1 = new FlinkJsonType(readTree("{\"key1\": \"value1\"}")); + FlinkJsonType json2 = null; // null JSON object + FlinkJsonType result = JsonFunctions.JSON_ARRAY.eval(json1, json2); + assertNotNull(result); + // Depending on implementation, the result might include the null or ignore it + assertEquals("[{\"key1\":\"value1\"},null]", result.getJson().toString()); + } + + @Test + void testRetractJsonTypes() { + ArrayAgg accumulator = JsonFunctions.JSON_ARRAYAGG.createAccumulator(); + FlinkJsonType json1 = new FlinkJsonType(readTree("{\"key\": \"value1\"}")); + FlinkJsonType json2 = new FlinkJsonType(readTree("{\"key\": \"value2\"}")); + JsonFunctions.JSON_ARRAYAGG.accumulate(accumulator, json1); + JsonFunctions.JSON_ARRAYAGG.accumulate(accumulator, json2); + + // Now retract one of the JSON objects + JsonFunctions.JSON_ARRAYAGG.retract(accumulator, json1); + + FlinkJsonType result = JsonFunctions.JSON_ARRAYAGG.getValue(accumulator); + assertNotNull(result); + assertEquals("[{\"key\":\"value2\"}]", result.getJson().toString()); + } + + @Test + void testRetractNullJsonType() { + ArrayAgg accumulator = JsonFunctions.JSON_ARRAYAGG.createAccumulator(); + FlinkJsonType json1 = new FlinkJsonType(readTree("{\"key\": \"value1\"}")); + JsonFunctions.JSON_ARRAYAGG.accumulate(accumulator, json1); + JsonFunctions.JSON_ARRAYAGG.accumulate(accumulator, (FlinkJsonType) null); + + // Now retract a null JSON object + JsonFunctions.JSON_ARRAYAGG.retract(accumulator, (FlinkJsonType) null); + + FlinkJsonType result = JsonFunctions.JSON_ARRAYAGG.getValue(accumulator); + assertNotNull(result); + assertEquals("[{\"key\":\"value1\"}]", result.getJson().toString()); + } + + @Test + void testRetractNullFromNonExisting() { + ArrayAgg accumulator = JsonFunctions.JSON_ARRAYAGG.createAccumulator(); + FlinkJsonType json1 = new FlinkJsonType(readTree("{\"key\": \"value1\"}")); + JsonFunctions.JSON_ARRAYAGG.accumulate(accumulator, json1); + + // Attempt to retract a null value that was never accumulated + JsonFunctions.JSON_ARRAYAGG.retract(accumulator, (FlinkJsonType) null); + + FlinkJsonType result = JsonFunctions.JSON_ARRAYAGG.getValue(accumulator); + assertNotNull(result); + assertEquals("[{\"key\":\"value1\"}]", result.getJson().toString()); + } + } + + @Nested + class JsonObjectAggTest { + + @Test + void testAggregateJsonTypes() { + ObjectAgg accumulator = JsonFunctions.JSON_OBJECTAGG.createAccumulator(); + JsonFunctions.JSON_OBJECTAGG.accumulate( + accumulator, "key1", new FlinkJsonType(readTree("{\"nestedKey1\": \"nestedValue1\"}"))); + JsonFunctions.JSON_OBJECTAGG.accumulate( + accumulator, "key2", new FlinkJsonType(readTree("{\"nestedKey2\": \"nestedValue2\"}"))); + + FlinkJsonType result = JsonFunctions.JSON_OBJECTAGG.getValue(accumulator); + assertNotNull(result); + assertEquals( + "{\"key1\":{\"nestedKey1\":\"nestedValue1\"},\"key2\":{\"nestedKey2\":\"nestedValue2\"}}", + result.getJson().toString()); + } + + @Test + void testAggregateWithOverwritingKeys() { + ObjectAgg accumulator = JsonFunctions.JSON_OBJECTAGG.createAccumulator(); + JsonFunctions.JSON_OBJECTAGG.accumulate(accumulator, "key", "value1"); + JsonFunctions.JSON_OBJECTAGG.accumulate(accumulator, "key", "value2"); + + FlinkJsonType result = JsonFunctions.JSON_OBJECTAGG.getValue(accumulator); + assertNotNull(result); + assertEquals( + "{\"key\":\"value2\"}", + result.getJson().toString()); // The last value for the same key should be retained + } + + @Test + void testNullKey() { + assertThrows( + IllegalArgumentException.class, () -> JsonFunctions.JSON_OBJECT.eval(null, "value1")); + } + + @Test + void testNullValue() { + FlinkJsonType result = JsonFunctions.JSON_OBJECT.eval("key1", null); + assertNotNull(result); + assertEquals("{\"key1\":null}", result.getJson().toString()); + } + + @Test + void testNullKeyValue() { + assertThrows( + IllegalArgumentException.class, () -> JsonFunctions.JSON_OBJECT.eval(null, null)); + } + + @Test + void testArrayOfNullValues() { + FlinkJsonType result = + JsonFunctions.JSON_OBJECT.eval("key1", new Object[] {null, null, null}); + assertNotNull(result); + // The expected output might vary based on how the function is designed to handle this case + assertEquals("{\"key1\":[null,null,null]}", result.getJson().toString()); + } + + @Test + void testRetractJsonTypes() { + ObjectAgg accumulator = JsonFunctions.JSON_OBJECTAGG.createAccumulator(); + JsonFunctions.JSON_OBJECTAGG.accumulate( + accumulator, "key1", new FlinkJsonType(readTree("{\"nestedKey1\": \"nestedValue1\"}"))); + JsonFunctions.JSON_OBJECTAGG.accumulate( + accumulator, "key2", new FlinkJsonType(readTree("{\"nestedKey2\": \"nestedValue2\"}"))); + + // Now retract a key-value pair + JsonFunctions.JSON_OBJECTAGG.retract( + accumulator, "key1", new FlinkJsonType(readTree("{\"nestedKey1\": \"nestedValue1\"}"))); + + FlinkJsonType result = JsonFunctions.JSON_OBJECTAGG.getValue(accumulator); + assertNotNull(result); + assertEquals("{\"key2\":{\"nestedKey2\":\"nestedValue2\"}}", result.getJson().toString()); + } + + @Test + void testRetractNullJsonValue() { + ObjectAgg accumulator = JsonFunctions.JSON_OBJECTAGG.createAccumulator(); + JsonFunctions.JSON_OBJECTAGG.accumulate( + accumulator, "key1", new FlinkJsonType(readTree("{\"nestedKey1\": \"nestedValue1\"}"))); + JsonFunctions.JSON_OBJECTAGG.accumulate(accumulator, "key2", (FlinkJsonType) null); + + // Now retract a null value + JsonFunctions.JSON_OBJECTAGG.retract(accumulator, "key2", (FlinkJsonType) null); + + FlinkJsonType result = JsonFunctions.JSON_OBJECTAGG.getValue(accumulator); + assertNotNull(result); + assertEquals("{\"key1\":{\"nestedKey1\":\"nestedValue1\"}}", result.getJson().toString()); + } + + @Test + void testRetractNullKey() { + ObjectAgg accumulator = JsonFunctions.JSON_OBJECTAGG.createAccumulator(); + JsonFunctions.JSON_OBJECTAGG.accumulate( + accumulator, "key1", new FlinkJsonType(readTree("{\"nestedKey1\": \"nestedValue1\"}"))); + JsonFunctions.JSON_OBJECTAGG.accumulate(accumulator, null, "someValue"); + + // Attempt to retract a key-value pair where the key is null + JsonFunctions.JSON_OBJECTAGG.retract(accumulator, null, "someValue"); + + FlinkJsonType result = JsonFunctions.JSON_OBJECTAGG.getValue(accumulator); + assertNotNull(result); + assertEquals("{\"key1\":{\"nestedKey1\":\"nestedValue1\"}}", result.getJson().toString()); + } + } +} diff --git a/types/vector-type/pom.xml b/types/vector-type/pom.xml index f54137a..45402f5 100644 --- a/types/vector-type/pom.xml +++ b/types/vector-type/pom.xml @@ -34,5 +34,22 @@ ${flink.version} provided + + org.apache.flink + flink-table-api-java-bridge + ${flink.version} + provided + + + ${project.groupId} + system-functions-discovery + ${project.version} + provided + + + com.google.auto.service + auto-service + 1.1.1 + diff --git a/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/AsciiTextTestEmbed.java b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/AsciiTextTestEmbed.java new file mode 100644 index 0000000..e1c1fc1 --- /dev/null +++ b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/AsciiTextTestEmbed.java @@ -0,0 +1,36 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.vector.functions; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.vector.FlinkVectorType; +import com.google.auto.service.AutoService; +import org.apache.flink.table.functions.ScalarFunction; + +/** A unuseful embedding function counts each character (modulo 256). Used for testing only. */ +@AutoService(AutoRegisterSystemFunction.class) +public class AsciiTextTestEmbed extends ScalarFunction implements AutoRegisterSystemFunction { + + private static final int VECTOR_LENGTH = 256; + + public FlinkVectorType eval(String text) { + double[] vector = new double[256]; + for (char c : text.toCharArray()) { + vector[c % VECTOR_LENGTH] += 1; + } + return new FlinkVectorType(vector); + } +} diff --git a/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/Center.java b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/Center.java new file mode 100644 index 0000000..c50a64b --- /dev/null +++ b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/Center.java @@ -0,0 +1,62 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.vector.functions; + +import static com.datasqrl.types.vector.functions.VectorFunctions.VEC_TO_DOUBLE; +import static com.datasqrl.types.vector.functions.VectorFunctions.convert; + +import com.datasqrl.types.vector.FlinkVectorType; +import org.apache.flink.table.functions.AggregateFunction; + +/** + * Aggregates vectors by computing the centroid, i.e. summing up all vectors and dividing the + * resulting vector by the number of vectors. + */ +public class Center extends AggregateFunction { + + @Override + public CenterAccumulator createAccumulator() { + return new CenterAccumulator(); + } + + @Override + public FlinkVectorType getValue(CenterAccumulator acc) { + if (acc.count == 0) { + return null; + } else { + return convert(acc.get()); + } + } + + public void accumulate(CenterAccumulator acc, FlinkVectorType vector) { + acc.add(VEC_TO_DOUBLE.eval(vector)); + } + + public void retract(CenterAccumulator acc, FlinkVectorType vector) { + acc.substract(VEC_TO_DOUBLE.eval(vector)); + } + + public void merge(CenterAccumulator acc, Iterable iter) { + for (CenterAccumulator a : iter) { + acc.addAll(a); + } + } + + public void resetAccumulator(CenterAccumulator acc) { + acc.count = 0; + acc.sum = null; + } +} diff --git a/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/CenterAccumulator.java b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/CenterAccumulator.java new file mode 100644 index 0000000..a70fba0 --- /dev/null +++ b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/CenterAccumulator.java @@ -0,0 +1,69 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.vector.functions; + +// import com.google.common.base.Preconditions; + +// mutable accumulator of structured type for the aggregate function +public class CenterAccumulator { + + public double[] sum = null; + public int count = 0; + + public synchronized void add(double[] values) { + if (count == 0) { + sum = values.clone(); + count = 1; + } else { + // Preconditions.checkArgument(values.length == sum.length); + for (int i = 0; i < values.length; i++) { + sum[i] += values[i]; + } + count++; + } + } + + public synchronized void addAll(CenterAccumulator other) { + if (other.count == 0) { + return; + } + if (this.count == 0) { + this.sum = new double[other.sum.length]; + } + // Preconditions.checkArgument(this.sum.length == other.sum.length); + for (int i = 0; i < other.sum.length; i++) { + this.sum[i] += other.sum[i]; + } + this.count += other.count; + } + + public double[] get() { + // Preconditions.checkArgument(count > 0); + double[] result = new double[sum.length]; + for (int i = 0; i < sum.length; i++) { + result[i] = sum[i] / count; + } + return result; + } + + public synchronized void substract(double[] values) { + // Preconditions.checkArgument(values.length == sum.length); + for (int i = 0; i < values.length; i++) { + sum[i] -= values[i]; + } + count--; + } +} diff --git a/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/CosineDistance.java b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/CosineDistance.java new file mode 100644 index 0000000..f390230 --- /dev/null +++ b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/CosineDistance.java @@ -0,0 +1,26 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.vector.functions; + +import com.datasqrl.types.vector.FlinkVectorType; + +/** Computes the cosine distance between two vectors */ +public class CosineDistance extends CosineSimilarity { + + public double eval(FlinkVectorType vectorA, FlinkVectorType vectorB) { + return 1 - super.eval(vectorA, vectorB); + } +} diff --git a/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/CosineSimilarity.java b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/CosineSimilarity.java new file mode 100644 index 0000000..32bff20 --- /dev/null +++ b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/CosineSimilarity.java @@ -0,0 +1,42 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.vector.functions; + +import static com.datasqrl.types.vector.functions.VectorFunctions.VEC_TO_DOUBLE; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.vector.FlinkVectorType; +import com.google.auto.service.AutoService; +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.linear.RealVector; +import org.apache.flink.table.functions.ScalarFunction; + +/** Computes the cosine similarity between two vectors */ +@AutoService(AutoRegisterSystemFunction.class) +public class CosineSimilarity extends ScalarFunction implements AutoRegisterSystemFunction { + + public double eval(FlinkVectorType vectorA, FlinkVectorType vectorB) { + // Create RealVectors from the input arrays + RealVector vA = new ArrayRealVector(VEC_TO_DOUBLE.eval(vectorA), false); + RealVector vB = new ArrayRealVector(VEC_TO_DOUBLE.eval(vectorB), false); + + // Calculate the cosine similarity + double dotProduct = vA.dotProduct(vB); + double normalization = vA.getNorm() * vB.getNorm(); + + return dotProduct / normalization; + } +} diff --git a/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/DoubleToVector.java b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/DoubleToVector.java new file mode 100644 index 0000000..277f456 --- /dev/null +++ b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/DoubleToVector.java @@ -0,0 +1,30 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.vector.functions; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.vector.FlinkVectorType; +import com.google.auto.service.AutoService; +import org.apache.flink.table.functions.ScalarFunction; + +/** Converts a double array to a vector */ +@AutoService(AutoRegisterSystemFunction.class) +public class DoubleToVector extends ScalarFunction implements AutoRegisterSystemFunction { + + public FlinkVectorType eval(double[] array) { + return new FlinkVectorType(array); + } +} diff --git a/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/EuclideanDistance.java b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/EuclideanDistance.java new file mode 100644 index 0000000..dd97d42 --- /dev/null +++ b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/EuclideanDistance.java @@ -0,0 +1,37 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.vector.functions; + +import static com.datasqrl.types.vector.functions.VectorFunctions.VEC_TO_DOUBLE; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.vector.FlinkVectorType; +import com.google.auto.service.AutoService; +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.linear.RealVector; +import org.apache.flink.table.functions.ScalarFunction; + +/** Computes the euclidean distance between two vectors */ +@AutoService(AutoRegisterSystemFunction.class) +public class EuclideanDistance extends ScalarFunction implements AutoRegisterSystemFunction { + + public double eval(FlinkVectorType vectorA, FlinkVectorType vectorB) { + // Create RealVectors from the input arrays + RealVector vA = new ArrayRealVector(VEC_TO_DOUBLE.eval(vectorA), false); + RealVector vB = new ArrayRealVector(VEC_TO_DOUBLE.eval(vectorB), false); + return vA.getDistance(vB); + } +} diff --git a/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/README.md b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/README.md new file mode 100644 index 0000000..6d16b4d --- /dev/null +++ b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/README.md @@ -0,0 +1,10 @@ +| Function Documentation | +|-------------------------| +| `AsciiTextTestEmbed(string) → vector`

Convert text to a vector of length 256 where each character's ASCII value is mapped.
Example: `AsciiTextTestEmbed('hello') → [0, 0, 0, ..., 1, 0, 1, 2, ...]` | +| `Center(vector) → vector`

Aggregate function to compute the center of multiple vectors.
Example: `Center([1.0, 2.0], [3.0, 4.0]) → [2.0, 3.0]` | +| `CosineDistance(vector, vector) → double`

Compute the cosine distance between two vectors.
Example: `CosineDistance([1.0, 0.0], [0.0, 1.0]) → 1.0` | +| `CosineSimilarity(vector, vector) → double`

Compute the cosine similarity between two vectors.
Example: `CosineSimilarity([1.0, 0.0], [0.0, 1.0]) → 0.0` | +| `DoubleToVector(array) → vector`

Convert an array of doubles to a vector.
Example: `DoubleToVector([1.0, 2.0, 3.0]) → [1.0, 2.0, 3.0]` | +| `EuclideanDistance(vector, vector) → double`

Compute the Euclidean distance between two vectors.
Example: `EuclideanDistance([1.0, 0.0], [0.0, 1.0]) → 1.41421356237` | +| `OnnxEmbed(string, string) → vector`

Convert text to a vector using an ONNX model.
Example: `OnnxEmbed('hello', '/path/to/model') → [0.5, 0.1, ...]` | +| `VectorToDouble(vector) → array`

Convert a vector to an array of doubles.
Example: `VectorToDouble([1.0, 2.0, 3.0]) → [1.0, 2.0, 3.0]` | \ No newline at end of file diff --git a/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/VectorFunctions.java b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/VectorFunctions.java new file mode 100644 index 0000000..6279e3a --- /dev/null +++ b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/VectorFunctions.java @@ -0,0 +1,50 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.vector.functions; + +import com.datasqrl.types.vector.FlinkVectorType; +import java.util.Set; +import org.apache.flink.table.functions.FunctionDefinition; + +public class VectorFunctions { + + public static final CosineSimilarity COSINE_SIMILARITY = new CosineSimilarity(); + public static final CosineDistance COSINE_DISTANCE = new CosineDistance(); + + public static final EuclideanDistance EUCLIDEAN_DISTANCE = new EuclideanDistance(); + + public static final VectorToDouble VEC_TO_DOUBLE = new VectorToDouble(); + + public static final DoubleToVector DOUBLE_TO_VECTOR = new DoubleToVector(); + + public static final AsciiTextTestEmbed ASCII_TEXT_TEST_EMBED = new AsciiTextTestEmbed(); + + public static final Center CENTER = new Center(); + + public static final Set functions = + Set.of( + COSINE_SIMILARITY, + COSINE_DISTANCE, + EUCLIDEAN_DISTANCE, + VEC_TO_DOUBLE, + DOUBLE_TO_VECTOR, + ASCII_TEXT_TEST_EMBED, + CENTER); + + public static FlinkVectorType convert(double[] vector) { + return new FlinkVectorType(vector); + } +} diff --git a/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/VectorToDouble.java b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/VectorToDouble.java new file mode 100644 index 0000000..5dc1033 --- /dev/null +++ b/types/vector-type/src/main/java/com/datasqrl/types/vector/functions/VectorToDouble.java @@ -0,0 +1,30 @@ +/* + * Copyright © 2024 DataSQRL (contact@datasqrl.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datasqrl.types.vector.functions; + +import com.datasqrl.function.AutoRegisterSystemFunction; +import com.datasqrl.types.vector.FlinkVectorType; +import com.google.auto.service.AutoService; +import org.apache.flink.table.functions.ScalarFunction; + +/** Converts a vector to a double array */ +@AutoService(AutoRegisterSystemFunction.class) +public class VectorToDouble extends ScalarFunction implements AutoRegisterSystemFunction { + + public double[] eval(FlinkVectorType vectorType) { + return vectorType.getValue(); + } +}