From 2ade7fe3b18bbd7979bf5b751dbb3006e69789f6 Mon Sep 17 00:00:00 2001 From: "chenzhe.29" Date: Mon, 3 Nov 2025 19:21:35 +0800 Subject: [PATCH 1/8] support thinking config --- idl/thrift/coze/loop/llm/domain/runtime.thrift | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/idl/thrift/coze/loop/llm/domain/runtime.thrift b/idl/thrift/coze/loop/llm/domain/runtime.thrift index 0ae685dc8..b61b809a6 100644 --- a/idl/thrift/coze/loop/llm/domain/runtime.thrift +++ b/idl/thrift/coze/loop/llm/domain/runtime.thrift @@ -13,6 +13,11 @@ struct ModelConfig { 8: optional i32 top_k 9: optional double presence_penalty 10: optional double frequency_penalty + 11: optional Thinking thinking +} + +struct Thinking { + 1: optional ThinkingOption thinking_option } struct Message { @@ -118,3 +123,8 @@ typedef string ImageURLDetail (ts.enum="true") const ImageURLDetail image_url_detail_auto = "auto" const ImageURLDetail image_url_detail_low = "low" const ImageURLDetail image_url_detail_high = "high" + +typedef string ThinkingOption (ts.enum="true") +const ThinkingOption thinking_option_disabled = "disabled" +const ThinkingOption thinking_option_enabled = "enabled" +const ThinkingOption thinking_option_auto = "auto" \ No newline at end of file From c08b60564fae7fe0347c5817622866c2e6e856f9 Mon Sep 17 00:00:00 2001 From: "chenzhe.29" Date: Tue, 4 Nov 2025 17:51:00 +0800 Subject: [PATCH 2/8] add thinking --- idl/thrift/coze/loop/llm/domain/runtime.thrift | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/idl/thrift/coze/loop/llm/domain/runtime.thrift b/idl/thrift/coze/loop/llm/domain/runtime.thrift index b61b809a6..729a205e3 100644 --- a/idl/thrift/coze/loop/llm/domain/runtime.thrift +++ b/idl/thrift/coze/loop/llm/domain/runtime.thrift @@ -1,6 +1,7 @@ namespace go coze.loop.llm.domain.runtime include "common.thrift" +include "manage.thrift" struct ModelConfig { 1: required i64 model_id (api.js_conv='true', go.tag='json:"model_id"')// 模型id @@ -13,11 +14,15 @@ struct ModelConfig { 8: optional i32 top_k 9: optional double presence_penalty 10: optional double frequency_penalty - 11: optional Thinking thinking + + // 与ParamSchema对应 + 100: optional list param_config_values } -struct Thinking { - 1: optional ThinkingOption thinking_option +struct ParamConfigValue { + 1: optional string name // 传给下游模型的key,与ParamSchema.name对齐 + 2: optional string label // 展示名称 + 3: optional manage.ParamOption value // 传给下游模型的value,与ParamSchema.options对齐 } struct Message { @@ -122,9 +127,4 @@ const ChatMessagePartType chat_message_part_type_image_url = "image_url" typedef string ImageURLDetail (ts.enum="true") const ImageURLDetail image_url_detail_auto = "auto" const ImageURLDetail image_url_detail_low = "low" -const ImageURLDetail image_url_detail_high = "high" - -typedef string ThinkingOption (ts.enum="true") -const ThinkingOption thinking_option_disabled = "disabled" -const ThinkingOption thinking_option_enabled = "enabled" -const ThinkingOption thinking_option_auto = "auto" \ No newline at end of file +const ImageURLDetail image_url_detail_high = "high" \ No newline at end of file From e4650569d38e01bc5829948c4c0ac90e873589de Mon Sep 17 00:00:00 2001 From: "chenzhe.29" Date: Tue, 4 Nov 2025 17:54:51 +0800 Subject: [PATCH 3/8] gen code --- .../api/handler/coze/loop/apis/wire_gen.go | 1 + .../coze/loop/llm/domain/runtime/k-runtime.go | 318 +++++++++++++ .../coze/loop/llm/domain/runtime/runtime.go | 450 +++++++++++++++++- .../llm/domain/runtime/runtime_validator.go | 8 + 4 files changed, 767 insertions(+), 10 deletions(-) diff --git a/backend/api/handler/coze/loop/apis/wire_gen.go b/backend/api/handler/coze/loop/apis/wire_gen.go index 7b5ae44ac..f090ab564 100644 --- a/backend/api/handler/coze/loop/apis/wire_gen.go +++ b/backend/api/handler/coze/loop/apis/wire_gen.go @@ -8,6 +8,7 @@ package apis import ( "context" + "github.com/cloudwego/kitex/pkg/endpoint" "github.com/coze-dev/coze-loop/backend/infra/ck" "github.com/coze-dev/coze-loop/backend/infra/db" diff --git a/backend/kitex_gen/coze/loop/llm/domain/runtime/k-runtime.go b/backend/kitex_gen/coze/loop/llm/domain/runtime/k-runtime.go index a977efdd3..6388b2e4f 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/runtime/k-runtime.go +++ b/backend/kitex_gen/coze/loop/llm/domain/runtime/k-runtime.go @@ -12,10 +12,12 @@ import ( kutils "github.com/cloudwego/kitex/pkg/utils" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/common" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/manage" ) var ( _ = common.KitexUnusedProtection + _ = manage.KitexUnusedProtection ) // unused protection @@ -186,6 +188,20 @@ func (p *ModelConfig) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 100: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField100(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -358,6 +374,31 @@ func (p *ModelConfig) FastReadField10(buf []byte) (int, error) { return offset, nil } +func (p *ModelConfig) FastReadField100(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]*ParamConfigValue, 0, size) + values := make([]ParamConfigValue, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + if l, err := _elem.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + + _field = append(_field, _elem) + } + p.ParamConfigValues = _field + return offset, nil +} + func (p *ModelConfig) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -375,6 +416,7 @@ func (p *ModelConfig) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField5(buf[offset:], w) offset += p.fastWriteField6(buf[offset:], w) offset += p.fastWriteField7(buf[offset:], w) + offset += p.fastWriteField100(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -393,6 +435,7 @@ func (p *ModelConfig) BLength() int { l += p.field8Length() l += p.field9Length() l += p.field10Length() + l += p.field100Length() } l += thrift.Binary.FieldStopLength() return l @@ -493,6 +536,22 @@ func (p *ModelConfig) fastWriteField10(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *ModelConfig) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetParamConfigValues() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 100) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.ParamConfigValues { + length++ + offset += v.FastWriteNocopy(buf[offset:], w) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRUCT, length) + } + return offset +} + func (p *ModelConfig) field1Length() int { l := 0 l += thrift.Binary.FieldBeginLength() @@ -585,6 +644,19 @@ func (p *ModelConfig) field10Length() int { return l } +func (p *ModelConfig) field100Length() int { + l := 0 + if p.IsSetParamConfigValues() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.ParamConfigValues { + _ = v + l += v.BLength() + } + } + return l +} + func (p *ModelConfig) DeepCopy(s interface{}) error { src, ok := s.(*ModelConfig) if !ok { @@ -648,6 +720,252 @@ func (p *ModelConfig) DeepCopy(s interface{}) error { p.FrequencyPenalty = &tmp } + if src.ParamConfigValues != nil { + p.ParamConfigValues = make([]*ParamConfigValue, 0, len(src.ParamConfigValues)) + for _, elem := range src.ParamConfigValues { + var _elem *ParamConfigValue + if elem != nil { + _elem = &ParamConfigValue{} + if err := _elem.DeepCopy(elem); err != nil { + return err + } + } + + p.ParamConfigValues = append(p.ParamConfigValues, _elem) + } + } + + return nil +} + +func (p *ParamConfigValue) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ParamConfigValue[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *ParamConfigValue) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Name = _field + return offset, nil +} + +func (p *ParamConfigValue) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Label = _field + return offset, nil +} + +func (p *ParamConfigValue) FastReadField3(buf []byte) (int, error) { + offset := 0 + _field := manage.NewParamOption() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Value = _field + return offset, nil +} + +func (p *ParamConfigValue) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *ParamConfigValue) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *ParamConfigValue) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *ParamConfigValue) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetName() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Name) + } + return offset +} + +func (p *ParamConfigValue) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetLabel() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Label) + } + return offset +} + +func (p *ParamConfigValue) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetValue() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 3) + offset += p.Value.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *ParamConfigValue) field1Length() int { + l := 0 + if p.IsSetName() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Name) + } + return l +} + +func (p *ParamConfigValue) field2Length() int { + l := 0 + if p.IsSetLabel() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Label) + } + return l +} + +func (p *ParamConfigValue) field3Length() int { + l := 0 + if p.IsSetValue() { + l += thrift.Binary.FieldBeginLength() + l += p.Value.BLength() + } + return l +} + +func (p *ParamConfigValue) DeepCopy(s interface{}) error { + src, ok := s.(*ParamConfigValue) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Name != nil { + var tmp string + if *src.Name != "" { + tmp = kutils.StringDeepCopy(*src.Name) + } + p.Name = &tmp + } + + if src.Label != nil { + var tmp string + if *src.Label != "" { + tmp = kutils.StringDeepCopy(*src.Label) + } + p.Label = &tmp + } + + var _value *manage.ParamOption + if src.Value != nil { + _value = &manage.ParamOption{} + if err := _value.DeepCopy(src.Value); err != nil { + return err + } + } + p.Value = _value + return nil } diff --git a/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime.go b/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime.go index e4129248d..b69ed1e34 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime.go +++ b/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/apache/thrift/lib/go/thrift" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/common" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/manage" "strings" ) @@ -73,6 +74,8 @@ type ModelConfig struct { TopK *int32 `thrift:"top_k,8,optional" frugal:"8,optional,i32" form:"top_k" json:"top_k,omitempty" query:"top_k"` PresencePenalty *float64 `thrift:"presence_penalty,9,optional" frugal:"9,optional,double" form:"presence_penalty" json:"presence_penalty,omitempty" query:"presence_penalty"` FrequencyPenalty *float64 `thrift:"frequency_penalty,10,optional" frugal:"10,optional,double" form:"frequency_penalty" json:"frequency_penalty,omitempty" query:"frequency_penalty"` + // 与ParamSchema对应 + ParamConfigValues []*ParamConfigValue `thrift:"param_config_values,100,optional" frugal:"100,optional,list" form:"param_config_values" json:"param_config_values,omitempty" query:"param_config_values"` } func NewModelConfig() *ModelConfig { @@ -196,6 +199,18 @@ func (p *ModelConfig) GetFrequencyPenalty() (v float64) { } return *p.FrequencyPenalty } + +var ModelConfig_ParamConfigValues_DEFAULT []*ParamConfigValue + +func (p *ModelConfig) GetParamConfigValues() (v []*ParamConfigValue) { + if p == nil { + return + } + if !p.IsSetParamConfigValues() { + return ModelConfig_ParamConfigValues_DEFAULT + } + return p.ParamConfigValues +} func (p *ModelConfig) SetModelID(val int64) { p.ModelID = val } @@ -226,18 +241,22 @@ func (p *ModelConfig) SetPresencePenalty(val *float64) { func (p *ModelConfig) SetFrequencyPenalty(val *float64) { p.FrequencyPenalty = val } +func (p *ModelConfig) SetParamConfigValues(val []*ParamConfigValue) { + p.ParamConfigValues = val +} var fieldIDToName_ModelConfig = map[int16]string{ - 1: "model_id", - 2: "temperature", - 3: "max_tokens", - 4: "top_p", - 5: "stop", - 6: "tool_choice", - 7: "response_format", - 8: "top_k", - 9: "presence_penalty", - 10: "frequency_penalty", + 1: "model_id", + 2: "temperature", + 3: "max_tokens", + 4: "top_p", + 5: "stop", + 6: "tool_choice", + 7: "response_format", + 8: "top_k", + 9: "presence_penalty", + 10: "frequency_penalty", + 100: "param_config_values", } func (p *ModelConfig) IsSetTemperature() bool { @@ -276,6 +295,10 @@ func (p *ModelConfig) IsSetFrequencyPenalty() bool { return p.FrequencyPenalty != nil } +func (p *ModelConfig) IsSetParamConfigValues() bool { + return p.ParamConfigValues != nil +} + func (p *ModelConfig) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -376,6 +399,14 @@ func (p *ModelConfig) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 100: + if fieldTypeId == thrift.LIST { + if err = p.ReadField100(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -530,6 +561,29 @@ func (p *ModelConfig) ReadField10(iprot thrift.TProtocol) error { p.FrequencyPenalty = _field return nil } +func (p *ModelConfig) ReadField100(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]*ParamConfigValue, 0, size) + values := make([]ParamConfigValue, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + + if err := _elem.Read(iprot); err != nil { + return err + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.ParamConfigValues = _field + return nil +} func (p *ModelConfig) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -577,6 +631,10 @@ func (p *ModelConfig) Write(oprot thrift.TProtocol) (err error) { fieldId = 10 goto WriteFieldError } + if err = p.writeField100(oprot); err != nil { + fieldId = 100 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -781,6 +839,32 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 10 end error: ", p), err) } +func (p *ModelConfig) writeField100(oprot thrift.TProtocol) (err error) { + if p.IsSetParamConfigValues() { + if err = oprot.WriteFieldBegin("param_config_values", thrift.LIST, 100); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRUCT, len(p.ParamConfigValues)); err != nil { + return err + } + for _, v := range p.ParamConfigValues { + if err := v.Write(oprot); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 end error: ", p), err) +} func (p *ModelConfig) String() string { if p == nil { @@ -826,6 +910,9 @@ func (p *ModelConfig) DeepEqual(ano *ModelConfig) bool { if !p.Field10DeepEqual(ano.FrequencyPenalty) { return false } + if !p.Field100DeepEqual(ano.ParamConfigValues) { + return false + } return true } @@ -940,6 +1027,349 @@ func (p *ModelConfig) Field10DeepEqual(src *float64) bool { } return true } +func (p *ModelConfig) Field100DeepEqual(src []*ParamConfigValue) bool { + + if len(p.ParamConfigValues) != len(src) { + return false + } + for i, v := range p.ParamConfigValues { + _src := src[i] + if !v.DeepEqual(_src) { + return false + } + } + return true +} + +type ParamConfigValue struct { + // 传给下游模型的key,与ParamSchema.name对齐 + Name *string `thrift:"name,1,optional" frugal:"1,optional,string" form:"name" json:"name,omitempty" query:"name"` + // 展示名称 + Label *string `thrift:"label,2,optional" frugal:"2,optional,string" form:"label" json:"label,omitempty" query:"label"` + // 传给下游模型的value,与ParamSchema.options对齐 + Value *manage.ParamOption `thrift:"value,3,optional" frugal:"3,optional,manage.ParamOption" form:"value" json:"value,omitempty" query:"value"` +} + +func NewParamConfigValue() *ParamConfigValue { + return &ParamConfigValue{} +} + +func (p *ParamConfigValue) InitDefault() { +} + +var ParamConfigValue_Name_DEFAULT string + +func (p *ParamConfigValue) GetName() (v string) { + if p == nil { + return + } + if !p.IsSetName() { + return ParamConfigValue_Name_DEFAULT + } + return *p.Name +} + +var ParamConfigValue_Label_DEFAULT string + +func (p *ParamConfigValue) GetLabel() (v string) { + if p == nil { + return + } + if !p.IsSetLabel() { + return ParamConfigValue_Label_DEFAULT + } + return *p.Label +} + +var ParamConfigValue_Value_DEFAULT *manage.ParamOption + +func (p *ParamConfigValue) GetValue() (v *manage.ParamOption) { + if p == nil { + return + } + if !p.IsSetValue() { + return ParamConfigValue_Value_DEFAULT + } + return p.Value +} +func (p *ParamConfigValue) SetName(val *string) { + p.Name = val +} +func (p *ParamConfigValue) SetLabel(val *string) { + p.Label = val +} +func (p *ParamConfigValue) SetValue(val *manage.ParamOption) { + p.Value = val +} + +var fieldIDToName_ParamConfigValue = map[int16]string{ + 1: "name", + 2: "label", + 3: "value", +} + +func (p *ParamConfigValue) IsSetName() bool { + return p.Name != nil +} + +func (p *ParamConfigValue) IsSetLabel() bool { + return p.Label != nil +} + +func (p *ParamConfigValue) IsSetValue() bool { + return p.Value != nil +} + +func (p *ParamConfigValue) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRING { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ParamConfigValue[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *ParamConfigValue) ReadField1(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Name = _field + return nil +} +func (p *ParamConfigValue) ReadField2(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Label = _field + return nil +} +func (p *ParamConfigValue) ReadField3(iprot thrift.TProtocol) error { + _field := manage.NewParamOption() + if err := _field.Read(iprot); err != nil { + return err + } + p.Value = _field + return nil +} + +func (p *ParamConfigValue) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("ParamConfigValue"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *ParamConfigValue) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetName() { + if err = oprot.WriteFieldBegin("name", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Name); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *ParamConfigValue) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetLabel() { + if err = oprot.WriteFieldBegin("label", thrift.STRING, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Label); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *ParamConfigValue) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetValue() { + if err = oprot.WriteFieldBegin("value", thrift.STRUCT, 3); err != nil { + goto WriteFieldBeginError + } + if err := p.Value.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} + +func (p *ParamConfigValue) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("ParamConfigValue(%+v)", *p) + +} + +func (p *ParamConfigValue) DeepEqual(ano *ParamConfigValue) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Name) { + return false + } + if !p.Field2DeepEqual(ano.Label) { + return false + } + if !p.Field3DeepEqual(ano.Value) { + return false + } + return true +} + +func (p *ParamConfigValue) Field1DeepEqual(src *string) bool { + + if p.Name == src { + return true + } else if p.Name == nil || src == nil { + return false + } + if strings.Compare(*p.Name, *src) != 0 { + return false + } + return true +} +func (p *ParamConfigValue) Field2DeepEqual(src *string) bool { + + if p.Label == src { + return true + } else if p.Label == nil || src == nil { + return false + } + if strings.Compare(*p.Label, *src) != 0 { + return false + } + return true +} +func (p *ParamConfigValue) Field3DeepEqual(src *manage.ParamOption) bool { + + if !p.Value.DeepEqual(src) { + return false + } + return true +} type Message struct { Role Role `thrift:"role,1,required" frugal:"1,required,string" form:"role,required" json:"role,required" query:"role,required"` diff --git a/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime_validator.go b/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime_validator.go index 9be9670f5..11b5abe28 100644 --- a/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime_validator.go +++ b/backend/kitex_gen/coze/loop/llm/domain/runtime/runtime_validator.go @@ -29,6 +29,14 @@ func (p *ModelConfig) IsValid() error { } return nil } +func (p *ParamConfigValue) IsValid() error { + if p.Value != nil { + if err := p.Value.IsValid(); err != nil { + return fmt.Errorf("field Value not valid, %w", err) + } + } + return nil +} func (p *Message) IsValid() error { if p.ResponseMeta != nil { if err := p.ResponseMeta.IsValid(); err != nil { From a7ac3b4b057e777b0feeece51ce6e68f73f0dd7d Mon Sep 17 00:00:00 2001 From: "chenzhe.29" Date: Tue, 4 Nov 2025 19:00:28 +0800 Subject: [PATCH 4/8] add parameters --- backend/modules/llm/domain/entity/runtime_option.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/backend/modules/llm/domain/entity/runtime_option.go b/backend/modules/llm/domain/entity/runtime_option.go index c7aba9625..b18ee5702 100644 --- a/backend/modules/llm/domain/entity/runtime_option.go +++ b/backend/modules/llm/domain/entity/runtime_option.go @@ -26,6 +26,8 @@ type Options struct { PresencePenalty *float32 // FrequencyPenalty is the frequency penalty for the model, which controls the diversity of the model. FrequencyPenalty *float32 + // Parameters is the extra parameters for the model. + Parameters map[string]string } type Option struct { @@ -167,3 +169,11 @@ func WithPresencePenalty(p float32) Option { }, } } + +func WithParameters(p map[string]string) Option { + return Option{ + apply: func(opts *Options) { + opts.Parameters = p + }, + } +} From ef3abe625a22a10fb4c5dbf3d90fc249116587cc Mon Sep 17 00:00:00 2001 From: "chenzhe.29" Date: Tue, 4 Nov 2025 19:20:30 +0800 Subject: [PATCH 5/8] add parameters --- .../modules/llm/application/convertor/runtime_option.go | 5 ++++- backend/modules/llm/application/runtime.go | 9 ++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/backend/modules/llm/application/convertor/runtime_option.go b/backend/modules/llm/application/convertor/runtime_option.go index 277c80a8a..731b112ac 100644 --- a/backend/modules/llm/application/convertor/runtime_option.go +++ b/backend/modules/llm/application/convertor/runtime_option.go @@ -10,7 +10,7 @@ import ( "github.com/coze-dev/coze-loop/backend/pkg/lang/slices" ) -func ModelAndTools2OptionDOs(modelCfg *druntime.ModelConfig, tools []*druntime.Tool) []entity.Option { +func ModelAndTools2OptionDOs(modelCfg *druntime.ModelConfig, tools []*druntime.Tool, parameters map[string]string) []entity.Option { var opts []entity.Option if modelCfg != nil { if modelCfg.Temperature != nil { @@ -47,6 +47,9 @@ func ModelAndTools2OptionDOs(modelCfg *druntime.ModelConfig, tools []*druntime.T }) opts = append(opts, entity.WithTools(toolsDTO)) } + if parameters != nil { + opts = append(opts, entity.WithParameters(parameters)) + } return opts } diff --git a/backend/modules/llm/application/runtime.go b/backend/modules/llm/application/runtime.go index e4301564e..3605c3589 100644 --- a/backend/modules/llm/application/runtime.go +++ b/backend/modules/llm/application/runtime.go @@ -10,9 +10,6 @@ import ( "strconv" "time" - "github.com/coze-dev/cozeloop-go/spec/tracespec" - "github.com/pkg/errors" - "github.com/coze-dev/coze-loop/backend/infra/limiter" "github.com/coze-dev/coze-loop/backend/infra/looptracer" "github.com/coze-dev/coze-loop/backend/infra/redis" @@ -30,6 +27,8 @@ import ( "github.com/coze-dev/coze-loop/backend/pkg/json" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" + "github.com/coze-dev/cozeloop-go/spec/tracespec" + "github.com/pkg/errors" ) type runtimeApp struct { @@ -77,7 +76,7 @@ func (r *runtimeApp) Chat(ctx context.Context, req *runtime.ChatRequest) (resp * if err != nil { return resp, errorx.NewByCode(llm_errorx.RequestNotValidCode, errorx.WithExtraMsg(err.Error())) } - options := convertor.ModelAndTools2OptionDOs(req.GetModelConfig(), req.GetTools()) + options := convertor.ModelAndTools2OptionDOs(req.GetModelConfig(), req.GetTools(), nil) var respMsg *entity.Message // 5. start span var span looptracer.Span @@ -137,7 +136,7 @@ func (r *runtimeApp) ChatStream(ctx context.Context, req *runtime.ChatRequest, s if err != nil { return errorx.NewByCode(llm_errorx.RequestNotValidCode, errorx.WithExtraMsg(err.Error())) } - options := convertor.ModelAndTools2OptionDOs(req.GetModelConfig(), req.GetTools()) + options := convertor.ModelAndTools2OptionDOs(req.GetModelConfig(), req.GetTools(), nil) // 4. start trace var span looptracer.Span ctx, span = looptracer.GetTracer().StartSpan(ctx, model.Name, tracespec.VModelSpanType, looptracer.WithSpanWorkspaceID(strconv.FormatInt(req.GetBizParam().GetWorkspaceID(), 10))) From c225525b546fc4102de3d07369fdfc3d35d4959f Mon Sep 17 00:00:00 2001 From: "caijialin.626" Date: Thu, 6 Nov 2025 16:00:01 +0800 Subject: [PATCH 6/8] [feat][prompt] add thrift --- .../loop/prompt/domain/prompt/k-prompt.go | 492 ++++++++++++++++++ .../prompt/domain/prompt/prompt_validator.go | 11 + .../coze/loop/prompt/domain/prompt.thrift | 13 + 3 files changed, 516 insertions(+) diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go index 90905c282..8e64adb2f 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go @@ -3287,6 +3287,20 @@ func (p *ModelConfig) FastRead(buf []byte) (int, error) { goto SkipFieldError } } + case 100: + if fieldTypeId == thrift.LIST { + l, err = p.FastReadField100(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } default: l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) offset += l @@ -3417,6 +3431,31 @@ func (p *ModelConfig) FastReadField8(buf []byte) (int, error) { return offset, nil } +func (p *ModelConfig) FastReadField100(buf []byte) (int, error) { + offset := 0 + + _, size, l, err := thrift.Binary.ReadListBegin(buf[offset:]) + offset += l + if err != nil { + return offset, err + } + _field := make([]*ParamConfigValue, 0, size) + values := make([]ParamConfigValue, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + if l, err := _elem.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + + _field = append(_field, _elem) + } + p.ParamConfigValues = _field + return offset, nil +} + func (p *ModelConfig) FastWrite(buf []byte) int { return p.FastWriteNocopy(buf, nil) } @@ -3432,6 +3471,7 @@ func (p *ModelConfig) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { offset += p.fastWriteField6(buf[offset:], w) offset += p.fastWriteField7(buf[offset:], w) offset += p.fastWriteField8(buf[offset:], w) + offset += p.fastWriteField100(buf[offset:], w) } offset += thrift.Binary.WriteFieldStop(buf[offset:]) return offset @@ -3448,6 +3488,7 @@ func (p *ModelConfig) BLength() int { l += p.field6Length() l += p.field7Length() l += p.field8Length() + l += p.field100Length() } l += thrift.Binary.FieldStopLength() return l @@ -3525,6 +3566,22 @@ func (p *ModelConfig) fastWriteField8(buf []byte, w thrift.NocopyWriter) int { return offset } +func (p *ModelConfig) fastWriteField100(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetParamConfigValues() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.LIST, 100) + listBeginOffset := offset + offset += thrift.Binary.ListBeginLength() + var length int + for _, v := range p.ParamConfigValues { + length++ + offset += v.FastWriteNocopy(buf[offset:], w) + } + thrift.Binary.WriteListBegin(buf[listBeginOffset:], thrift.STRUCT, length) + } + return offset +} + func (p *ModelConfig) field1Length() int { l := 0 if p.IsSetModelID() { @@ -3597,6 +3654,19 @@ func (p *ModelConfig) field8Length() int { return l } +func (p *ModelConfig) field100Length() int { + l := 0 + if p.IsSetParamConfigValues() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.ListBeginLength() + for _, v := range p.ParamConfigValues { + _ = v + l += v.BLength() + } + } + return l +} + func (p *ModelConfig) DeepCopy(s interface{}) error { src, ok := s.(*ModelConfig) if !ok { @@ -3643,6 +3713,428 @@ func (p *ModelConfig) DeepCopy(s interface{}) error { p.JSONMode = &tmp } + if src.ParamConfigValues != nil { + p.ParamConfigValues = make([]*ParamConfigValue, 0, len(src.ParamConfigValues)) + for _, elem := range src.ParamConfigValues { + var _elem *ParamConfigValue + if elem != nil { + _elem = &ParamConfigValue{} + if err := _elem.DeepCopy(elem); err != nil { + return err + } + } + + p.ParamConfigValues = append(p.ParamConfigValues, _elem) + } + } + + return nil +} + +func (p *ParamConfigValue) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 3: + if fieldTypeId == thrift.STRUCT { + l, err = p.FastReadField3(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ParamConfigValue[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *ParamConfigValue) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Name = _field + return offset, nil +} + +func (p *ParamConfigValue) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Label = _field + return offset, nil +} + +func (p *ParamConfigValue) FastReadField3(buf []byte) (int, error) { + offset := 0 + _field := NewParamOption() + if l, err := _field.FastRead(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + } + p.Value = _field + return offset, nil +} + +func (p *ParamConfigValue) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *ParamConfigValue) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + offset += p.fastWriteField3(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *ParamConfigValue) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + l += p.field3Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *ParamConfigValue) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetName() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Name) + } + return offset +} + +func (p *ParamConfigValue) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetLabel() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Label) + } + return offset +} + +func (p *ParamConfigValue) fastWriteField3(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetValue() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRUCT, 3) + offset += p.Value.FastWriteNocopy(buf[offset:], w) + } + return offset +} + +func (p *ParamConfigValue) field1Length() int { + l := 0 + if p.IsSetName() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Name) + } + return l +} + +func (p *ParamConfigValue) field2Length() int { + l := 0 + if p.IsSetLabel() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Label) + } + return l +} + +func (p *ParamConfigValue) field3Length() int { + l := 0 + if p.IsSetValue() { + l += thrift.Binary.FieldBeginLength() + l += p.Value.BLength() + } + return l +} + +func (p *ParamConfigValue) DeepCopy(s interface{}) error { + src, ok := s.(*ParamConfigValue) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Name != nil { + var tmp string + if *src.Name != "" { + tmp = kutils.StringDeepCopy(*src.Name) + } + p.Name = &tmp + } + + if src.Label != nil { + var tmp string + if *src.Label != "" { + tmp = kutils.StringDeepCopy(*src.Label) + } + p.Label = &tmp + } + + var _value *ParamOption + if src.Value != nil { + _value = &ParamOption{} + if err := _value.DeepCopy(src.Value); err != nil { + return err + } + } + p.Value = _value + + return nil +} + +func (p *ParamOption) FastRead(buf []byte) (int, error) { + + var err error + var offset int + var l int + var fieldTypeId thrift.TType + var fieldId int16 + for { + fieldTypeId, fieldId, l, err = thrift.Binary.ReadFieldBegin(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField1(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + case 2: + if fieldTypeId == thrift.STRING { + l, err = p.FastReadField2(buf[offset:]) + offset += l + if err != nil { + goto ReadFieldError + } + } else { + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + default: + l, err = thrift.Binary.Skip(buf[offset:], fieldTypeId) + offset += l + if err != nil { + goto SkipFieldError + } + } + } + + return offset, nil +ReadFieldBeginError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ParamOption[fieldId]), err) +SkipFieldError: + return offset, thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) +} + +func (p *ParamOption) FastReadField1(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Value = _field + return offset, nil +} + +func (p *ParamOption) FastReadField2(buf []byte) (int, error) { + offset := 0 + + var _field *string + if v, l, err := thrift.Binary.ReadString(buf[offset:]); err != nil { + return offset, err + } else { + offset += l + _field = &v + } + p.Label = _field + return offset, nil +} + +func (p *ParamOption) FastWrite(buf []byte) int { + return p.FastWriteNocopy(buf, nil) +} + +func (p *ParamOption) FastWriteNocopy(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p != nil { + offset += p.fastWriteField1(buf[offset:], w) + offset += p.fastWriteField2(buf[offset:], w) + } + offset += thrift.Binary.WriteFieldStop(buf[offset:]) + return offset +} + +func (p *ParamOption) BLength() int { + l := 0 + if p != nil { + l += p.field1Length() + l += p.field2Length() + } + l += thrift.Binary.FieldStopLength() + return l +} + +func (p *ParamOption) fastWriteField1(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetValue() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 1) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Value) + } + return offset +} + +func (p *ParamOption) fastWriteField2(buf []byte, w thrift.NocopyWriter) int { + offset := 0 + if p.IsSetLabel() { + offset += thrift.Binary.WriteFieldBegin(buf[offset:], thrift.STRING, 2) + offset += thrift.Binary.WriteStringNocopy(buf[offset:], w, *p.Label) + } + return offset +} + +func (p *ParamOption) field1Length() int { + l := 0 + if p.IsSetValue() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Value) + } + return l +} + +func (p *ParamOption) field2Length() int { + l := 0 + if p.IsSetLabel() { + l += thrift.Binary.FieldBeginLength() + l += thrift.Binary.StringLengthNocopy(*p.Label) + } + return l +} + +func (p *ParamOption) DeepCopy(s interface{}) error { + src, ok := s.(*ParamOption) + if !ok { + return fmt.Errorf("%T's type not matched %T", s, p) + } + + if src.Value != nil { + var tmp string + if *src.Value != "" { + tmp = kutils.StringDeepCopy(*src.Value) + } + p.Value = &tmp + } + + if src.Label != nil { + var tmp string + if *src.Label != "" { + tmp = kutils.StringDeepCopy(*src.Label) + } + p.Label = &tmp + } + return nil } diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go index f6639cac5..c17b3cd64 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt_validator.go @@ -112,6 +112,17 @@ func (p *ToolCallConfig) IsValid() error { func (p *ModelConfig) IsValid() error { return nil } +func (p *ParamConfigValue) IsValid() error { + if p.Value != nil { + if err := p.Value.IsValid(); err != nil { + return fmt.Errorf("field Value not valid, %w", err) + } + } + return nil +} +func (p *ParamOption) IsValid() error { + return nil +} func (p *Message) IsValid() error { return nil } diff --git a/idl/thrift/coze/loop/prompt/domain/prompt.thrift b/idl/thrift/coze/loop/prompt/domain/prompt.thrift index f70477467..170e5869d 100644 --- a/idl/thrift/coze/loop/prompt/domain/prompt.thrift +++ b/idl/thrift/coze/loop/prompt/domain/prompt.thrift @@ -98,6 +98,19 @@ struct ModelConfig { 6: optional double presence_penalty 7: optional double frequency_penalty 8: optional bool json_mode + + 100: optional list param_config_values +} + +struct ParamConfigValue { + 1: optional string name // 传给下游模型的key,与ParamSchema.name对齐 + 2: optional string label // 展示名称 + 3: optional ParamOption value // 传给下游模型的value,与ParamSchema.options对齐 +} + +struct ParamOption { + 1: optional string value // 实际值 + 2: optional string label // 展示值 } struct Message { From 451afcbaafad1b894a2a4185cd17b391e00508b5 Mon Sep 17 00:00:00 2001 From: "caijialin.626" Date: Thu, 6 Nov 2025 16:24:49 +0800 Subject: [PATCH 7/8] [feat][prompt] add paramConfigValue --- .../coze/loop/prompt/domain/prompt/prompt.go | 720 +++++++++++++++++- .../prompt/application/convertor/prompt.go | 104 ++- .../prompt/domain/entity/prompt_detail.go | 28 +- .../prompt/infra/rpc/convertor/chat.go | 55 +- 4 files changed, 858 insertions(+), 49 deletions(-) diff --git a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go index ed908b4aa..f74514d70 100644 --- a/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go +++ b/backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go @@ -4462,14 +4462,15 @@ func (p *ToolCallConfig) Field1DeepEqual(src *ToolChoiceType) bool { } type ModelConfig struct { - ModelID *int64 `thrift:"model_id,1,optional" frugal:"1,optional,i64" json:"model_id" form:"model_id" query:"model_id"` - MaxTokens *int32 `thrift:"max_tokens,2,optional" frugal:"2,optional,i32" form:"max_tokens" json:"max_tokens,omitempty" query:"max_tokens"` - Temperature *float64 `thrift:"temperature,3,optional" frugal:"3,optional,double" form:"temperature" json:"temperature,omitempty" query:"temperature"` - TopK *int32 `thrift:"top_k,4,optional" frugal:"4,optional,i32" form:"top_k" json:"top_k,omitempty" query:"top_k"` - TopP *float64 `thrift:"top_p,5,optional" frugal:"5,optional,double" form:"top_p" json:"top_p,omitempty" query:"top_p"` - PresencePenalty *float64 `thrift:"presence_penalty,6,optional" frugal:"6,optional,double" form:"presence_penalty" json:"presence_penalty,omitempty" query:"presence_penalty"` - FrequencyPenalty *float64 `thrift:"frequency_penalty,7,optional" frugal:"7,optional,double" form:"frequency_penalty" json:"frequency_penalty,omitempty" query:"frequency_penalty"` - JSONMode *bool `thrift:"json_mode,8,optional" frugal:"8,optional,bool" form:"json_mode" json:"json_mode,omitempty" query:"json_mode"` + ModelID *int64 `thrift:"model_id,1,optional" frugal:"1,optional,i64" json:"model_id" form:"model_id" query:"model_id"` + MaxTokens *int32 `thrift:"max_tokens,2,optional" frugal:"2,optional,i32" form:"max_tokens" json:"max_tokens,omitempty" query:"max_tokens"` + Temperature *float64 `thrift:"temperature,3,optional" frugal:"3,optional,double" form:"temperature" json:"temperature,omitempty" query:"temperature"` + TopK *int32 `thrift:"top_k,4,optional" frugal:"4,optional,i32" form:"top_k" json:"top_k,omitempty" query:"top_k"` + TopP *float64 `thrift:"top_p,5,optional" frugal:"5,optional,double" form:"top_p" json:"top_p,omitempty" query:"top_p"` + PresencePenalty *float64 `thrift:"presence_penalty,6,optional" frugal:"6,optional,double" form:"presence_penalty" json:"presence_penalty,omitempty" query:"presence_penalty"` + FrequencyPenalty *float64 `thrift:"frequency_penalty,7,optional" frugal:"7,optional,double" form:"frequency_penalty" json:"frequency_penalty,omitempty" query:"frequency_penalty"` + JSONMode *bool `thrift:"json_mode,8,optional" frugal:"8,optional,bool" form:"json_mode" json:"json_mode,omitempty" query:"json_mode"` + ParamConfigValues []*ParamConfigValue `thrift:"param_config_values,100,optional" frugal:"100,optional,list" form:"param_config_values" json:"param_config_values,omitempty" query:"param_config_values"` } func NewModelConfig() *ModelConfig { @@ -4574,6 +4575,18 @@ func (p *ModelConfig) GetJSONMode() (v bool) { } return *p.JSONMode } + +var ModelConfig_ParamConfigValues_DEFAULT []*ParamConfigValue + +func (p *ModelConfig) GetParamConfigValues() (v []*ParamConfigValue) { + if p == nil { + return + } + if !p.IsSetParamConfigValues() { + return ModelConfig_ParamConfigValues_DEFAULT + } + return p.ParamConfigValues +} func (p *ModelConfig) SetModelID(val *int64) { p.ModelID = val } @@ -4598,16 +4611,20 @@ func (p *ModelConfig) SetFrequencyPenalty(val *float64) { func (p *ModelConfig) SetJSONMode(val *bool) { p.JSONMode = val } +func (p *ModelConfig) SetParamConfigValues(val []*ParamConfigValue) { + p.ParamConfigValues = val +} var fieldIDToName_ModelConfig = map[int16]string{ - 1: "model_id", - 2: "max_tokens", - 3: "temperature", - 4: "top_k", - 5: "top_p", - 6: "presence_penalty", - 7: "frequency_penalty", - 8: "json_mode", + 1: "model_id", + 2: "max_tokens", + 3: "temperature", + 4: "top_k", + 5: "top_p", + 6: "presence_penalty", + 7: "frequency_penalty", + 8: "json_mode", + 100: "param_config_values", } func (p *ModelConfig) IsSetModelID() bool { @@ -4642,6 +4659,10 @@ func (p *ModelConfig) IsSetJSONMode() bool { return p.JSONMode != nil } +func (p *ModelConfig) IsSetParamConfigValues() bool { + return p.ParamConfigValues != nil +} + func (p *ModelConfig) Read(iprot thrift.TProtocol) (err error) { var fieldTypeId thrift.TType var fieldId int16 @@ -4724,6 +4745,14 @@ func (p *ModelConfig) Read(iprot thrift.TProtocol) (err error) { } else if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError } + case 100: + if fieldTypeId == thrift.LIST { + if err = p.ReadField100(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } default: if err = iprot.Skip(fieldTypeId); err != nil { goto SkipFieldError @@ -4841,6 +4870,29 @@ func (p *ModelConfig) ReadField8(iprot thrift.TProtocol) error { p.JSONMode = _field return nil } +func (p *ModelConfig) ReadField100(iprot thrift.TProtocol) error { + _, size, err := iprot.ReadListBegin() + if err != nil { + return err + } + _field := make([]*ParamConfigValue, 0, size) + values := make([]ParamConfigValue, size) + for i := 0; i < size; i++ { + _elem := &values[i] + _elem.InitDefault() + + if err := _elem.Read(iprot); err != nil { + return err + } + + _field = append(_field, _elem) + } + if err := iprot.ReadListEnd(); err != nil { + return err + } + p.ParamConfigValues = _field + return nil +} func (p *ModelConfig) Write(oprot thrift.TProtocol) (err error) { var fieldId int16 @@ -4880,6 +4932,10 @@ func (p *ModelConfig) Write(oprot thrift.TProtocol) (err error) { fieldId = 8 goto WriteFieldError } + if err = p.writeField100(oprot); err != nil { + fieldId = 100 + goto WriteFieldError + } } if err = oprot.WriteFieldStop(); err != nil { goto WriteFieldStopError @@ -5042,6 +5098,32 @@ WriteFieldBeginError: WriteFieldEndError: return thrift.PrependError(fmt.Sprintf("%T write field 8 end error: ", p), err) } +func (p *ModelConfig) writeField100(oprot thrift.TProtocol) (err error) { + if p.IsSetParamConfigValues() { + if err = oprot.WriteFieldBegin("param_config_values", thrift.LIST, 100); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteListBegin(thrift.STRUCT, len(p.ParamConfigValues)); err != nil { + return err + } + for _, v := range p.ParamConfigValues { + if err := v.Write(oprot); err != nil { + return err + } + } + if err := oprot.WriteListEnd(); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 100 end error: ", p), err) +} func (p *ModelConfig) String() string { if p == nil { @@ -5081,6 +5163,9 @@ func (p *ModelConfig) DeepEqual(ano *ModelConfig) bool { if !p.Field8DeepEqual(ano.JSONMode) { return false } + if !p.Field100DeepEqual(ano.ParamConfigValues) { + return false + } return true } @@ -5180,6 +5265,609 @@ func (p *ModelConfig) Field8DeepEqual(src *bool) bool { } return true } +func (p *ModelConfig) Field100DeepEqual(src []*ParamConfigValue) bool { + + if len(p.ParamConfigValues) != len(src) { + return false + } + for i, v := range p.ParamConfigValues { + _src := src[i] + if !v.DeepEqual(_src) { + return false + } + } + return true +} + +type ParamConfigValue struct { + // 传给下游模型的key,与ParamSchema.name对齐 + Name *string `thrift:"name,1,optional" frugal:"1,optional,string" form:"name" json:"name,omitempty" query:"name"` + // 展示名称 + Label *string `thrift:"label,2,optional" frugal:"2,optional,string" form:"label" json:"label,omitempty" query:"label"` + // 传给下游模型的value,与ParamSchema.options对齐 + Value *ParamOption `thrift:"value,3,optional" frugal:"3,optional,ParamOption" form:"value" json:"value,omitempty" query:"value"` +} + +func NewParamConfigValue() *ParamConfigValue { + return &ParamConfigValue{} +} + +func (p *ParamConfigValue) InitDefault() { +} + +var ParamConfigValue_Name_DEFAULT string + +func (p *ParamConfigValue) GetName() (v string) { + if p == nil { + return + } + if !p.IsSetName() { + return ParamConfigValue_Name_DEFAULT + } + return *p.Name +} + +var ParamConfigValue_Label_DEFAULT string + +func (p *ParamConfigValue) GetLabel() (v string) { + if p == nil { + return + } + if !p.IsSetLabel() { + return ParamConfigValue_Label_DEFAULT + } + return *p.Label +} + +var ParamConfigValue_Value_DEFAULT *ParamOption + +func (p *ParamConfigValue) GetValue() (v *ParamOption) { + if p == nil { + return + } + if !p.IsSetValue() { + return ParamConfigValue_Value_DEFAULT + } + return p.Value +} +func (p *ParamConfigValue) SetName(val *string) { + p.Name = val +} +func (p *ParamConfigValue) SetLabel(val *string) { + p.Label = val +} +func (p *ParamConfigValue) SetValue(val *ParamOption) { + p.Value = val +} + +var fieldIDToName_ParamConfigValue = map[int16]string{ + 1: "name", + 2: "label", + 3: "value", +} + +func (p *ParamConfigValue) IsSetName() bool { + return p.Name != nil +} + +func (p *ParamConfigValue) IsSetLabel() bool { + return p.Label != nil +} + +func (p *ParamConfigValue) IsSetValue() bool { + return p.Value != nil +} + +func (p *ParamConfigValue) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRING { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 3: + if fieldTypeId == thrift.STRUCT { + if err = p.ReadField3(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ParamConfigValue[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *ParamConfigValue) ReadField1(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Name = _field + return nil +} +func (p *ParamConfigValue) ReadField2(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Label = _field + return nil +} +func (p *ParamConfigValue) ReadField3(iprot thrift.TProtocol) error { + _field := NewParamOption() + if err := _field.Read(iprot); err != nil { + return err + } + p.Value = _field + return nil +} + +func (p *ParamConfigValue) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("ParamConfigValue"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + if err = p.writeField3(oprot); err != nil { + fieldId = 3 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *ParamConfigValue) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetName() { + if err = oprot.WriteFieldBegin("name", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Name); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *ParamConfigValue) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetLabel() { + if err = oprot.WriteFieldBegin("label", thrift.STRING, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Label); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} +func (p *ParamConfigValue) writeField3(oprot thrift.TProtocol) (err error) { + if p.IsSetValue() { + if err = oprot.WriteFieldBegin("value", thrift.STRUCT, 3); err != nil { + goto WriteFieldBeginError + } + if err := p.Value.Write(oprot); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) +} + +func (p *ParamConfigValue) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("ParamConfigValue(%+v)", *p) + +} + +func (p *ParamConfigValue) DeepEqual(ano *ParamConfigValue) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Name) { + return false + } + if !p.Field2DeepEqual(ano.Label) { + return false + } + if !p.Field3DeepEqual(ano.Value) { + return false + } + return true +} + +func (p *ParamConfigValue) Field1DeepEqual(src *string) bool { + + if p.Name == src { + return true + } else if p.Name == nil || src == nil { + return false + } + if strings.Compare(*p.Name, *src) != 0 { + return false + } + return true +} +func (p *ParamConfigValue) Field2DeepEqual(src *string) bool { + + if p.Label == src { + return true + } else if p.Label == nil || src == nil { + return false + } + if strings.Compare(*p.Label, *src) != 0 { + return false + } + return true +} +func (p *ParamConfigValue) Field3DeepEqual(src *ParamOption) bool { + + if !p.Value.DeepEqual(src) { + return false + } + return true +} + +type ParamOption struct { + // 实际值 + Value *string `thrift:"value,1,optional" frugal:"1,optional,string" form:"value" json:"value,omitempty" query:"value"` + // 展示值 + Label *string `thrift:"label,2,optional" frugal:"2,optional,string" form:"label" json:"label,omitempty" query:"label"` +} + +func NewParamOption() *ParamOption { + return &ParamOption{} +} + +func (p *ParamOption) InitDefault() { +} + +var ParamOption_Value_DEFAULT string + +func (p *ParamOption) GetValue() (v string) { + if p == nil { + return + } + if !p.IsSetValue() { + return ParamOption_Value_DEFAULT + } + return *p.Value +} + +var ParamOption_Label_DEFAULT string + +func (p *ParamOption) GetLabel() (v string) { + if p == nil { + return + } + if !p.IsSetLabel() { + return ParamOption_Label_DEFAULT + } + return *p.Label +} +func (p *ParamOption) SetValue(val *string) { + p.Value = val +} +func (p *ParamOption) SetLabel(val *string) { + p.Label = val +} + +var fieldIDToName_ParamOption = map[int16]string{ + 1: "value", + 2: "label", +} + +func (p *ParamOption) IsSetValue() bool { + return p.Value != nil +} + +func (p *ParamOption) IsSetLabel() bool { + return p.Label != nil +} + +func (p *ParamOption) Read(iprot thrift.TProtocol) (err error) { + var fieldTypeId thrift.TType + var fieldId int16 + + if _, err = iprot.ReadStructBegin(); err != nil { + goto ReadStructBeginError + } + + for { + _, fieldTypeId, fieldId, err = iprot.ReadFieldBegin() + if err != nil { + goto ReadFieldBeginError + } + if fieldTypeId == thrift.STOP { + break + } + + switch fieldId { + case 1: + if fieldTypeId == thrift.STRING { + if err = p.ReadField1(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + case 2: + if fieldTypeId == thrift.STRING { + if err = p.ReadField2(iprot); err != nil { + goto ReadFieldError + } + } else if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + default: + if err = iprot.Skip(fieldTypeId); err != nil { + goto SkipFieldError + } + } + if err = iprot.ReadFieldEnd(); err != nil { + goto ReadFieldEndError + } + } + if err = iprot.ReadStructEnd(); err != nil { + goto ReadStructEndError + } + + return nil +ReadStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) +ReadFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) +ReadFieldError: + return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_ParamOption[fieldId]), err) +SkipFieldError: + return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + +ReadFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) +ReadStructEndError: + return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) +} + +func (p *ParamOption) ReadField1(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Value = _field + return nil +} +func (p *ParamOption) ReadField2(iprot thrift.TProtocol) error { + + var _field *string + if v, err := iprot.ReadString(); err != nil { + return err + } else { + _field = &v + } + p.Label = _field + return nil +} + +func (p *ParamOption) Write(oprot thrift.TProtocol) (err error) { + var fieldId int16 + if err = oprot.WriteStructBegin("ParamOption"); err != nil { + goto WriteStructBeginError + } + if p != nil { + if err = p.writeField1(oprot); err != nil { + fieldId = 1 + goto WriteFieldError + } + if err = p.writeField2(oprot); err != nil { + fieldId = 2 + goto WriteFieldError + } + } + if err = oprot.WriteFieldStop(); err != nil { + goto WriteFieldStopError + } + if err = oprot.WriteStructEnd(); err != nil { + goto WriteStructEndError + } + return nil +WriteStructBeginError: + return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) +WriteFieldError: + return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) +WriteFieldStopError: + return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) +WriteStructEndError: + return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) +} + +func (p *ParamOption) writeField1(oprot thrift.TProtocol) (err error) { + if p.IsSetValue() { + if err = oprot.WriteFieldBegin("value", thrift.STRING, 1); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Value); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) +} +func (p *ParamOption) writeField2(oprot thrift.TProtocol) (err error) { + if p.IsSetLabel() { + if err = oprot.WriteFieldBegin("label", thrift.STRING, 2); err != nil { + goto WriteFieldBeginError + } + if err := oprot.WriteString(*p.Label); err != nil { + return err + } + if err = oprot.WriteFieldEnd(); err != nil { + goto WriteFieldEndError + } + } + return nil +WriteFieldBeginError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) +WriteFieldEndError: + return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) +} + +func (p *ParamOption) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("ParamOption(%+v)", *p) + +} + +func (p *ParamOption) DeepEqual(ano *ParamOption) bool { + if p == ano { + return true + } else if p == nil || ano == nil { + return false + } + if !p.Field1DeepEqual(ano.Value) { + return false + } + if !p.Field2DeepEqual(ano.Label) { + return false + } + return true +} + +func (p *ParamOption) Field1DeepEqual(src *string) bool { + + if p.Value == src { + return true + } else if p.Value == nil || src == nil { + return false + } + if strings.Compare(*p.Value, *src) != 0 { + return false + } + return true +} +func (p *ParamOption) Field2DeepEqual(src *string) bool { + + if p.Label == src { + return true + } else if p.Label == nil || src == nil { + return false + } + if strings.Compare(*p.Label, *src) != 0 { + return false + } + return true +} type Message struct { Role *Role `thrift:"role,1,optional" frugal:"1,optional,string" form:"role" json:"role,omitempty" query:"role"` diff --git a/backend/modules/prompt/application/convertor/prompt.go b/backend/modules/prompt/application/convertor/prompt.go index 716d33777..e04e80530 100644 --- a/backend/modules/prompt/application/convertor/prompt.go +++ b/backend/modules/prompt/application/convertor/prompt.go @@ -386,14 +386,50 @@ func ModelConfigDTO2DO(dto *prompt.ModelConfig) *entity.ModelConfig { } return &entity.ModelConfig{ - ModelID: dto.GetModelID(), - MaxTokens: dto.MaxTokens, - Temperature: dto.Temperature, - TopK: dto.TopK, - TopP: dto.TopP, - PresencePenalty: dto.PresencePenalty, - FrequencyPenalty: dto.FrequencyPenalty, - JSONMode: dto.JSONMode, + ModelID: dto.GetModelID(), + MaxTokens: dto.MaxTokens, + Temperature: dto.Temperature, + TopK: dto.TopK, + TopP: dto.TopP, + PresencePenalty: dto.PresencePenalty, + FrequencyPenalty: dto.FrequencyPenalty, + JSONMode: dto.JSONMode, + ParamConfigValues: BatchParamConfigValueDTO2DO(dto.ParamConfigValues), + } +} + +func BatchParamConfigValueDTO2DO(dtos []*prompt.ParamConfigValue) []*entity.ParamConfigValue { + if dtos == nil { + return nil + } + result := make([]*entity.ParamConfigValue, 0, len(dtos)) + for _, dto := range dtos { + if dto == nil { + continue + } + result = append(result, ParamConfigValueDTO2DO(dto)) + } + return result +} + +func ParamConfigValueDTO2DO(dto *prompt.ParamConfigValue) *entity.ParamConfigValue { + if dto == nil { + return nil + } + return &entity.ParamConfigValue{ + Name: ptr.From(dto.Name), + Label: ptr.From(dto.Label), + Value: ParamOptionDTO2DO(dto.Value), + } +} + +func ParamOptionDTO2DO(dto *prompt.ParamOption) *entity.ParamOption { + if dto == nil { + return nil + } + return &entity.ParamOption{ + Value: ptr.From(dto.Value), + Label: ptr.From(dto.Label), } } @@ -765,14 +801,50 @@ func ModelConfigDO2DTO(do *entity.ModelConfig) *prompt.ModelConfig { return nil } return &prompt.ModelConfig{ - ModelID: ptr.Of(do.ModelID), - MaxTokens: do.MaxTokens, - Temperature: do.Temperature, - TopK: do.TopK, - TopP: do.TopP, - PresencePenalty: do.PresencePenalty, - FrequencyPenalty: do.FrequencyPenalty, - JSONMode: do.JSONMode, + ModelID: ptr.Of(do.ModelID), + MaxTokens: do.MaxTokens, + Temperature: do.Temperature, + TopK: do.TopK, + TopP: do.TopP, + PresencePenalty: do.PresencePenalty, + FrequencyPenalty: do.FrequencyPenalty, + JSONMode: do.JSONMode, + ParamConfigValues: BatchParamConfigValueDO2DTO(do.ParamConfigValues), + } +} + +func BatchParamConfigValueDO2DTO(dos []*entity.ParamConfigValue) []*prompt.ParamConfigValue { + if dos == nil { + return nil + } + result := make([]*prompt.ParamConfigValue, 0, len(dos)) + for _, do := range dos { + if do == nil { + continue + } + result = append(result, ParamConfigValueDO2DTO(do)) + } + return result +} + +func ParamConfigValueDO2DTO(do *entity.ParamConfigValue) *prompt.ParamConfigValue { + if do == nil { + return nil + } + return &prompt.ParamConfigValue{ + Name: ptr.Of(do.Name), + Label: ptr.Of(do.Label), + Value: ParamOptionDO2DTO(do.Value), + } +} + +func ParamOptionDO2DTO(do *entity.ParamOption) *prompt.ParamOption { + if do == nil { + return nil + } + return &prompt.ParamOption{ + Value: ptr.Of(do.Value), + Label: ptr.Of(do.Label), } } diff --git a/backend/modules/prompt/domain/entity/prompt_detail.go b/backend/modules/prompt/domain/entity/prompt_detail.go index bfd4b34f0..faec9421b 100644 --- a/backend/modules/prompt/domain/entity/prompt_detail.go +++ b/backend/modules/prompt/domain/entity/prompt_detail.go @@ -158,14 +158,26 @@ type FunctionCall struct { } type ModelConfig struct { - ModelID int64 `json:"model_id"` - MaxTokens *int32 `json:"max_tokens,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopK *int32 `json:"top_k,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - PresencePenalty *float64 `json:"presence_penalty,omitempty"` - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` - JSONMode *bool `json:"json_mode,omitempty"` + ModelID int64 `json:"model_id"` + MaxTokens *int32 `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK *int32 `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + JSONMode *bool `json:"json_mode,omitempty"` + ParamConfigValues []*ParamConfigValue `json:"param_config_values,omitempty"` +} + +type ParamConfigValue struct { + Name string `json:"name"` + Label string `json:"label"` + Value *ParamOption `json:"value,omitempty"` +} + +type ParamOption struct { + Value string `json:"value"` + Label string `json:"label"` } func (pt *PromptTemplate) formatMessages(messages []*Message, variableVals []*VariableVal) ([]*Message, error) { diff --git a/backend/modules/prompt/infra/rpc/convertor/chat.go b/backend/modules/prompt/infra/rpc/convertor/chat.go index b3aa3b069..b28c68a18 100644 --- a/backend/modules/prompt/infra/rpc/convertor/chat.go +++ b/backend/modules/prompt/infra/rpc/convertor/chat.go @@ -10,6 +10,7 @@ import ( "github.com/vincent-petithory/dataurl" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/common" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/manage" runtimedto "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/runtime" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/runtime" "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/component/rpc" @@ -52,15 +53,16 @@ func ModelConfigDO2DTO(modelConfig *entity.ModelConfig, toolCallConfig *entity.T } } return &runtimedto.ModelConfig{ - ModelID: modelConfig.ModelID, - Temperature: modelConfig.Temperature, - MaxTokens: maxTokens, - TopP: modelConfig.TopP, - ToolChoice: toolChoice, - ResponseFormat: responseFormat, - TopK: modelConfig.TopK, - PresencePenalty: modelConfig.PresencePenalty, - FrequencyPenalty: modelConfig.FrequencyPenalty, + ModelID: modelConfig.ModelID, + Temperature: modelConfig.Temperature, + MaxTokens: maxTokens, + TopP: modelConfig.TopP, + ToolChoice: toolChoice, + ResponseFormat: responseFormat, + TopK: modelConfig.TopK, + PresencePenalty: modelConfig.PresencePenalty, + FrequencyPenalty: modelConfig.FrequencyPenalty, + ParamConfigValues: BatchParamConfigValueDO2DTO(modelConfig.ParamConfigValues), } } @@ -385,3 +387,38 @@ func TokenUsageDTO2DO(dto *runtimedto.TokenUsage) *entity.TokenUsage { OutputTokens: ptr.From(dto.CompletionTokens), } } + +func BatchParamConfigValueDO2DTO(dos []*entity.ParamConfigValue) []*runtimedto.ParamConfigValue { + if dos == nil { + return nil + } + result := make([]*runtimedto.ParamConfigValue, 0, len(dos)) + for _, do := range dos { + if do == nil { + continue + } + result = append(result, ParamConfigValueDO2DTO(do)) + } + return result +} + +func ParamConfigValueDO2DTO(do *entity.ParamConfigValue) *runtimedto.ParamConfigValue { + if do == nil { + return nil + } + return &runtimedto.ParamConfigValue{ + Name: ptr.Of(do.Name), + Label: ptr.Of(do.Label), + Value: ParamOptionDO2DTO(do.Value), + } +} + +func ParamOptionDO2DTO(do *entity.ParamOption) *manage.ParamOption { + if do == nil { + return nil + } + return &manage.ParamOption{ + Value: ptr.Of(do.Value), + Label: ptr.Of(do.Label), + } +} From 413bb6164f95e9867647006c885eaec73ed2fef4 Mon Sep 17 00:00:00 2001 From: "caijialin.626" Date: Thu, 6 Nov 2025 19:07:29 +0800 Subject: [PATCH 8/8] [feat][prompt] add paramConfigValue ut --- .../application/convertor/prompt_test.go | 457 ++++++++++++++++++ .../prompt/infra/rpc/convertor/chat_test.go | 37 ++ 2 files changed, 494 insertions(+) diff --git a/backend/modules/prompt/application/convertor/prompt_test.go b/backend/modules/prompt/application/convertor/prompt_test.go index c2ce98bc4..863f9f314 100644 --- a/backend/modules/prompt/application/convertor/prompt_test.go +++ b/backend/modules/prompt/application/convertor/prompt_test.go @@ -116,6 +116,24 @@ func mockPromptCases() []promptTestCase { ModelID: ptr.Of(int64(789)), Temperature: ptr.Of(0.7), MaxTokens: ptr.Of(int32(1000)), + ParamConfigValues: []*prompt.ParamConfigValue{ + { + Name: ptr.Of("temperature"), + Label: ptr.Of("Temperature"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.7"), + Label: ptr.Of("0.7"), + }, + }, + { + Name: ptr.Of("top_p"), + Label: ptr.Of("Top P"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.9"), + Label: ptr.Of("0.9"), + }, + }, + }, }, Tools: []*prompt.Tool{ { @@ -211,6 +229,24 @@ func mockPromptCases() []promptTestCase { ModelID: 789, Temperature: ptr.Of(0.7), MaxTokens: ptr.Of(int32(1000)), + ParamConfigValues: []*entity.ParamConfigValue{ + { + Name: "temperature", + Label: "Temperature", + Value: &entity.ParamOption{ + Value: "0.7", + Label: "0.7", + }, + }, + { + Name: "top_p", + Label: "Top P", + Value: &entity.ParamOption{ + Value: "0.9", + Label: "0.9", + }, + }, + }, }, Tools: []*entity.Tool{ { @@ -564,3 +600,424 @@ func TestMessageDO2DTO(t *testing.T) { }) } } + +type paramOptionTestCase struct { + name string + dto *prompt.ParamOption + do *entity.ParamOption +} + +func mockParamOptionCases() []paramOptionTestCase { + return []paramOptionTestCase{ + { + name: "nil input", + dto: nil, + do: nil, + }, + { + name: "empty param option", + dto: &prompt.ParamOption{ + Value: ptr.Of(""), + Label: ptr.Of(""), + }, + do: &entity.ParamOption{ + Value: "", + Label: "", + }, + }, + { + name: "basic param option", + dto: &prompt.ParamOption{ + Value: ptr.Of("value1"), + Label: ptr.Of("Label 1"), + }, + do: &entity.ParamOption{ + Value: "value1", + Label: "Label 1", + }, + }, + { + name: "param option with special characters", + dto: &prompt.ParamOption{ + Value: ptr.Of("option_value_123"), + Label: ptr.Of("Option Label (Special: 测试)"), + }, + do: &entity.ParamOption{ + Value: "option_value_123", + Label: "Option Label (Special: 测试)", + }, + }, + } +} + +func TestParamOptionDTO2DO(t *testing.T) { + for _, tt := range mockParamOptionCases() { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.do, ParamOptionDTO2DO(tt.dto)) + }) + } +} + +func TestParamOptionDO2DTO(t *testing.T) { + for _, tt := range mockParamOptionCases() { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.dto, ParamOptionDO2DTO(tt.do)) + }) + } +} + +type paramConfigValueTestCase struct { + name string + dto *prompt.ParamConfigValue + do *entity.ParamConfigValue +} + +func mockParamConfigValueCases() []paramConfigValueTestCase { + return []paramConfigValueTestCase{ + { + name: "nil input", + dto: nil, + do: nil, + }, + { + name: "empty param config value", + dto: &prompt.ParamConfigValue{ + Name: ptr.Of(""), + Label: ptr.Of(""), + Value: nil, + }, + do: &entity.ParamConfigValue{ + Name: "", + Label: "", + Value: nil, + }, + }, + { + name: "basic param config value", + dto: &prompt.ParamConfigValue{ + Name: ptr.Of("temperature"), + Label: ptr.Of("Temperature"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.7"), + Label: ptr.Of("0.7"), + }, + }, + do: &entity.ParamConfigValue{ + Name: "temperature", + Label: "Temperature", + Value: &entity.ParamOption{ + Value: "0.7", + Label: "0.7", + }, + }, + }, + { + name: "param config value with complex option", + dto: &prompt.ParamConfigValue{ + Name: ptr.Of("top_p"), + Label: ptr.Of("Top P"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.9"), + Label: ptr.Of("Top P: 0.9 (Recommended)"), + }, + }, + do: &entity.ParamConfigValue{ + Name: "top_p", + Label: "Top P", + Value: &entity.ParamOption{ + Value: "0.9", + Label: "Top P: 0.9 (Recommended)", + }, + }, + }, + { + name: "param config value without value", + dto: &prompt.ParamConfigValue{ + Name: ptr.Of("max_tokens"), + Label: ptr.Of("Max Tokens"), + Value: nil, + }, + do: &entity.ParamConfigValue{ + Name: "max_tokens", + Label: "Max Tokens", + Value: nil, + }, + }, + } +} + +func TestParamConfigValueDTO2DO(t *testing.T) { + for _, tt := range mockParamConfigValueCases() { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.do, ParamConfigValueDTO2DO(tt.dto)) + }) + } +} + +func TestParamConfigValueDO2DTO(t *testing.T) { + for _, tt := range mockParamConfigValueCases() { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.dto, ParamConfigValueDO2DTO(tt.do)) + }) + } +} + +func TestBatchParamConfigValueDTO2DO(t *testing.T) { + tests := []struct { + name string + dtos []*prompt.ParamConfigValue + dos []*entity.ParamConfigValue + }{ + { + name: "nil input", + dtos: nil, + dos: nil, + }, + { + name: "empty slice", + dtos: []*prompt.ParamConfigValue{}, + dos: []*entity.ParamConfigValue{}, + }, + { + name: "single param config value", + dtos: []*prompt.ParamConfigValue{ + { + Name: ptr.Of("temperature"), + Label: ptr.Of("Temperature"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.7"), + Label: ptr.Of("0.7"), + }, + }, + }, + dos: []*entity.ParamConfigValue{ + { + Name: "temperature", + Label: "Temperature", + Value: &entity.ParamOption{ + Value: "0.7", + Label: "0.7", + }, + }, + }, + }, + { + name: "multiple param config values", + dtos: []*prompt.ParamConfigValue{ + { + Name: ptr.Of("temperature"), + Label: ptr.Of("Temperature"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.7"), + Label: ptr.Of("0.7"), + }, + }, + { + Name: ptr.Of("top_p"), + Label: ptr.Of("Top P"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.9"), + Label: ptr.Of("0.9"), + }, + }, + }, + dos: []*entity.ParamConfigValue{ + { + Name: "temperature", + Label: "Temperature", + Value: &entity.ParamOption{ + Value: "0.7", + Label: "0.7", + }, + }, + { + Name: "top_p", + Label: "Top P", + Value: &entity.ParamOption{ + Value: "0.9", + Label: "0.9", + }, + }, + }, + }, + { + name: "with nil elements (should be skipped)", + dtos: []*prompt.ParamConfigValue{ + { + Name: ptr.Of("temperature"), + Label: ptr.Of("Temperature"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.7"), + Label: ptr.Of("0.7"), + }, + }, + nil, + { + Name: ptr.Of("top_p"), + Label: ptr.Of("Top P"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.9"), + Label: ptr.Of("0.9"), + }, + }, + }, + dos: []*entity.ParamConfigValue{ + { + Name: "temperature", + Label: "Temperature", + Value: &entity.ParamOption{ + Value: "0.7", + Label: "0.7", + }, + }, + { + Name: "top_p", + Label: "Top P", + Value: &entity.ParamOption{ + Value: "0.9", + Label: "0.9", + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.dos, BatchParamConfigValueDTO2DO(tt.dtos)) + }) + } +} + +func TestBatchParamConfigValueDO2DTO(t *testing.T) { + tests := []struct { + name string + dos []*entity.ParamConfigValue + dtos []*prompt.ParamConfigValue + }{ + { + name: "nil input", + dos: nil, + dtos: nil, + }, + { + name: "empty slice", + dos: []*entity.ParamConfigValue{}, + dtos: []*prompt.ParamConfigValue{}, + }, + { + name: "single param config value", + dos: []*entity.ParamConfigValue{ + { + Name: "temperature", + Label: "Temperature", + Value: &entity.ParamOption{ + Value: "0.7", + Label: "0.7", + }, + }, + }, + dtos: []*prompt.ParamConfigValue{ + { + Name: ptr.Of("temperature"), + Label: ptr.Of("Temperature"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.7"), + Label: ptr.Of("0.7"), + }, + }, + }, + }, + { + name: "multiple param config values", + dos: []*entity.ParamConfigValue{ + { + Name: "temperature", + Label: "Temperature", + Value: &entity.ParamOption{ + Value: "0.7", + Label: "0.7", + }, + }, + { + Name: "top_p", + Label: "Top P", + Value: &entity.ParamOption{ + Value: "0.9", + Label: "0.9", + }, + }, + }, + dtos: []*prompt.ParamConfigValue{ + { + Name: ptr.Of("temperature"), + Label: ptr.Of("Temperature"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.7"), + Label: ptr.Of("0.7"), + }, + }, + { + Name: ptr.Of("top_p"), + Label: ptr.Of("Top P"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.9"), + Label: ptr.Of("0.9"), + }, + }, + }, + }, + { + name: "with nil elements (should be skipped)", + dos: []*entity.ParamConfigValue{ + { + Name: "temperature", + Label: "Temperature", + Value: &entity.ParamOption{ + Value: "0.7", + Label: "0.7", + }, + }, + nil, + { + Name: "top_p", + Label: "Top P", + Value: &entity.ParamOption{ + Value: "0.9", + Label: "0.9", + }, + }, + }, + dtos: []*prompt.ParamConfigValue{ + { + Name: ptr.Of("temperature"), + Label: ptr.Of("Temperature"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.7"), + Label: ptr.Of("0.7"), + }, + }, + { + Name: ptr.Of("top_p"), + Label: ptr.Of("Top P"), + Value: &prompt.ParamOption{ + Value: ptr.Of("0.9"), + Label: ptr.Of("0.9"), + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.dtos, BatchParamConfigValueDO2DTO(tt.dos)) + }) + } +} diff --git a/backend/modules/prompt/infra/rpc/convertor/chat_test.go b/backend/modules/prompt/infra/rpc/convertor/chat_test.go index 02a5ad9ee..f9d9dc01f 100644 --- a/backend/modules/prompt/infra/rpc/convertor/chat_test.go +++ b/backend/modules/prompt/infra/rpc/convertor/chat_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/common" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/manage" runtimedto "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/domain/runtime" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/llm/runtime" "github.com/coze-dev/coze-loop/backend/modules/prompt/domain/component/rpc" @@ -87,6 +88,24 @@ func TestLLMCallParamConvert(t *testing.T) { PresencePenalty: nil, FrequencyPenalty: nil, JSONMode: nil, + ParamConfigValues: []*entity.ParamConfigValue{ + { + Name: "temperature", + Label: "Temperature", + Value: &entity.ParamOption{ + Value: "0.5", + Label: "0.5", + }, + }, + { + Name: "top_p", + Label: "Top P", + Value: &entity.ParamOption{ + Value: "0.1", + Label: "0.1", + }, + }, + }, }, }, want: &runtime.ChatRequest{ @@ -97,6 +116,24 @@ func TestLLMCallParamConvert(t *testing.T) { TopP: ptr.Of(0.1), Stop: nil, ToolChoice: ptr.Of(runtimedto.ToolChoiceAuto), + ParamConfigValues: []*runtimedto.ParamConfigValue{ + { + Name: ptr.Of("temperature"), + Label: ptr.Of("Temperature"), + Value: &manage.ParamOption{ + Value: ptr.Of("0.5"), + Label: ptr.Of("0.5"), + }, + }, + { + Name: ptr.Of("top_p"), + Label: ptr.Of("Top P"), + Value: &manage.ParamOption{ + Value: ptr.Of("0.1"), + Label: ptr.Of("0.1"), + }, + }, + }, }, Messages: []*runtimedto.Message{ {