Skip to content

Commit 5f16213

Browse files
fix: Use System Prompt from user configuration (#454) (#455)
1 parent 0dfaa12 commit 5f16213

File tree

7 files changed

+44
-38
lines changed

7 files changed

+44
-38
lines changed

src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package ee.carlrobert.codegpt.completions;
22

3+
import static ee.carlrobert.codegpt.completions.ConversationType.DEFAULT;
4+
import static ee.carlrobert.codegpt.completions.ConversationType.FIX_COMPILE_ERRORS;
35
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.CUSTOM_SERVICE_API_KEY;
46
import static ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent;
57
import static java.lang.String.format;
@@ -57,6 +59,7 @@
5759
import java.util.Map;
5860
import java.util.NoSuchElementException;
5961
import java.util.Objects;
62+
import java.util.Set;
6063
import java.util.UUID;
6164
import java.util.stream.Collectors;
6265
import java.util.stream.Stream;
@@ -74,6 +77,8 @@ public class CompletionRequestProvider {
7477

7578
public static final String FIX_COMPILE_ERRORS_SYSTEM_PROMPT = getResourceContent(
7679
"/prompts/fix-compile-errors.txt");
80+
private static final Set<ConversationType> OPENAI_SYSTEM_CONVERSATION_TYPES = Set.of(
81+
DEFAULT, FIX_COMPILE_ERRORS);
7782

7883
private final EncodingManager encodingManager = EncodingManager.getInstance();
7984
private final Conversation conversation;
@@ -151,10 +156,8 @@ public LlamaCompletionRequest buildLlamaCompletionRequest(
151156
promptTemplate = settings.getRemoteModelPromptTemplate();
152157
}
153158

154-
var systemPrompt = COMPLETION_SYSTEM_PROMPT;
155-
if (conversationType == ConversationType.FIX_COMPILE_ERRORS) {
156-
systemPrompt = FIX_COMPILE_ERRORS_SYSTEM_PROMPT;
157-
}
159+
var systemPrompt = conversationType == FIX_COMPILE_ERRORS
160+
? FIX_COMPILE_ERRORS_SYSTEM_PROMPT : ConfigurationSettings.getSystemPrompt();
158161

159162
var prompt = promptTemplate.buildPrompt(
160163
systemPrompt,
@@ -257,7 +260,7 @@ public ClaudeCompletionRequest buildAnthropicChatCompletionRequest(
257260
request.setModel(settings.getModel());
258261
request.setMaxTokens(configuration.getMaxTokens());
259262
request.setStream(true);
260-
request.setSystem(COMPLETION_SYSTEM_PROMPT);
263+
request.setSystem(ConfigurationSettings.getSystemPrompt());
261264
List<ClaudeCompletionMessage> messages = conversation.getMessages().stream()
262265
.filter(prevMessage -> prevMessage.getResponse() != null
263266
&& !prevMessage.getResponse().isEmpty())
@@ -284,14 +287,10 @@ public ClaudeCompletionRequest buildAnthropicChatCompletionRequest(
284287
private List<OpenAIChatCompletionMessage> buildMessages(CallParameters callParameters) {
285288
var message = callParameters.getMessage();
286289
var messages = new ArrayList<OpenAIChatCompletionMessage>();
287-
if (callParameters.getConversationType() == ConversationType.DEFAULT) {
288-
messages.add(new OpenAIChatCompletionStandardMessage(
289-
"system",
290-
ConfigurationSettings.getCurrentState().getSystemPrompt()));
291-
}
292-
if (callParameters.getConversationType() == ConversationType.FIX_COMPILE_ERRORS) {
293-
messages.add(
294-
new OpenAIChatCompletionStandardMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT));
290+
if (OPENAI_SYSTEM_CONVERSATION_TYPES.contains(callParameters.getConversationType())) {
291+
String content = DEFAULT == callParameters.getConversationType()
292+
? ConfigurationSettings.getSystemPrompt() : FIX_COMPILE_ERRORS_SYSTEM_PROMPT;
293+
messages.add(new OpenAIChatCompletionStandardMessage("system", content));
295294
}
296295

297296
for (var prevMessage : conversation.getMessages()) {

src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ public void changedUpdate(DocumentEvent e) {
9494
maxTokensField.setValue(configuration.getMaxTokens());
9595

9696
systemPromptTextArea = new JTextArea();
97-
if (configuration.getSystemPrompt().isEmpty()) {
97+
if (configuration.getSystemPrompt().isBlank()) {
9898
// for backward compatibility
9999
systemPromptTextArea.setText(COMPLETION_SYSTEM_PROMPT);
100100
} else {

src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationSettings.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,8 @@ public static ConfigurationState getCurrentState() {
3131
public static ConfigurationSettings getInstance() {
3232
return ApplicationManager.getApplication().getService(ConfigurationSettings.class);
3333
}
34+
35+
public static String getSystemPrompt() {
36+
return getCurrentState().getSystemPrompt();
37+
}
3438
}

src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/TotalTokensDetails.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ public class TotalTokensDetails {
1212
private int referencedFilesTokens;
1313

1414
public TotalTokensDetails(EncodingManager encodingManager) {
15-
systemPromptTokens = encodingManager.countTokens(
16-
ConfigurationSettings.getCurrentState().getSystemPrompt());
15+
systemPromptTokens = encodingManager.countTokens(ConfigurationSettings.getSystemPrompt());
1716
}
1817

1918
public int getSystemPromptTokens() {

src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package ee.carlrobert.codegpt.completions
22

3+
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
34
import ee.carlrobert.codegpt.conversations.ConversationService
45
import ee.carlrobert.codegpt.conversations.message.Message
56
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
@@ -42,6 +43,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
4243
}
4344

4445
fun testChatCompletionRequestWithoutSystemPromptOverride() {
46+
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
4547
val conversation = ConversationService.getInstance().startConversation()
4648
val firstMessage = createDummyMessage(500)
4749
val secondMessage = createDummyMessage(250)
@@ -60,7 +62,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
6062
assertThat(request.messages)
6163
.extracting("role", "content")
6264
.containsExactly(
63-
Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT),
65+
Tuple.tuple("system", COMPLETION_SYSTEM_PROMPT),
6466
Tuple.tuple("user", "TEST_PROMPT"),
6567
Tuple.tuple("assistant", firstMessage.response),
6668
Tuple.tuple("user", "TEST_PROMPT"),
@@ -69,7 +71,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
6971
}
7072

7173
fun testChatCompletionRequestRetry() {
72-
ConfigurationSettings.getCurrentState().systemPrompt = CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
74+
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
7375
val conversation = ConversationService.getInstance().startConversation()
7476
val firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500)
7577
val secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250)
@@ -88,13 +90,14 @@ class CompletionRequestProviderTest : IntegrationTest() {
8890
assertThat(request.messages)
8991
.extracting("role", "content")
9092
.containsExactly(
91-
Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT),
93+
Tuple.tuple("system", "TEST_SYSTEM_PROMPT"),
9294
Tuple.tuple("user", "FIRST_TEST_PROMPT"),
9395
Tuple.tuple("assistant", firstMessage.response),
9496
Tuple.tuple("user", "SECOND_TEST_PROMPT"))
9597
}
9698

9799
fun testReducedChatCompletionRequest() {
100+
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
98101
val conversation = ConversationService.getInstance().startConversation()
99102
conversation.addMessage(createDummyMessage(50))
100103
conversation.addMessage(createDummyMessage(100))
@@ -116,7 +119,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
116119
assertThat(request.messages)
117120
.extracting("role", "content")
118121
.containsExactly(
119-
Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT),
122+
Tuple.tuple("system", COMPLETION_SYSTEM_PROMPT),
120123
Tuple.tuple("user", "TEST_PROMPT"),
121124
Tuple.tuple("assistant", remainingMessage.response),
122125
Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT"))

src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.kt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package ee.carlrobert.codegpt.completions
22

33
import ee.carlrobert.codegpt.CodeGPTPlugin
4-
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
54
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA
65
import ee.carlrobert.codegpt.conversations.ConversationService
76
import ee.carlrobert.codegpt.conversations.message.Message
@@ -20,6 +19,7 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
2019

2120
fun testOpenAIChatCompletionCall() {
2221
useOpenAIService()
22+
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
2323
val message = Message("TEST_PROMPT")
2424
val conversation = ConversationService.getInstance().startConversation()
2525
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
@@ -34,7 +34,7 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
3434
.containsExactly(
3535
"gpt-4",
3636
listOf(
37-
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
37+
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
3838
mapOf("role" to "user", "content" to "TEST_PROMPT")))
3939
listOf(
4040
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
@@ -50,6 +50,7 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
5050

5151
fun testAzureChatCompletionCall() {
5252
useAzureService()
53+
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
5354
val conversationService = ConversationService.getInstance()
5455
val prevMessage = Message("TEST_PREV_PROMPT")
5556
prevMessage.response = "TEST_PREV_RESPONSE"
@@ -66,7 +67,7 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
6667
.extracting("messages")
6768
.isEqualTo(
6869
listOf(
69-
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
70+
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
7071
mapOf("role" to "user", "content" to "TEST_PREV_PROMPT"),
7172
mapOf("role" to "assistant", "content" to "TEST_PREV_RESPONSE"),
7273
mapOf("role" to "user", "content" to "TEST_PROMPT")))
@@ -138,6 +139,7 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
138139
fun testLlamaChatCompletionCall() {
139140
useLlamaService()
140141
ConfigurationSettings.getCurrentState().maxTokens = 99
142+
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
141143
val message = Message("TEST_PROMPT")
142144
val conversation = ConversationService.getInstance().startConversation()
143145
conversation.addMessage(Message("Ping", "Pong"))
@@ -151,7 +153,7 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
151153
"stream")
152154
.containsExactly(
153155
LLAMA.buildPrompt(
154-
COMPLETION_SYSTEM_PROMPT,
156+
"TEST_SYSTEM_PROMPT",
155157
"TEST_PROMPT",
156158
conversation.messages),
157159
99,

src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package ee.carlrobert.codegpt.toolwindow.chat
33
import ee.carlrobert.codegpt.CodeGPTKeys
44
import ee.carlrobert.codegpt.EncodingManager
55
import ee.carlrobert.codegpt.ReferencedFile
6-
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
76
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
87
import ee.carlrobert.codegpt.completions.ConversationType
98
import ee.carlrobert.codegpt.completions.HuggingFaceModel
@@ -31,7 +30,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
3130

3231
fun testSendingOpenAIMessage() {
3332
useOpenAIService()
34-
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
33+
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
3534
val message = Message("Hello!")
3635
val conversation = ConversationService.getInstance().startConversation()
3736
val panel = ChatToolWindowTabPanel(project, conversation)
@@ -46,7 +45,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
4645
.containsExactly(
4746
"gpt-4",
4847
listOf(
49-
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
48+
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
5049
mapOf("role" to "user", "content" to "Hello!")))
5150
listOf(
5251
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
@@ -68,7 +67,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
6867
"userPromptTokens",
6968
"highlightedTokens")
7069
.containsExactly(
71-
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
70+
encodingManager.countTokens("TEST_SYSTEM_PROMPT"),
7271
encodingManager.countTokens(message.prompt),
7372
0,
7473
0)
@@ -93,7 +92,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
9392
ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"),
9493
ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3")))
9594
useOpenAIService()
96-
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
95+
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
9796
val message = Message("TEST_MESSAGE")
9897
message.userMessage = "TEST_MESSAGE"
9998
message.referencedFilePaths = listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")
@@ -110,7 +109,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
110109
.containsExactly(
111110
"gpt-4",
112111
listOf(
113-
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
112+
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
114113
mapOf("role" to "user", "content" to """
115114
Use the following context to answer question at the end:
116115
@@ -153,7 +152,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
153152
"userPromptTokens",
154153
"highlightedTokens")
155154
.containsExactly(
156-
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
155+
encodingManager.countTokens("TEST_SYSTEM_PROMPT"),
157156
encodingManager.countTokens(message.prompt),
158157
0,
159158
0)
@@ -180,7 +179,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
180179
val testImagePath = Objects.requireNonNull(javaClass.getResource("/images/test-image.png")).path
181180
project.putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, testImagePath)
182181
useOpenAIService("gpt-4-vision-preview")
183-
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
182+
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
184183
val message = Message("TEST_MESSAGE")
185184
val conversation = ConversationService.getInstance().startConversation()
186185
val panel = ChatToolWindowTabPanel(project, conversation)
@@ -196,7 +195,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
196195
.containsExactly(
197196
"gpt-4-vision-preview",
198197
listOf(
199-
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
198+
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
200199
mapOf("role" to "user", "content" to listOf(
201200
mapOf(
202201
"type" to "image_url",
@@ -226,7 +225,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
226225
"userPromptTokens",
227226
"highlightedTokens")
228227
.containsExactly(
229-
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
228+
encodingManager.countTokens("TEST_SYSTEM_PROMPT"),
230229
encodingManager.countTokens(message.prompt),
231230
0,
232231
0)
@@ -256,7 +255,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
256255
ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"),
257256
ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3")))
258257
useOpenAIService()
259-
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
258+
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
260259
val message = Message("TEST_MESSAGE")
261260
message.userMessage = "TEST_MESSAGE"
262261
message.referencedFilePaths = listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")
@@ -316,7 +315,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
316315
"userPromptTokens",
317316
"highlightedTokens")
318317
.containsExactly(
319-
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
318+
encodingManager.countTokens("TEST_SYSTEM_PROMPT"),
320319
encodingManager.countTokens(message.prompt),
321320
0,
322321
0)
@@ -342,7 +341,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
342341
fun testSendingLlamaMessage() {
343342
useLlamaService()
344343
val configurationState = ConfigurationSettings.getCurrentState()
345-
configurationState.systemPrompt = COMPLETION_SYSTEM_PROMPT
344+
configurationState.systemPrompt = "TEST_SYSTEM_PROMPT"
346345
configurationState.maxTokens = 1000
347346
configurationState.temperature = 0.1
348347
val llamaSettings = LlamaSettings.getCurrentState()
@@ -369,7 +368,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
369368
"repeat_penalty")
370369
.containsExactly(
371370
LLAMA.buildPrompt(
372-
COMPLETION_SYSTEM_PROMPT,
371+
"TEST_SYSTEM_PROMPT",
373372
"TEST_PROMPT",
374373
conversation.messages),
375374
configurationState.maxTokens,

0 commit comments

Comments
 (0)