Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hsweb-easy-orm-rdb/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<version>42.7.8</version>
<scope>test</scope>
<optional>true</optional>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import java.math.BigDecimal;
import java.sql.JDBCType;
import java.sql.Timestamp;
import java.time.ZonedDateTime;

/**
Expand Down Expand Up @@ -43,6 +42,15 @@ public PostgresqlDialect() {
addDataTypeBuilder("json", meta -> "json");
addDataTypeBuilder("jsonb", meta -> "jsonb");

addDataTypeBuilder("vector", meta -> StringUtils.concat("vector(", meta.getLength(512), ")"));
addDataTypeBuilder("halfvec", meta -> StringUtils.concat("halfvec(", meta.getLength(512), ")"));
addDataTypeBuilder("sparsevec", meta -> StringUtils.concat("sparsevec(", meta.getLength(512), ")"));


registerDataType("vector", VectorType.VECTOR);
registerDataType("halfvec", VectorType.HALF_VECTOR);
registerDataType("sparsevec", VectorType.SPARSE_VECTOR);


registerDataType("json", JsonType.INSTANCE);
registerDataType("jsonb", JsonbType.INSTANCE);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package org.hswebframework.ezorm.rdb.supports.postgres;

import org.hswebframework.ezorm.core.ValueCodec;
import org.hswebframework.ezorm.rdb.codec.EnumValueCodec;
import org.hswebframework.ezorm.rdb.metadata.RDBFeatures;
import org.hswebframework.ezorm.rdb.metadata.DefaultValueCodecFactory;
import org.hswebframework.ezorm.rdb.metadata.RDBSchemaMetadata;
import org.hswebframework.ezorm.rdb.metadata.RDBTableMetadata;
import org.hswebframework.ezorm.rdb.metadata.ValueCodecFactory;
import org.hswebframework.ezorm.rdb.metadata.dialect.Dialect;
import org.hswebframework.ezorm.rdb.operator.CompositeExceptionTranslation;
import org.hswebframework.ezorm.rdb.utils.FeatureUtils;

import java.util.Optional;

public class PostgresqlSchemaMetadata extends RDBSchemaMetadata {

public PostgresqlSchemaMetadata(String name) {
Expand All @@ -23,6 +27,15 @@ public PostgresqlSchemaMetadata(String name) {
addFeature(new CompositeExceptionTranslation()
.add(FeatureUtils.r2dbcIsAlive(), () -> PostgresqlR2DBCExceptionTranslation.of(this))
);

addFeature((ValueCodecFactory) column -> {
if(column.getType() instanceof ValueCodec){
return Optional.of(
((ValueCodec<?,?>) column.getType())
);
};
return DefaultValueCodecFactory.COMMONS.createValueCodec(column);
});
}

@Override
Expand All @@ -40,6 +53,9 @@ public RDBTableMetadata newTable(String name) {
column.addFeature(PostgresqlEnumInFragmentBuilder.in);
column.addFeature(PostgresqlEnumInFragmentBuilder.notIn);
}
if (column.getValueCodec() instanceof VectorType) {
PostgresqlVectorFragmentBuilder.ALL.values().forEach(column::addFeature);
}
});
return metadata;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public class PostgresqlTableMetadataParser extends RDBTableMetadataParser {
* <pre>{@code
* select column_name::varchar as "name",
* udt_name::varchar as "data_type",
* case when udt_name in ('vector','halfvec','sparsevec') then format_type(a.atttypid,a.atttypmod)::varchar else null end as "column_type"
* character_maximum_length::int4 as "data_length",
* numeric_precision::int4 as "data_precision",
* numeric_scale::int4 as "data_scale",
Expand All @@ -37,6 +38,7 @@ public class PostgresqlTableMetadataParser extends RDBTableMetadataParser {
String.join(" ",
"select column_name::varchar as \"name\"",
", udt_name::varchar as \"data_type\"",
", case when udt_name in ('vector','halfvec','sparsevec') then format_type(a.atttypid,a.atttypmod)::varchar else null end as \"column_type\"",
", character_maximum_length::int4 as \"data_length\"",
", numeric_precision::int4 as \"data_precision\"",
", numeric_scale::int4 as \"data_scale\"",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.hswebframework.ezorm.rdb.supports.postgres;

import org.hswebframework.ezorm.core.param.Term;
import org.hswebframework.ezorm.rdb.metadata.RDBColumnMetadata;
import org.hswebframework.ezorm.rdb.operator.builder.fragments.EmptySqlFragments;
import org.hswebframework.ezorm.rdb.operator.builder.fragments.SqlFragments;
import org.hswebframework.ezorm.rdb.operator.builder.fragments.TermFragmentBuilder;
import org.hswebframework.ezorm.rdb.operator.builder.fragments.term.AbstractTermFragmentBuilder;

import java.util.HashMap;
import java.util.Map;

public class PostgresqlVectorFragmentBuilder extends AbstractTermFragmentBuilder {
public static final Map<VectorTermType, PostgresqlVectorFragmentBuilder> ALL = new HashMap<>();

static {
for (VectorTermType value : VectorTermType.values()) {
ALL.put(value, new PostgresqlVectorFragmentBuilder(value));
}
}

private final VectorTermType type;

public PostgresqlVectorFragmentBuilder(VectorTermType type) {
super(type.name(), "向量查询");
this.type = type;
}

@Override
public SqlFragments createFragments(String columnFullName, RDBColumnMetadata column, Term term) {
VectorQueryParam vectorTerm = VectorQueryParam.of(term.getValue());
String vectorColumn = VectorUtils.getVectorColumn(columnFullName, type, vectorTerm.getVector());
return createTermFragments(column, vectorColumn, vectorTerm.getTerm(type, term.getColumn()));
}


protected SqlFragments createTermFragments(RDBColumnMetadata column,
String vectorColumn,
Term term) {
TermFragmentBuilder builder = column
.findFeature(TermFragmentBuilder.createFeatureId(term.getTermType()))
.orElse(null);

if (builder != null) {
return builder
.createFragments(vectorColumn, column, term);
}
return EmptySqlFragments.INSTANCE;
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package org.hswebframework.ezorm.rdb.supports.postgres;

import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.*;
import org.apache.commons.beanutils.BeanUtils;
import org.hswebframework.ezorm.core.param.Term;
import org.hswebframework.ezorm.core.param.TermType;
import org.hswebframework.ezorm.rdb.operator.builder.fragments.NativeSql;

import java.util.Map;

@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor(staticName = "of")
public class VectorQueryParam {

/**
* 求相似距离的向量值
*/
private Float[] vector;
/**
* 和相似距离比较的操作符
*/
private String termType = TermType.lte;
/**
* 距离 (0~1) 0最相似
*/
private float distance = 0.25f;

