Skip to content

Commit 8c62f24

Browse files
caijialin0626kasarolzzw
authored andcommitted
[feat][prompt] prompt support model config extra field (#260)
1 parent 2f77e95 commit 8c62f24

File tree

8 files changed

+214
-3
lines changed

8 files changed

+214
-3
lines changed

backend/kitex_gen/coze/loop/prompt/domain/prompt/k-prompt.go

Lines changed: 56 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go

Lines changed: 77 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/modules/prompt/application/convertor/prompt.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ func ModelConfigDTO2DO(dto *prompt.ModelConfig) *entity.ModelConfig {
411411
PresencePenalty: dto.PresencePenalty,
412412
FrequencyPenalty: dto.FrequencyPenalty,
413413
JSONMode: dto.JSONMode,
414+
Extra: dto.Extra,
414415
}
415416
}
416417

@@ -805,6 +806,7 @@ func ModelConfigDO2DTO(do *entity.ModelConfig) *prompt.ModelConfig {
805806
PresencePenalty: do.PresencePenalty,
806807
FrequencyPenalty: do.FrequencyPenalty,
807808
JSONMode: do.JSONMode,
809+
Extra: do.Extra,
808810
}
809811
}
810812

backend/modules/prompt/application/convertor/prompt_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,3 +646,18 @@ func TestMessageDO2DTO(t *testing.T) {
646646
})
647647
}
648648
}
649+
650+
func TestModelConfigExtraConversion(t *testing.T) {
651+
extra := ptr.Of(`{"foo":"bar"}`)
652+
dto := &prompt.ModelConfig{
653+
Extra: extra,
654+
}
655+
656+
do := ModelConfigDTO2DO(dto)
657+
assert.NotNil(t, do)
658+
assert.Equal(t, extra, do.Extra)
659+
660+
dtoBack := ModelConfigDO2DTO(do)
661+
assert.NotNil(t, dtoBack)
662+
assert.Equal(t, extra, dtoBack.Extra)
663+
}

backend/modules/prompt/application/execute_test.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ func TestOverridePromptParams(t *testing.T) {
519519
ModelConfig: &entity.ModelConfig{
520520
ModelID: 456,
521521
Temperature: ptr.Of(0.7),
522+
Extra: ptr.Of(`{"source":"base"}`),
522523
},
523524
},
524525
},
@@ -586,6 +587,7 @@ func TestOverridePromptParams(t *testing.T) {
586587
ModelID: ptr.Of(int64(789)),
587588
Temperature: ptr.Of(0.9),
588589
MaxTokens: ptr.Of(int32(2000)),
590+
Extra: ptr.Of(`{"source":"override"}`),
589591
},
590592
},
591593
},
@@ -598,6 +600,7 @@ func TestOverridePromptParams(t *testing.T) {
598600
ModelID: 789,
599601
Temperature: ptr.Of(0.9),
600602
MaxTokens: ptr.Of(int32(2000)),
603+
Extra: ptr.Of(`{"source":"override"}`),
601604
},
602605
},
603606
}
@@ -651,10 +654,17 @@ func TestOverridePromptParams(t *testing.T) {
651654
if tt.args.promptDO.PromptCommit.PromptDetail != nil {
652655
promptCopy.PromptCommit.PromptDetail = &entity.PromptDetail{}
653656
if tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig != nil {
657+
orig := tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig
654658
promptCopy.PromptCommit.PromptDetail.ModelConfig = &entity.ModelConfig{
655-
ModelID: tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig.ModelID,
656-
Temperature: tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig.Temperature,
657-
MaxTokens: tt.args.promptDO.PromptCommit.PromptDetail.ModelConfig.MaxTokens,
659+
ModelID: orig.ModelID,
660+
MaxTokens: orig.MaxTokens,
661+
Temperature: orig.Temperature,
662+
TopK: orig.TopK,
663+
TopP: orig.TopP,
664+
PresencePenalty: orig.PresencePenalty,
665+
FrequencyPenalty: orig.FrequencyPenalty,
666+
JSONMode: orig.JSONMode,
667+
Extra: orig.Extra,
658668
}
659669
}
660670
}

backend/modules/prompt/domain/entity/prompt_detail.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ type ModelConfig struct {
177177
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
178178
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
179179
JSONMode *bool `json:"json_mode,omitempty"`
180+
Extra *string `json:"extra,omitempty"`
180181
}
181182

182183
func (pt *PromptTemplate) formatMessages(messages []*Message, variableVals []*VariableVal) ([]*Message, error) {

backend/modules/prompt/domain/service/execute_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,3 +854,52 @@ func TestPromptServiceImpl_Execute(t *testing.T) {
854854
})
855855
}
856856
}
857+
858+
func TestPromptServiceImpl_prepareLLMCallParam_PreservesExtra(t *testing.T) {
859+
t.Parallel()
860+
extra := ptr.Of(`{"foo":"bar"}`)
861+
prompt := &entity.Prompt{
862+
ID: 1,
863+
SpaceID: 42,
864+
PromptKey: "test_prompt",
865+
PromptCommit: &entity.PromptCommit{
866+
CommitInfo: &entity.CommitInfo{
867+
Version: "v1",
868+
},
869+
PromptDetail: &entity.PromptDetail{
870+
ModelConfig: &entity.ModelConfig{
871+
ModelID: 99,
872+
Extra: extra,
873+
JSONMode: ptr.Of(true),
874+
},
875+
PromptTemplate: &entity.PromptTemplate{
876+
TemplateType: entity.TemplateTypeNormal,
877+
Messages: []*entity.Message{
878+
{
879+
Role: entity.RoleSystem,
880+
Content: ptr.Of("System prompt"),
881+
},
882+
},
883+
},
884+
},
885+
},
886+
}
887+
svc := &PromptServiceImpl{}
888+
param := ExecuteParam{
889+
Prompt: prompt,
890+
Messages: []*entity.Message{
891+
{
892+
Role: entity.RoleUser,
893+
Content: ptr.Of("Hi"),
894+
},
895+
},
896+
VariableVals: nil,
897+
Scenario: entity.ScenarioPromptDebug,
898+
}
899+
got, err := svc.prepareLLMCallParam(context.Background(), param)
900+
assert.NoError(t, err)
901+
if assert.NotNil(t, got.ModelConfig) {
902+
assert.Equal(t, extra, got.ModelConfig.Extra)
903+
assert.Equal(t, prompt.PromptCommit.PromptDetail.ModelConfig.Extra, got.ModelConfig.Extra)
904+
}
905+
}

0 commit comments

Comments
 (0)