Skip to content

Commit 46f7268

Browse files
PhilKescarlrobertoh
andcommitted
feat: use GoogleCompletionRequest.systemInstruction for system prompt (#989)
* feat: use GoogleCompletionRequest.systemInstruction for system prompt * fix(test): req body assertion --------- Co-authored-by: Carl-Robert Linnupuu <carlrobertoh@gmail.com>
1 parent e667cd8 commit 46f7268

File tree

2 files changed

+36
-42
lines changed

2 files changed

+36
-42
lines changed

src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt

+20-27
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class GoogleRequestFactory : BaseRequestFactory() {
3131
.maxOutputTokens(configuration.maxTokens)
3232
.temperature(configuration.temperature.toDouble()).build()
3333
)
34+
.systemInstruction(buildSystemInstruction(params))
3435
.build()
3536
}
3637

@@ -94,33 +95,6 @@ class GoogleRequestFactory : BaseRequestFactory() {
9495
val message = params.message
9596
val messages = mutableListOf<GoogleCompletionContent>()
9697

97-
when (params.conversationType) {
98-
ConversationType.DEFAULT -> {
99-
val selectedPersona = service<PromptsSettings>().state.personas.selectedPersona
100-
if (!selectedPersona.disabled) {
101-
messages.add(
102-
GoogleCompletionContent(
103-
"user",
104-
listOf(PromptsSettings.getSelectedPersonaSystemPrompt())
105-
)
106-
)
107-
messages.add(GoogleCompletionContent("model", listOf("Understood.")))
108-
}
109-
}
110-
111-
ConversationType.FIX_COMPILE_ERRORS -> {
112-
messages.add(
113-
GoogleCompletionContent(
114-
"user",
115-
listOf(service<PromptsSettings>().state.coreActions.fixCompileErrors.instructions)
116-
)
117-
)
118-
messages.add(GoogleCompletionContent("model", listOf("Understood.")))
119-
}
120-
121-
else -> {}
122-
}
123-
12498
for (prevMessage in params.conversation.messages) {
12599
if (params.retry && prevMessage.id == message.id) {
126100
break
@@ -207,4 +181,23 @@ class GoogleRequestFactory : BaseRequestFactory() {
207181

208182
return updatedMessages.filterNotNull()
209183
}
184+
185+
private fun buildSystemInstruction(params: ChatCompletionParameters): String? {
186+
return when (params.conversationType) {
187+
ConversationType.DEFAULT -> {
188+
val selectedPersona = service<PromptsSettings>().state.personas.selectedPersona
189+
return if (!selectedPersona.disabled) {
190+
PromptsSettings.getSelectedPersonaSystemPrompt();
191+
} else {
192+
null
193+
}
194+
}
195+
196+
ConversationType.FIX_COMPILE_ERRORS -> service<PromptsSettings>().state.coreActions.fixCompileErrors.instructions
197+
198+
ConversationType.REVIEW_CHANGES -> service<PromptsSettings>().state.coreActions.reviewChanges.instructions
199+
200+
else -> null
201+
}
202+
}
210203
}

src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultToolwindowChatCompletionRequestHandlerTest.kt

+16-15
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA
55
import ee.carlrobert.codegpt.conversations.ConversationService
66
import ee.carlrobert.codegpt.conversations.message.Message
77
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
8-
import ee.carlrobert.codegpt.settings.persona.PersonaSettings
98
import ee.carlrobert.codegpt.settings.prompts.PromptsSettings
109
import ee.carlrobert.llm.client.http.RequestEntity
1110
import ee.carlrobert.llm.client.http.exchange.NdJsonStreamHttpExchange
@@ -19,7 +18,8 @@ class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() {
1918

2019
fun testOpenAIChatCompletionCall() {
2120
useOpenAIService()
22-
service<PromptsSettings>().state.personas.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
21+
service<PromptsSettings>().state.personas.selectedPersona.instructions =
22+
"TEST_SYSTEM_PROMPT"
2323
val message = Message("TEST_PROMPT")
2424
val conversation = ConversationService.getInstance().startConversation()
2525
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
@@ -58,7 +58,8 @@ class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() {
5858

5959
fun testAzureChatCompletionCall() {
6060
useAzureService()
61-
service<PromptsSettings>().state.personas.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
61+
service<PromptsSettings>().state.personas.selectedPersona.instructions =
62+
"TEST_SYSTEM_PROMPT"
6263
val conversationService = ConversationService.getInstance()
6364
val prevMessage = Message("TEST_PREV_PROMPT")
6465
prevMessage.response = "TEST_PREV_RESPONSE"
@@ -104,7 +105,8 @@ class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() {
104105
fun testLlamaChatCompletionCall() {
105106
useLlamaService()
106107
service<ConfigurationSettings>().state.maxTokens = 99
107-
service<PromptsSettings>().state.personas.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
108+
service<PromptsSettings>().state.personas.selectedPersona.instructions =
109+
"TEST_SYSTEM_PROMPT"
108110
val message = Message("TEST_PROMPT")
109111
val conversation = ConversationService.getInstance().startConversation()
110112
conversation.addMessage(Message("Ping", "Pong"))
@@ -145,7 +147,8 @@ class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() {
145147
fun testOllamaChatCompletionCall() {
146148
useOllamaService()
147149
service<ConfigurationSettings>().state.maxTokens = 99
148-
service<PromptsSettings>().state.personas.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
150+
service<PromptsSettings>().state.personas.selectedPersona.instructions =
151+
"TEST_SYSTEM_PROMPT"
149152
val message = Message("TEST_PROMPT")
150153
val conversation = ConversationService.getInstance().startConversation()
151154
expectOllama(NdJsonStreamHttpExchange { request: RequestEntity ->
@@ -184,24 +187,21 @@ class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() {
184187

185188
fun testGoogleChatCompletionCall() {
186189
useGoogleService()
187-
service<PromptsSettings>().state.personas.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
190+
service<PromptsSettings>().state.personas.selectedPersona.instructions =
191+
"TEST_SYSTEM_PROMPT"
188192
val message = Message("TEST_PROMPT")
189193
val conversation = ConversationService.getInstance().startConversation()
190194
expectGoogle(StreamHttpExchange { request: RequestEntity ->
191195
assertThat(request.uri.path).isEqualTo("/v1/models/gemini-pro:streamGenerateContent")
192196
assertThat(request.method).isEqualTo("POST")
193197
assertThat(request.uri.query).isEqualTo("key=TEST_API_KEY&alt=sse")
194198
assertThat(request.body)
195-
.extracting("contents")
196-
.isEqualTo(
199+
.extracting("contents", "systemInstruction")
200+
.containsExactly(
197201
listOf(
198-
mapOf(
199-
"parts" to listOf(mapOf("text" to "TEST_SYSTEM_PROMPT")),
200-
"role" to "user"
201-
),
202-
mapOf("parts" to listOf(mapOf("text" to "Understood.")), "role" to "model"),
203202
mapOf("parts" to listOf(mapOf("text" to "TEST_PROMPT")), "role" to "user"),
204-
)
203+
),
204+
mapOf("parts" to listOf(mapOf("text" to "TEST_SYSTEM_PROMPT")))
205205
)
206206
listOf(
207207
jsonMapResponse(
@@ -229,7 +229,8 @@ class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() {
229229

230230
fun testCodeGPTServiceChatCompletionCall() {
231231
useCodeGPTService()
232-
service<PromptsSettings>().state.personas.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
232+
service<PromptsSettings>().state.personas.selectedPersona.instructions =
233+
"TEST_SYSTEM_PROMPT"
233234
val message = Message("TEST_PROMPT")
234235
val conversation = ConversationService.getInstance().startConversation()
235236
expectCodeGPT(StreamHttpExchange { request: RequestEntity ->

0 commit comments

Comments
 (0)