//该l2模型距离分布的中位数
private int alpha = 50;

public VectorQueryParam(Object vector) {
this.vector = VectorType.toFloatArray(vector);
}

@JsonIgnore
public Term getTerm(VectorTermType type, String column) {
return Term.of(column, termType, NativeSql.of("?", type.toSqlValue(alpha, distance)));
}

@SneakyThrows
public static VectorQueryParam of(Object value) {
if (value == null) {
return null;
}
if (value instanceof VectorQueryParam) {
return ((VectorQueryParam) value);
}
if (value instanceof Map<?, ?> v) {
VectorQueryParam term = new VectorQueryParam();
BeanUtils.copyProperties(v, term);
if (term.getVector() == null) {
term.setVector(VectorType.toFloatArray(v.get("vector")));
}
return term;
}
return new VectorQueryParam(value);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package org.hswebframework.ezorm.rdb.supports.postgres;

import lombok.AllArgsConstructor;
import lombok.Getter;

@Getter
@AllArgsConstructor
public enum VectorTermType {
/**
* L2 Distance,不推荐使用
* [0,正无穷],0最相似
*/
vector_l2("<->") {
@Override
public Float toSqlValue(int alpha, Float value) {
return value * alpha / (1 - value);
}
},
/**
* cosine distance
* [0,2],0最相似
*/
vector_cos("<=>") {
@Override
public Float toSqlValue(int alpha, Float value) {
return 2 * value;
}
},
/**
* Inner Product 归一化后
* [-1,1],-1最相似
*/
vector_ip("<#>") {
@Override
public Float toSqlValue(int alpha, Float value) {
if (value == 0) {
return -1f;
}
return 2 * value - 1;
}
};

private final String operation;

/**
* 将统一 distance(0~1)转换为 SQL 条件值
*
* @param alpha 该l2模型距离分布的中位数
* @param distance 距离 (0~1) 0最相似
* @return 值
*/
public abstract Float toSqlValue(int alpha, Float distance);
}
Loading
Loading