Skip to content

Commit 6e6a499

Browse files
feat: Support Llama 3 model (#479)
* feat: Support Llama 3 model (#478) * Use new InfillPrompt * Switch to lmstudio-community * Use new Prompt * llama.cpp removed the BOS token ggml-org/llama.cpp@a55d8a9 * Add tests * I would prefer a stream based solution * Add 70B models * Add tests for skipping blank system prompt * Remove InfillPrompt for now
1 parent bcb33ae commit 6e6a499

File tree

5 files changed

+155
-7
lines changed

5 files changed

+155
-7
lines changed

src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,31 @@ public enum HuggingFaceModel {
4343
WIZARD_CODER_PYTHON_13B_Q5(13, 5, "WizardCoder-Python-13B-V1.0-GGUF"),
4444
WIZARD_CODER_PYTHON_34B_Q3(34, 3, "WizardCoder-Python-34B-V1.0-GGUF"),
4545
WIZARD_CODER_PYTHON_34B_Q4(34, 4, "WizardCoder-Python-34B-V1.0-GGUF"),
46-
WIZARD_CODER_PYTHON_34B_Q5(34, 5, "WizardCoder-Python-34B-V1.0-GGUF");
46+
WIZARD_CODER_PYTHON_34B_Q5(34, 5, "WizardCoder-Python-34B-V1.0-GGUF"),
47+
48+
LLAMA_3_8B_IQ3_M(8, 3, "Meta-Llama-3-8B-Instruct-IQ3_M.gguf", "lmstudio-community"),
49+
LLAMA_3_8B_Q4_K_M(8, 4, "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf", "lmstudio-community"),
50+
LLAMA_3_8B_Q5_K_M(8, 5, "Meta-Llama-3-8B-Instruct-Q5_K_M.gguf", "lmstudio-community"),
51+
LLAMA_3_8B_Q6_K(8, 6, "Meta-Llama-3-8B-Instruct-Q6_K.gguf", "lmstudio-community"),
52+
LLAMA_3_8B_Q8_0(8, 8, "Meta-Llama-3-8B-Instruct-Q8_0.gguf", "lmstudio-community"),
53+
LLAMA_3_70B_IQ1(70, 1, "Meta-Llama-3-70B-Instruct-IQ1_M.gguf", "lmstudio-community"),
54+
LLAMA_3_70B_IQ2_XS(70, 2, "Meta-Llama-3-70B-Instruct-IQ2_XS.gguf", "lmstudio-community"),
55+
LLAMA_3_70B_Q4_K_M(70, 4, "Meta-Llama-3-70B-Instruct-Q4_K_M.gguf", "lmstudio-community");
4756

4857
private final int parameterSize;
4958
private final int quantization;
5059
private final String modelName;
60+
private final String user;
5161

5262
HuggingFaceModel(int parameterSize, int quantization, String modelName) {
63+
this(parameterSize, quantization, modelName, "TheBloke");
64+
}
65+
66+
HuggingFaceModel(int parameterSize, int quantization, String modelName, String user) {
5367
this.parameterSize = parameterSize;
5468
this.quantization = quantization;
5569
this.modelName = modelName;
70+
this.user = user;
5671
}
5772

5873
public int getParameterSize() {
@@ -68,26 +83,37 @@ public String getCode() {
6883
}
6984

7085
public String getFileName() {
71-
return modelName.toLowerCase().replace("-gguf", format(".Q%d_K_M.gguf", quantization));
86+
if ("TheBloke".equals(user)) {
87+
return modelName.toLowerCase().replace("-gguf", format(".Q%d_K_M.gguf", quantization));
88+
}
89+
return modelName;
7290
}
7391

7492
public URL getFileURL() {
7593
try {
7694
return new URL(
77-
format("https://huggingface.co/TheBloke/%s/resolve/main/%s", modelName, getFileName()));
95+
"https://huggingface.co/%s/%s/resolve/main/%s".formatted(user, getDirectory(), getFileName()));
7896
} catch (MalformedURLException ex) {
7997
throw new RuntimeException(ex);
8098
}
8199
}
82100

83101
public URL getHuggingFaceURL() {
84102
try {
85-
return new URL("https://huggingface.co/TheBloke/" + modelName);
103+
return new URL("https://huggingface.co/%s/%s".formatted(user, getDirectory()));
86104
} catch (MalformedURLException ex) {
87105
throw new RuntimeException(ex);
88106
}
89107
}
90108

109+
private String getDirectory() {
110+
if ("lmstudio-community".equals(user)) {
111+
// Meta-Llama-3-8B-Instruct-Q4_K_M.gguf -> Meta-Llama-3-8B-Instruct-GGUF
112+
return modelName.replaceFirst("-[^.-]+\\.gguf$", "-GGUF");
113+
}
114+
return modelName;
115+
}
116+
91117
@Override
92118
public String toString() {
93119
return format("%d-bit precision", quantization);

src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,24 @@ public enum LlamaModel {
8282
HuggingFaceModel.WIZARD_CODER_PYTHON_13B_Q5,
8383
HuggingFaceModel.WIZARD_CODER_PYTHON_34B_Q3,
8484
HuggingFaceModel.WIZARD_CODER_PYTHON_34B_Q4,
85-
HuggingFaceModel.WIZARD_CODER_PYTHON_34B_Q5));
85+
HuggingFaceModel.WIZARD_CODER_PYTHON_34B_Q5)),
86+
LLAMA_3(
87+
"Llama 3",
88+
"Llama 3 is a family of large language models (LLMs), a collection of pretrained and "
89+
+ "instruction tuned generative text models in 8 and 70B sizes. The Llama 3 instruction "
90+
+ "tuned models are optimized for dialogue use cases and outperform many of the available"
91+
+ " open source chat models on common industry benchmarks. Further, in developing these "
92+
+ "models, we took great care to optimize helpfulness and safety.",
93+
PromptTemplate.LLAMA_3,
94+
List.of(
95+
HuggingFaceModel.LLAMA_3_8B_IQ3_M,
96+
HuggingFaceModel.LLAMA_3_8B_Q4_K_M,
97+
HuggingFaceModel.LLAMA_3_8B_Q5_K_M,
98+
HuggingFaceModel.LLAMA_3_8B_Q6_K,
99+
HuggingFaceModel.LLAMA_3_8B_Q8_0,
100+
HuggingFaceModel.LLAMA_3_70B_IQ1,
101+
HuggingFaceModel.LLAMA_3_70B_IQ2_XS,
102+
HuggingFaceModel.LLAMA_3_70B_Q4_K_M));
86103

87104
private final String label;
88105
private final String description;

src/main/java/ee/carlrobert/codegpt/completions/llama/PromptTemplate.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package ee.carlrobert.codegpt.completions.llama;
22

3+
import static java.util.stream.Stream.concat;
4+
35
import ee.carlrobert.codegpt.conversations.message.Message;
46
import java.util.List;
7+
import java.util.stream.Collectors;
8+
import java.util.stream.Stream;
59

610
public enum PromptTemplate {
711

@@ -55,6 +59,26 @@ public String buildPrompt(String systemPrompt, String userPrompt, List<Message>
5559
.toString();
5660
}
5761
},
62+
LLAMA_3("Llama 3") {
63+
@Override
64+
public String buildPrompt(String systemPrompt, String userPrompt, List<Message> history) {
65+
return concat(concat(Stream.ofNullable(systemPrompt)
66+
.filter(s -> !s.isBlank())
67+
.flatMap(system -> Stream.of(
68+
"<|start_header_id|>system<|end_header_id|>\n\n",
69+
system,
70+
"<|eot_id|>")),
71+
history.stream().flatMap(message -> mapMessage(
72+
message,
73+
"<|start_header_id|>user<|end_header_id|>\n\n",
74+
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
75+
"<|eot_id|>"))), Stream.of(
76+
"<|start_header_id|>user<|end_header_id|>\n\n",
77+
userPrompt,
78+
"<|eot_id|>"))
79+
.collect(Collectors.joining());
80+
}
81+
},
5882
MIXTRAL_INSTRUCT("Mixtral Instruct") {
5983
@Override
6084
public String buildPrompt(String systemPrompt, String userPrompt, List<Message> history) {
@@ -171,4 +195,15 @@ public String buildPrompt(String systemPrompt, String userPrompt, List<Message>
171195
public String toString() {
172196
return label;
173197
}
198+
199+
private static Stream<String> mapMessage(Message message,
200+
String prefix, String infix, String suffix) {
201+
return Stream.of(
202+
prefix,
203+
message.getPrompt(),
204+
infix,
205+
message.getResponse(),
206+
suffix
207+
);
208+
}
174209
}

src/main/kotlin/ee/carlrobert/codegpt/codecompletions/InfillPromptTemplate.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ enum class InfillPromptTemplate(val label: String, val stopTokens: List<String>?
2828
override fun toString(): String {
2929
return label
3030
}
31-
}
31+
}

src/test/kotlin/ee/carlrobert/codegpt/completions/PromptTemplateTest.kt

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@ package ee.carlrobert.codegpt.completions
33
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.ALPACA
44
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.CHAT_ML
55
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA
6+
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA_3
67
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.TORA
78
import ee.carlrobert.codegpt.conversations.message.Message
89
import org.assertj.core.api.Assertions.assertThat
9-
import org.junit.Test
10+
import org.junit.jupiter.api.Test
11+
import org.junit.jupiter.params.ParameterizedTest
12+
import org.junit.jupiter.params.provider.NullAndEmptySource
13+
import org.junit.jupiter.params.provider.ValueSource
1014

1115
class PromptTemplateTest {
1216

@@ -34,6 +38,72 @@ class PromptTemplateTest {
3438
""".trimIndent())
3539
}
3640

41+
@Test
42+
fun shouldBuildLlama3PromptWithoutHistory() {
43+
val prompt = LLAMA_3.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf())
44+
45+
assertThat(prompt).isEqualTo("""
46+
<|start_header_id|>system<|end_header_id|>
47+
48+
TEST_SYSTEM_PROMPT<|eot_id|><|start_header_id|>user<|end_header_id|>
49+
50+
TEST_USER_PROMPT<|eot_id|>""".trimIndent()
51+
)
52+
}
53+
54+
@ParameterizedTest
55+
@NullAndEmptySource
56+
@ValueSource(strings = [" ", "\t", "\n"])
57+
fun shouldBuildLlama3PromptWithoutHistorySkippingBlankSystemPrompt(systemPrompt: String?) {
58+
val prompt = LLAMA_3.buildPrompt(systemPrompt, USER_PROMPT, listOf())
59+
60+
assertThat(prompt).isEqualTo("""
61+
<|start_header_id|>user<|end_header_id|>
62+
63+
TEST_USER_PROMPT<|eot_id|>""".trimIndent()
64+
)
65+
}
66+
67+
@Test
68+
fun shouldBuildLlama3PromptWithHistory() {
69+
val prompt = LLAMA_3.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)
70+
71+
assertThat(prompt).isEqualTo("""
72+
<|start_header_id|>system<|end_header_id|>
73+
74+
TEST_SYSTEM_PROMPT<|eot_id|><|start_header_id|>user<|end_header_id|>
75+
76+
TEST_PREV_PROMPT_1<|eot_id|><|start_header_id|>assistant<|end_header_id|>
77+
78+
TEST_PREV_RESPONSE_1<|eot_id|><|start_header_id|>user<|end_header_id|>
79+
80+
TEST_PREV_PROMPT_2<|eot_id|><|start_header_id|>assistant<|end_header_id|>
81+
82+
TEST_PREV_RESPONSE_2<|eot_id|><|start_header_id|>user<|end_header_id|>
83+
84+
TEST_USER_PROMPT<|eot_id|>""".trimIndent())
85+
}
86+
87+
@ParameterizedTest
88+
@NullAndEmptySource
89+
@ValueSource(strings = [" ", "\t", "\n"])
90+
fun shouldBuildLlama3PromptWithHistorySkippingBlankSystemPrompt(systemPrompt: String?) {
91+
val prompt = LLAMA_3.buildPrompt(systemPrompt, USER_PROMPT, HISTORY)
92+
93+
assertThat(prompt).isEqualTo("""
94+
<|start_header_id|>user<|end_header_id|>
95+
96+
TEST_PREV_PROMPT_1<|eot_id|><|start_header_id|>assistant<|end_header_id|>
97+
98+
TEST_PREV_RESPONSE_1<|eot_id|><|start_header_id|>user<|end_header_id|>
99+
100+
TEST_PREV_PROMPT_2<|eot_id|><|start_header_id|>assistant<|end_header_id|>
101+
102+
TEST_PREV_RESPONSE_2<|eot_id|><|start_header_id|>user<|end_header_id|>
103+
104+
TEST_USER_PROMPT<|eot_id|>""".trimIndent())
105+
}
106+
37107
@Test
38108
fun shouldBuildAlpacaPromptWithHistory() {
39109
val prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)

0 commit comments

Comments
 (0)