Skip to content

Commit 76ff2fe

Browse files
committed
Allow including configured models in the built artifact for Jlama
Currently, this only works for fast-jar (the default) deployments Relates to: #991
1 parent dd931ed commit 76ff2fe

File tree

13 files changed

+441
-185
lines changed

13 files changed

+441
-185
lines changed

model-providers/jlama/deployment/src/main/java/io/quarkiverse/langchain4j/jlama/deployment/JlamaAiProcessor.java

Lines changed: 0 additions & 103 deletions
This file was deleted.
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
package io.quarkiverse.langchain4j.jlama.deployment;
2+
3+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL;
4+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.EMBEDDING_MODEL;
5+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.STREAMING_CHAT_MODEL;
6+
7+
import java.io.IOException;
8+
import java.io.UncheckedIOException;
9+
import java.math.BigDecimal;
10+
import java.math.RoundingMode;
11+
import java.nio.file.Files;
12+
import java.nio.file.Path;
13+
import java.util.ArrayList;
14+
import java.util.List;
15+
import java.util.Optional;
16+
17+
import jakarta.enterprise.context.ApplicationScoped;
18+
19+
import org.apache.commons.io.file.PathUtils;
20+
import org.jboss.jandex.AnnotationInstance;
21+
import org.jboss.logging.Logger;
22+
import org.slf4j.LoggerFactory;
23+
24+
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
25+
import com.github.tjake.jlama.safetensors.SafeTensorSupport;
26+
import com.github.tjake.jlama.util.ProgressReporter;
27+
28+
import io.quarkiverse.langchain4j.ModelName;
29+
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
30+
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
31+
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
32+
import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem;
33+
import io.quarkiverse.langchain4j.jlama.JlamaModelRegistry;
34+
import io.quarkiverse.langchain4j.jlama.runtime.JlamaAiRecorder;
35+
import io.quarkiverse.langchain4j.jlama.runtime.config.LangChain4jJlamaConfig;
36+
import io.quarkiverse.langchain4j.jlama.runtime.config.LangChain4jJlamaFixedRuntimeConfig;
37+
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
38+
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
39+
import io.quarkus.builder.item.MultiBuildItem;
40+
import io.quarkus.deployment.IsNormal;
41+
import io.quarkus.deployment.annotations.BuildProducer;
42+
import io.quarkus.deployment.annotations.BuildStep;
43+
import io.quarkus.deployment.annotations.ExecutionTime;
44+
import io.quarkus.deployment.annotations.Produce;
45+
import io.quarkus.deployment.annotations.Record;
46+
import io.quarkus.deployment.builditem.FeatureBuildItem;
47+
import io.quarkus.deployment.builditem.LaunchModeBuildItem;
48+
import io.quarkus.deployment.builditem.ServiceStartBuildItem;
49+
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
50+
import io.quarkus.deployment.console.ConsoleInstalledBuildItem;
51+
import io.quarkus.deployment.console.StartupLogCompressor;
52+
import io.quarkus.deployment.logging.LoggingSetupBuildItem;
53+
import io.quarkus.deployment.pkg.builditem.ArtifactResultBuildItem;
54+
import io.quarkus.deployment.pkg.builditem.JarBuildItem;
55+
import io.quarkus.deployment.pkg.steps.JarResultBuildStep;
56+
57+
public class JlamaProcessor {
58+
59+
private final static Logger LOGGER = Logger.getLogger(JlamaProcessor.class);
60+
61+
private static final String FEATURE = "langchain4j-jlama";
62+
private static final String PROVIDER = "jlama";
63+
private static final org.slf4j.Logger log = LoggerFactory.getLogger(JlamaProcessor.class);
64+
65+
@BuildStep
66+
FeatureBuildItem feature() {
67+
return new FeatureBuildItem(FEATURE);
68+
}
69+
70+
@BuildStep
71+
void nativeSupport(BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer) {
72+
reflectiveClassProducer
73+
.produce(ReflectiveClassBuildItem.builder(PropertyNamingStrategies.SnakeCaseStrategy.class).build());
74+
}
75+
76+
@BuildStep
77+
public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem> chatProducer,
78+
BuildProducer<EmbeddingModelProviderCandidateBuildItem> embeddingProducer,
79+
LangChain4jJlamaBuildTimeConfig config) {
80+
if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) {
81+
chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER));
82+
}
83+
if (config.embeddingModel().enabled().isEmpty() || config.embeddingModel().enabled().get()) {
84+
embeddingProducer.produce(new EmbeddingModelProviderCandidateBuildItem(PROVIDER));
85+
}
86+
}
87+
88+
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
89+
@BuildStep
90+
@Record(ExecutionTime.RUNTIME_INIT)
91+
void generateBeans(JlamaAiRecorder recorder, List<SelectedChatModelProviderBuildItem> selectedChatItem,
92+
List<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
93+
LangChain4jJlamaConfig runtimeConfig,
94+
LangChain4jJlamaFixedRuntimeConfig fixedRuntimeConfig,
95+
BuildProducer<SyntheticBeanBuildItem> beanProducer) {
96+
97+
for (var selected : selectedChatItem) {
98+
if (PROVIDER.equals(selected.getProvider())) {
99+
String configName = selected.getConfigName();
100+
var builder = SyntheticBeanBuildItem.configure(CHAT_MODEL).setRuntimeInit().defaultBean()
101+
.scope(ApplicationScoped.class)
102+
.supplier(recorder.chatModel(runtimeConfig, fixedRuntimeConfig, configName));
103+
addQualifierIfNecessary(builder, configName);
104+
beanProducer.produce(builder.done());
105+
106+
var streamingBuilder = SyntheticBeanBuildItem.configure(STREAMING_CHAT_MODEL).setRuntimeInit()
107+
.defaultBean().scope(ApplicationScoped.class)
108+
.supplier(recorder.streamingChatModel(runtimeConfig, fixedRuntimeConfig, configName));
109+
addQualifierIfNecessary(streamingBuilder, configName);
110+
beanProducer.produce(streamingBuilder.done());
111+
}
112+
}
113+
114+
for (var selected : selectedEmbedding) {
115+
if (PROVIDER.equals(selected.getProvider())) {
116+
String configName = selected.getConfigName();
117+
var builder = SyntheticBeanBuildItem.configure(EMBEDDING_MODEL).setRuntimeInit().defaultBean()
118+
.unremovable().scope(ApplicationScoped.class)
119+
.supplier(recorder.embeddingModel(runtimeConfig, fixedRuntimeConfig, configName));
120+
addQualifierIfNecessary(builder, configName);
121+
beanProducer.produce(builder.done());
122+
}
123+
}
124+
}
125+
126+
private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String configName) {
127+
if (!NamedConfigUtil.isDefault(configName)) {
128+
builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", configName).build());
129+
}
130+
}
131+
132+
@Produce(ServiceStartBuildItem.class)
133+
@BuildStep
134+
void downloadModels(List<SelectedChatModelProviderBuildItem> selectedChatModels,
135+
List<SelectedEmbeddingModelCandidateBuildItem> selectedEmbeddingModels,
136+
LoggingSetupBuildItem loggingSetupBuildItem,
137+
Optional<ConsoleInstalledBuildItem> consoleInstalledBuildItem,
138+
LaunchModeBuildItem launchMode,
139+
LangChain4jJlamaBuildTimeConfig buildTimeConfig,
140+
LangChain4jJlamaFixedRuntimeConfig fixedRuntimeConfig,
141+
BuildProducer<ModelDownloadedBuildItem> modelDownloadedProducer) {
142+
if (!buildTimeConfig.includeModelsInArtifact()) {
143+
return;
144+
}
145+
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(fixedRuntimeConfig.modelsPath());
146+
147+
BigDecimal ONE_HUNDRED = new BigDecimal("100");
148+
149+
if (buildTimeConfig.chatModel().enabled().orElse(true) || buildTimeConfig.embeddingModel().enabled().orElse(true)) {
150+
List<String> modelsNeeded = new ArrayList<>();
151+
for (var selected : selectedChatModels) {
152+
if (PROVIDER.equals(selected.getProvider())) {
153+
String configName = selected.getConfigName();
154+
155+
String modelName = NamedConfigUtil.isDefault(configName)
156+
? fixedRuntimeConfig.defaultConfig().chatModel().modelName()
157+
: fixedRuntimeConfig.namedConfig().get(configName).chatModel().modelName();
158+
modelsNeeded.add(modelName);
159+
}
160+
}
161+
162+
for (var selected : selectedEmbeddingModels) {
163+
if (PROVIDER.equals(selected.getProvider())) {
164+
String configName = selected.getConfigName();
165+
166+
String modelName = NamedConfigUtil.isDefault(configName)
167+
? fixedRuntimeConfig.defaultConfig().embeddingModel().modelName()
168+
: fixedRuntimeConfig.namedConfig().get(configName).embeddingModel().modelName();
169+
modelsNeeded.add(modelName);
170+
}
171+
}
172+
173+
if (!modelsNeeded.isEmpty()) {
174+
StartupLogCompressor compressor = new StartupLogCompressor(
175+
(launchMode.isTest() ? "(test) " : "") + "Jlama model pull:",
176+
consoleInstalledBuildItem,
177+
loggingSetupBuildItem);
178+
179+
for (String modelName : modelsNeeded) {
180+
JlamaModelRegistry.ModelInfo modelInfo = JlamaModelRegistry.ModelInfo.from(modelName);
181+
Path pathOfModelDirOnDisk = SafeTensorSupport.constructLocalModelPath(
182+
registry.getModelCachePath().toAbsolutePath().toString(), modelInfo.owner(),
183+
modelInfo.name());
184+
// Check if the model is already downloaded
185+
// this is done automatically by download model, but we want to provide a good progress experience, so we do it again here
186+
if (Files.exists(pathOfModelDirOnDisk.resolve(".finished"))) {
187+
LOGGER.debug("Model " + modelName + "already exists in " + pathOfModelDirOnDisk);
188+
} else {
189+
// we pull one model at a time and provide progress updates to the user via logging
190+
LOGGER.info("Pulling model " + modelName);
191+
192+
try {
193+
registry.downloadModel(modelName, Optional.empty(), Optional.of(new ProgressReporter() {
194+
@Override
195+
public void update(String filename, long sizeDownloaded, long totalSize) {
196+
// Jlama downloads a bunch of files for each mode of which only the weights file is large
197+
// and makes sense to report progress on
198+
if (totalSize < 100_000) {
199+
return;
200+
}
201+
202+
BigDecimal percentage = new BigDecimal(sizeDownloaded).divide(new BigDecimal(totalSize), 4,
203+
RoundingMode.HALF_DOWN).multiply(ONE_HUNDRED);
204+
BigDecimal progress = percentage.setScale(2, RoundingMode.HALF_DOWN);
205+
if (progress.compareTo(ONE_HUNDRED) >= 0) {
206+
// avoid showing 100% for too long
207+
LOGGER.infof("Verifying and cleaning up\n", progress);
208+
} else {
209+
LOGGER.infof("Progress: %s%%\n", progress);
210+
}
211+
}
212+
}));
213+
} catch (IOException e) {
214+
compressor.closeAndDumpCaptured();
215+
throw new UncheckedIOException(e);
216+
}
217+
}
218+
219+
modelDownloadedProducer.produce(new ModelDownloadedBuildItem(modelName, pathOfModelDirOnDisk));
220+
}
221+
222+
compressor.close();
223+
}
224+
}
225+
226+
}
227+
228+
/**
229+
* When building a fast jar, we can copy the model files into the directory
230+
*
231+
*/
232+
@BuildStep(onlyIf = IsNormal.class)
233+
@Produce(ArtifactResultBuildItem.class)
234+
public void copyToFastJar(List<ModelDownloadedBuildItem> models,
235+
Optional<JarBuildItem> jarBuildItem) {
236+
if (!jarBuildItem.isPresent()) {
237+
return;
238+
}
239+
240+
Path jarPath = jarBuildItem.get().getPath();
241+
if (!JarResultBuildStep.QUARKUS_RUN_JAR.equals(jarPath.getFileName().toString())) {
242+
return;
243+
}
244+
245+
Path quarkusAppDir = jarPath.getParent();
246+
Path jlamaInQuarkusAppDir = quarkusAppDir.resolve("jlama");
247+
248+
for (ModelDownloadedBuildItem bi : models) {
249+
try {
250+
JlamaModelRegistry.ModelInfo modelInfo = JlamaModelRegistry.ModelInfo.from(bi.getModelName());
251+
Path targetDir = jlamaInQuarkusAppDir.resolve(modelInfo.toFileName());
252+
Files.createDirectories(targetDir);
253+
PathUtils.copyDirectory(bi.getDirectory(), targetDir);
254+
} catch (IOException e) {
255+
throw new UncheckedIOException(e);
256+
}
257+
}
258+
259+
}
260+
261+
public static final class ModelDownloadedBuildItem extends MultiBuildItem {
262+
263+
private final String modelName;
264+
private final Path directory;
265+
266+
public ModelDownloadedBuildItem(String modelName, Path directory) {
267+
this.modelName = modelName;
268+
this.directory = directory;
269+
}
270+
271+
public String getModelName() {
272+
return modelName;
273+
}
274+
275+
public Path getDirectory() {
276+
return directory;
277+
}
278+
}
279+
}

0 commit comments

Comments
 (0)