Skip to content

Commit 38342ee

Browse files
authored
feat: 添加向量距离FunctionFragmentBuilder实现 (#99)
1 parent 07419bc commit 38342ee

6 files changed

Lines changed: 60 additions & 12 deletions

File tree

hsweb-easy-orm-rdb/src/main/java/org/hswebframework/ezorm/rdb/supports/postgres/PostgresqlSchemaMetadata.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ public RDBTableMetadata newTable(String name) {
5454
column.addFeature(PostgresqlEnumInFragmentBuilder.notIn);
5555
}
5656
if (column.getValueCodec() instanceof VectorType) {
57-
PostgresqlVectorFragmentBuilder.ALL.values().forEach(column::addFeature);
57+
addFeature(new PostgresqlVectorDistanceFunctionFragmentBuilder());
58+
PostgresqlVectorDistanceTermFragmentBuilder.ALL.values().forEach(column::addFeature);
5859
}
5960
});
6061
return metadata;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package org.hswebframework.ezorm.rdb.supports.postgres;
2+
3+
import lombok.Getter;
4+
import org.hswebframework.ezorm.rdb.metadata.RDBColumnMetadata;
5+
import org.hswebframework.ezorm.rdb.operator.builder.fragments.EmptySqlFragments;
6+
import org.hswebframework.ezorm.rdb.operator.builder.fragments.PrepareSqlFragments;
7+
import org.hswebframework.ezorm.rdb.operator.builder.fragments.SqlFragments;
8+
import org.hswebframework.ezorm.rdb.operator.builder.fragments.function.FunctionFragmentBuilder;
9+
10+
import java.util.Map;
11+
12+
@Getter
13+
public class PostgresqlVectorDistanceFunctionFragmentBuilder implements FunctionFragmentBuilder {
14+
15+
public static final String function_id = "vector_distance";
16+
public static final String opt_term_type_key = "termType";
17+
public static final String opt_vector_value_key = "vectorValue";
18+
19+
private final String function = function_id;
20+
21+
private final String name = "向量距离";
22+
23+
@Override
24+
public SqlFragments create(String columnFullName, RDBColumnMetadata metadata, Map<String, Object> opts) {
25+
VectorTermType termType = VectorTermType.of(opts.getOrDefault(opt_term_type_key, VectorTermType.vector_ip));
26+
Float[] array = VectorType.toFloatArray(opts.get(opt_vector_value_key));
27+
if (array == null || termType == null) {
28+
return EmptySqlFragments.INSTANCE;
29+
}
30+
31+
String vectorColumn = VectorUtils.getVectorDistanceColumn(columnFullName, termType, array);
32+
return PrepareSqlFragments.of(vectorColumn);
33+
}
34+
35+
36+
}

hsweb-easy-orm-rdb/src/main/java/org/hswebframework/ezorm/rdb/supports/postgres/PostgresqlVectorFragmentBuilder.java renamed to hsweb-easy-orm-rdb/src/main/java/org/hswebframework/ezorm/rdb/supports/postgres/PostgresqlVectorDistanceTermFragmentBuilder.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,26 @@
1010
import java.util.HashMap;
1111
import java.util.Map;
1212

13-
public class PostgresqlVectorFragmentBuilder extends AbstractTermFragmentBuilder {
14-
public static final Map<VectorTermType, PostgresqlVectorFragmentBuilder> ALL = new HashMap<>();
13+
public class PostgresqlVectorDistanceTermFragmentBuilder extends AbstractTermFragmentBuilder {
14+
public static final Map<VectorTermType, PostgresqlVectorDistanceTermFragmentBuilder> ALL = new HashMap<>();
1515

1616
static {
1717
for (VectorTermType value : VectorTermType.values()) {
18-
ALL.put(value, new PostgresqlVectorFragmentBuilder(value));
18+
ALL.put(value, new PostgresqlVectorDistanceTermFragmentBuilder(value));
1919
}
2020
}
2121

2222
private final VectorTermType type;
2323

24-
public PostgresqlVectorFragmentBuilder(VectorTermType type) {
25-
super(type.name(), "向量查询");
24+
public PostgresqlVectorDistanceTermFragmentBuilder(VectorTermType type) {
25+
super(type.name(), "向量距离查询");
2626
this.type = type;
2727
}
2828

2929
@Override
3030
public SqlFragments createFragments(String columnFullName, RDBColumnMetadata column, Term term) {
3131
VectorQueryParam vectorTerm = VectorQueryParam.of(term.getValue());
32-
String vectorColumn = VectorUtils.getVectorColumn(columnFullName, type, vectorTerm.getVector());
32+
String vectorColumn = VectorUtils.getVectorDistanceColumn(columnFullName, type, vectorTerm.getVector());
3333
return createTermFragments(column, vectorColumn, vectorTerm.getTerm(type, term.getColumn()));
3434
}
3535

hsweb-easy-orm-rdb/src/main/java/org/hswebframework/ezorm/rdb/supports/postgres/VectorTermType.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import lombok.AllArgsConstructor;
44
import lombok.Getter;
5+
import reactor.util.annotation.Nullable;
56

67
@Getter
78
@AllArgsConstructor
@@ -45,9 +46,20 @@ public Float toSqlValue(int alpha, Float value) {
4546
/**
4647
* 将统一 distance(0~1)转换为 SQL 条件值
4748
*
48-
* @param alpha 该l2模型距离分布的中位数
49+
* @param alpha 该l2模型距离分布的中位数
4950
* @param distance 距离 (0~1) 0最相似
5051
* @return 值
5152
*/
5253
public abstract Float toSqlValue(int alpha, Float distance);
54+
55+
@Nullable
56+
public static VectorTermType of(Object value) {
57+
if (value instanceof VectorTermType) {
58+
return (VectorTermType) value;
59+
}
60+
if (value instanceof String) {
61+
return VectorTermType.valueOf((String) value);
62+
}
63+
return null;
64+
}
5365
}

hsweb-easy-orm-rdb/src/main/java/org/hswebframework/ezorm/rdb/supports/postgres/VectorType.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.hswebframework.ezorm.rdb.metadata.RDBColumnMetadata;
1010
import org.hswebframework.ezorm.rdb.metadata.dialect.DataTypeBuilder;
1111
import org.postgresql.util.PGobject;
12+
import reactor.util.annotation.Nullable;
1213

1314
import java.lang.reflect.Array;
1415
import java.math.BigDecimal;
@@ -64,9 +65,6 @@ private Float[] toFloat(Object data) {
6465
if (data == null) {
6566
return null;
6667
}
67-
if (data instanceof Float[] values) {
68-
return values;
69-
}
7068
if (data instanceof Vector vector) {
7169
return toFloatArray(vector.getVector());
7270
}
@@ -76,6 +74,7 @@ private Float[] toFloat(Object data) {
7674
return toFloatArray(data);
7775
}
7876

77+
@Nullable
7978
public static Float[] toFloatArray(Object value) {
8079
if (value == null) {
8180
return null;

hsweb-easy-orm-rdb/src/main/java/org/hswebframework/ezorm/rdb/supports/postgres/VectorUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public class VectorUtils {
1616
* @param vector 目标向量值
1717
* @return 向量距离计算字段
1818
*/
19-
public static String getVectorColumn(String columnFullName, VectorTermType type, Float[] vector) {
19+
public static String getVectorDistanceColumn(String columnFullName, VectorTermType type, Float[] vector) {
2020
return PrepareSqlFragments
2121
.of(columnFullName)
2222
.addSql(type.getOperation(),"?")

0 commit comments

Comments
 (0)