From a4ad4da3f11136e2eb03a181707498af07a10c31 Mon Sep 17 00:00:00 2001 From: Aleksey Myasnikov Date: Wed, 7 May 2025 15:28:08 +0300 Subject: [PATCH] WIP --- internal/bind/params.go | 5 + internal/params/parameters.go | 7 ++ internal/value/cast.go | 38 +++++++- params_test.go | 14 +-- tests/integration/param_raw_protobuf_test.go | 99 ++++++++++++++++++++ 5 files changed, 152 insertions(+), 11 deletions(-) create mode 100644 tests/integration/param_raw_protobuf_test.go diff --git a/internal/bind/params.go b/internal/bind/params.go index e78d77f87..ef551cf8d 100644 --- a/internal/bind/params.go +++ b/internal/bind/params.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" "net/url" "reflect" "sort" @@ -212,6 +213,10 @@ func toValue(v any) (_ value.Value, err error) { return value.VoidValue(), nil case value.Value: return x, nil + case Ydb.TypedValue: + return value.FromProtobuf(&x), nil + case *Ydb.TypedValue: + return value.FromProtobuf(x), nil case bool: return value.BoolValue(x), nil case int: diff --git a/internal/params/parameters.go b/internal/params/parameters.go index 8b0588293..d8aae83dd 100644 --- a/internal/params/parameters.go +++ b/internal/params/parameters.go @@ -423,7 +423,14 @@ func (p *Parameter) TzDatetime(v time.Time) Builder { } } +// Raw makes value from raw protobuf +// Deprecated: use FromProtobuf instead func (p *Parameter) Raw(pb *Ydb.TypedValue) Builder { + return p.FromProtobuf(pb) +} + +// FromProtobuf makes value from raw protobuf +func (p *Parameter) FromProtobuf(pb *Ydb.TypedValue) Builder { p.value = value.FromProtobuf(pb) return Builder{ diff --git a/internal/value/cast.go b/internal/value/cast.go index 7106b62fd..c80f5f4aa 100644 --- a/internal/value/cast.go +++ b/internal/value/cast.go @@ -1,14 +1,44 @@ package value +import ( + "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" +) + func CastTo(v Value, dst interface{}) error { if dst == nil { return errNilDestination } - if ptr, has := dst.(*Value); has { - *ptr = v + + switch x := dst.(type) { + case *Value: + *x = v return nil - } + case **Value: + **x = v + + return nil + case *Ydb.Value: + *x = *v.toYDB() + + return nil + case **Ydb.Value: + *x = v.toYDB() - return v.castTo(dst) + return nil + case *Ydb.TypedValue: + x.Type = v.Type().ToYDB() + x.Value = v.toYDB() + + return nil + case **Ydb.TypedValue: + *x = &Ydb.TypedValue{ + Type: v.Type().ToYDB(), + Value: v.toYDB(), + } + + return nil + default: + return v.castTo(dst) + } } diff --git a/params_test.go b/params_test.go index a920b7f06..3eb5b4182 100644 --- a/params_test.go +++ b/params_test.go @@ -108,12 +108,12 @@ func makeParamsUsingParamsBuilder(tb testing.TB) params.Parameters { Build() } -func makeParamsUsingRawProtobuf(tb testing.TB) params.Parameters { +func makeParamsFromProtobuf(tb testing.TB) params.Parameters { return ydb.ParamsBuilder(). - Param("$a").Raw(a). - Param("$b").Raw(b). - Param("$c").Raw(c). - Param("$d").Raw(d). + Param("$a").FromProtobuf(a). + Param("$b").FromProtobuf(b). + Param("$c").FromProtobuf(c). + Param("$d").FromProtobuf(d). Build() } @@ -153,7 +153,7 @@ func TestParams(t *testing.T) { require.NoError(t, err) require.Equal(t, fmt.Sprint(exp), fmt.Sprint(pb)) t.Run("Raw", func(t *testing.T) { - params := makeParamsUsingRawProtobuf(t) + params := makeParamsFromProtobuf(t) pb, err := params.ToYDB() require.NoError(t, err) require.Equal(t, fmt.Sprint(exp), fmt.Sprint(pb)) @@ -184,7 +184,7 @@ func BenchmarkParams(b *testing.B) { b.Run("RawProtobuf", func(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - params := makeParamsUsingRawProtobuf(b) + params := makeParamsFromProtobuf(b) _, _ = params.ToYDB() } }) diff --git a/tests/integration/param_raw_protobuf_test.go b/tests/integration/param_raw_protobuf_test.go new file mode 100644 index 000000000..7026dd886 --- /dev/null +++ b/tests/integration/param_raw_protobuf_test.go @@ -0,0 +1,99 @@ +//go:build integration +// +build integration + +package integration + +import ( + "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" + "github.com/ydb-platform/ydb-go-sdk/v3/query" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ydb-platform/ydb-go-sdk/v3" +) + +func TestRawProtobuf(t *testing.T) { + raw := &Ydb.TypedValue{ + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + Value: &Ydb.Value{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 123, + }, + }, + } + + t.Run("query", func(t *testing.T) { + var ( + scope = newScope(t) + db = scope.Driver() + ) + + row, err := db.Query().QueryRow(scope.Ctx, ` + DECLARE $raw AS Uint64; + SELECT $raw;`, + query.WithParameters( + ydb.ParamsBuilder().Param("$raw").FromProtobuf(raw).Build(), + ), + ) + require.NoError(t, err) + + t.Run("*Ydb.TypedValue", func(t *testing.T) { + var act *Ydb.TypedValue + require.NoError(t, row.Scan(&act)) + require.Equal(t, raw.String(), act.String()) + }) + t.Run("**Ydb.TypedValue", func(t *testing.T) { + var act *Ydb.TypedValue + require.NoError(t, row.Scan(&act)) + require.Equal(t, raw.String(), act.String()) + }) + t.Run("*Ydb.Value", func(t *testing.T) { + var act *Ydb.Value + require.NoError(t, row.Scan(&act)) + require.Equal(t, raw.Value.String(), act.String()) + }) + t.Run("**Ydb.Value", func(t *testing.T) { + var act *Ydb.Value + require.NoError(t, row.Scan(&act)) + require.Equal(t, raw.Value.String(), act.String()) + }) + }) + t.Run("database/sql", func(t *testing.T) { + var ( + scope = newScope(t) + db = scope.SQLDriver( + ydb.WithAutoDeclare(), + ydb.WithPositionalArgs(), + ) + ) + + row := db.QueryRowContext(scope.Ctx, `SELECT ?`, raw) + require.NoError(t, row.Err()) + + t.Run("*Ydb.TypedValue", func(t *testing.T) { + var act *Ydb.TypedValue + require.NoError(t, row.Scan(&act)) + require.Equal(t, raw.String(), act.String()) + }) + t.Run("**Ydb.TypedValue", func(t *testing.T) { + var act *Ydb.TypedValue + require.NoError(t, row.Scan(&act)) + require.Equal(t, raw.String(), act.String()) + }) + t.Run("*Ydb.Value", func(t *testing.T) { + var act *Ydb.Value + require.NoError(t, row.Scan(&act)) + require.Equal(t, raw.Value.String(), act.String()) + }) + t.Run("**Ydb.Value", func(t *testing.T) { + var act *Ydb.Value + require.NoError(t, row.Scan(&act)) + require.Equal(t, raw.Value.String(), act.String()) + }) + }) +}