From 7aeb196b825e276ead4df21bbe44c6bb348b6d72 Mon Sep 17 00:00:00 2001 From: cryptoc1 Date: Fri, 26 Sep 2025 12:00:53 -0400 Subject: [PATCH 1/3] General Updates: - [NEW] support cancellation of `ChatSession.InitializeSessionFromHistoryAsync` - [NEW] improve usage of `CancellationToken`s in `LlamaExecutorBase` - [FIX] `CS1998` warnings --- LLama.Examples/Examples/QuantizeModel.cs | 4 +- LLama/ChatSession.cs | 23 ++++---- LLama/LLamaContext.cs | 19 +++---- LLama/LLamaExecutorBase.cs | 38 ++++++++------ LLama/LLamaInstructExecutor.cs | 36 +++++++------ LLama/LLamaInteractExecutor.cs | 67 +++++++++++++----------- 6 files changed, 104 insertions(+), 83 deletions(-) diff --git a/LLama.Examples/Examples/QuantizeModel.cs b/LLama.Examples/Examples/QuantizeModel.cs index a1f7ca1bd..863bb0c3a 100644 --- a/LLama.Examples/Examples/QuantizeModel.cs +++ b/LLama.Examples/Examples/QuantizeModel.cs @@ -2,7 +2,7 @@ namespace LLama.Examples.Examples { public class QuantizeModel { - public static async Task Run() + public static Task Run() { string inputPath = UserSettings.GetModelPath(); @@ -20,6 +20,8 @@ public static async Task Run() { Console.WriteLine("Quantization failed!"); } + + return Task.CompletedTask; } } } diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 90119d4fe..bda7472d5 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -76,9 +76,10 @@ public class ChatSession /// The executor for this session /// History for this session /// History Transform for this session + /// A token that cancels the operation /// A new chat session. public static async Task InitializeSessionFromHistoryAsync( - ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null) + ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null, CancellationToken cancellationToken = default) { if (executor is not StatefulExecutorBase statefulExecutor) { @@ -90,7 +91,7 @@ public static async Task InitializeSessionFromHistoryAsync( session = session.WithHistoryTransform(transform); } - await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history)); + await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken); return session; } @@ -311,13 +312,15 @@ public ChatSession RemoveLastMessage() /// Compute KV cache for the message and add it to the chat history. /// /// + /// /// - public async Task AddAndProcessMessage(ChatHistory.Message message) + public async Task AddAndProcessMessage(ChatHistory.Message message, CancellationToken cancellationToken = default) { if (Executor is not StatefulExecutorBase statefulExecutor) { throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages."); } + AddMessage(message); var content = message.Content; if (message.AuthorRole != AuthorRole.Assistant) @@ -328,27 +331,27 @@ public async Task AddAndProcessMessage(ChatHistory.Message message) } } - await statefulExecutor.PrefillPromptAsync(content); + await statefulExecutor.PrefillPromptAsync(content, cancellationToken); return this; } /// /// Compute KV cache for the system message and add it to the chat history. /// - public Task AddAndProcessSystemMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content)); + public Task AddAndProcessSystemMessage(string content, CancellationToken cancellationToken = default) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content), cancellationToken); /// /// Compute KV cache for the user message and add it to the chat history. /// - public Task AddAndProcessUserMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content)); + public Task AddAndProcessUserMessage(string content, CancellationToken cancellationToken = default) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content), cancellationToken); /// /// Compute KV cache for the assistant message and add it to the chat history. /// - public Task AddAndProcessAssistantMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + public Task AddAndProcessAssistantMessage(string content, CancellationToken cancellationToken = default) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content), cancellationToken); /// /// Replace a user message with a new message and remove all messages after the new message. diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 42d76c514..4188f9e5f 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -1,14 +1,14 @@ -using LLama.Native; using System; using System.Collections.Generic; using System.Diagnostics; -using System.Text; using System.IO; using System.IO.MemoryMappedFiles; +using System.Text; +using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; +using LLama.Native; using Microsoft.Extensions.Logging; -using System.Threading; namespace LLama { @@ -73,7 +73,7 @@ public int BatchThreads /// Get the special tokens for the model associated with this context /// public SafeLlamaModelHandle.Vocabulary Vocab { get; } - + /// /// Create a new LLamaContext for the given LLamaWeights /// @@ -396,7 +396,7 @@ public Task DecodeAsync(LLamaBatch batch, CancellationToken cancel { return Task.Run(() => Decode(batch), cancellationToken); } - + /// /// /// @@ -406,10 +406,10 @@ public DecodeResult Decode(LLamaBatchEmbeddings batch) return 0; if (batch.EmbeddingsCount > BatchSize) throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch)); - + return (DecodeResult)NativeHandle.Decode(batch); } - + /// /// /// @@ -425,15 +425,16 @@ public Task DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo /// /// /// + /// /// A tuple, containing the decode result, the number of tokens that have not been decoded yet and the total number of tokens that have been decoded. - public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past) + public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past, CancellationToken cancellationToken = default) { return Task.Run(() => { var past = n_past; var res = NativeHandle.Decode(tokens, id, batch, ref past); return (res.Item1, res.Item2, past); - }); + }, cancellationToken); } #endregion diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 36989006e..0e8d5f115 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -239,36 +239,41 @@ protected virtual void TryReuseMatchingPrefix() /// Decide whether to continue the loop. /// /// + /// /// - protected abstract Task GetLoopCondition(InferStateArgs args); + protected abstract Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken = default); /// /// Preprocess the inputs before the inference. /// /// /// - protected abstract Task PreprocessInputs(string? text, InferStateArgs args); + /// + protected abstract Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken = default); /// /// Do some post processing after the inference. /// /// /// + /// /// - protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args); + protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default); /// /// The core inference logic. /// /// /// - protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args); + /// + protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default); /// /// Save the current state to a file. /// /// - public abstract Task SaveState(string filename); + /// + public abstract Task SaveState(string filename, CancellationToken cancellationToken = default); /// /// Get the current state data. @@ -280,13 +285,15 @@ protected virtual void TryReuseMatchingPrefix() /// Load the state from data. /// /// - public abstract Task LoadState(ExecutorBaseState data); + /// + public abstract Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default); /// /// Load the state from a file. /// /// - public abstract Task LoadState(string filename); + /// + public abstract Task LoadState(string filename, CancellationToken cancellationToken = default); /// @@ -310,15 +317,15 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count }; - await PreprocessInputs(text, args); + await PreprocessInputs(text, args, cancellationToken); - while (await GetLoopCondition(args)) + while (await GetLoopCondition(args, cancellationToken)) { if (cancellationToken.IsCancellationRequested) { break; } - await InferInternal(inferenceParams, args); + await InferInternal(inferenceParams, args, cancellationToken); if (args.ReturnValue) { @@ -326,7 +333,7 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc yield return _decoder.Read(); } - var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args); + var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args, cancellationToken); if (extraOutputs is { Count: > 0 }) { foreach (var item in extraOutputs) @@ -346,8 +353,9 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc /// It could reduce the latency of the first time response if the first input from the user is not immediate. /// /// Prompt to process + /// /// - public virtual async Task PrefillPromptAsync(string prompt) + public virtual async Task PrefillPromptAsync(string prompt, CancellationToken cancellationToken = default) { var inferenceParams = new InferenceParams { @@ -362,11 +370,11 @@ public virtual async Task PrefillPromptAsync(string prompt) NeedToSaveSession = false }; - await PreprocessInputs(prompt, args); + await PreprocessInputs(prompt, args, cancellationToken); // First run adds the prompt to the _embeds - await InferInternal(inferenceParams, args); + await InferInternal(inferenceParams, args, cancellationToken); // Second run puts it through decode - await InferInternal(inferenceParams, args); + await InferInternal(inferenceParams, args, cancellationToken); } /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 331591fba..c65aecc3d 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -1,14 +1,15 @@ -using LLama.Abstractions; -using LLama.Common; -using LLama.Native; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; +using System.Threading; using System.Threading.Tasks; +using LLama.Abstractions; +using LLama.Common; using LLama.Exceptions; +using LLama.Native; using LLama.Sampling; using Microsoft.Extensions.Logging; @@ -65,9 +66,9 @@ public override ExecutorBaseState GetStateData() return state; } /// - public override Task LoadState(ExecutorBaseState data) + public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default) { - if(data is InstructExecutorState state) + if (data is InstructExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; _embed_inps = state.EmbedInps!.ToList(); @@ -91,34 +92,34 @@ public override Task LoadState(ExecutorBaseState data) } /// - public override async Task SaveState(string filename) + public override async Task SaveState(string filename, CancellationToken cancellationToken = default) { var state = (InstructExecutorState)GetStateData(); using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { - await JsonSerializer.SerializeAsync(fs, state); + await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken); } } /// - public override async Task LoadState(string filename) + public override async Task LoadState(string filename, CancellationToken cancellationToken) { using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) { var state = await JsonSerializer.DeserializeAsync(fs); - await LoadState(state!); + await LoadState(state!, cancellationToken); } } /// - protected override Task GetLoopCondition(InferStateArgs args) + protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) { return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run); } /// - protected override Task PreprocessInputs(string? text, InferStateArgs args) + protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken) { - args.Antiprompts ??= [ ]; + args.Antiprompts ??= []; if (!args.Antiprompts.Contains(_instructionPrefix)) args.Antiprompts.Add(_instructionPrefix); @@ -154,19 +155,19 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) } /// - protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) + protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { if (_embed_inps.Count <= _consumedTokensCount) { if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) { args.WaitForInput = true; - return (true, Array.Empty()); + return Task.FromResult<(bool, IReadOnlyList)>((true, [])); } if (_pastTokensCount > 0 && args.WaitForInput) { - return (true, new[] { "\n> " }); + return Task.FromResult<(bool, IReadOnlyList)>((true, ["\n> "])); } } @@ -180,11 +181,12 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) args.RemainedTokens = inferenceParams.MaxTokens; args.WaitForInput = true; } - return (false, Array.Empty()); + + return Task.FromResult<(bool, IReadOnlyList)>((false, [])); } /// - protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { var batch = new LLamaBatch(); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 7c9558ee3..0029b3e2d 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -1,14 +1,15 @@ -using LLama.Common; -using LLama.Native; -using LLama.Abstractions; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; +using System.Threading; using System.Threading.Tasks; +using LLama.Abstractions; +using LLama.Common; using LLama.Exceptions; +using LLama.Native; using LLama.Sampling; using Microsoft.Extensions.Logging; @@ -21,7 +22,7 @@ namespace LLama public class InteractiveExecutor : StatefulExecutorBase { private bool _is_prompt_run = true; - + // LLava private int _EmbedImagePosition = -1; private List _imageEmbedHandles = new List(); @@ -36,7 +37,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null) : base(context, logger) { } - + /// /// /// @@ -46,7 +47,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null) public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger? logger = null) : base(context, clipModel, logger) { - } + } /// public override ExecutorBaseState GetStateData() @@ -68,7 +69,7 @@ public override ExecutorBaseState GetStateData() return state; } /// - public override Task LoadState(ExecutorBaseState data) + public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken) { if (data is InteractiveExecutorState state) { @@ -88,22 +89,24 @@ public override Task LoadState(ExecutorBaseState data) return Task.CompletedTask; } + /// - public override async Task SaveState(string filename) + public override async Task SaveState(string filename, CancellationToken cancellationToken = default) { var state = (InteractiveExecutorState)GetStateData(); - using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) + using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { - await JsonSerializer.SerializeAsync(fs, state); + await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken); } } + /// - public override async Task LoadState(string filename) + public override async Task LoadState(string filename, CancellationToken cancellationToken = default) { using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) { var state = await JsonSerializer.DeserializeAsync(fs); - await LoadState(state!); + await LoadState(state!, cancellationToken); } } @@ -111,13 +114,13 @@ public override async Task LoadState(string filename) /// Define whether to continue the loop to generate responses. /// /// - protected override Task GetLoopCondition(InferStateArgs args) + protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) { return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run); } /// - protected override Task PreprocessInputs(string? text, InferStateArgs args) + protected override async Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken) { if (_is_prompt_run) { @@ -129,7 +132,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) } else { - PreprocessLlava(text, args, true); + await PreprocessLlava(text, args, true); } } else @@ -150,17 +153,15 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) } else { - PreprocessLlava(text, args, false); + await PreprocessLlava(text, args, false); } } } - - return Task.CompletedTask; } /// - private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true ) - { + private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true) + { // If the prompt contains the tag extract this. _imageInPrompt = text.Contains(""); if (_imageInPrompt && IsMultiModal) @@ -191,7 +192,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru { var line_inp = Context.Tokenize(text, false, true); _embed_inps.AddRange(line_inp); - args.RemainedTokens -= line_inp.Length; + args.RemainedTokens -= line_inp.Length; } } return Task.CompletedTask; @@ -203,20 +204,24 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru /// /// /// - protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) + protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { if (_embed_inps.Count <= _consumedTokensCount) { if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) + { args.WaitForInput = true; + } if (_pastTokensCount > 0 && args.WaitForInput) - return (true, Array.Empty()); + { + return Task.FromResult<(bool, IReadOnlyList)>((true, [])); + } } if (_embeds.Count > 0 && _embeds.Last().IsEndOfGeneration(Context.Vocab)) { - return (true, Array.Empty()); + return Task.FromResult<(bool, IReadOnlyList)>((true, [])); } if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1) @@ -225,11 +230,11 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru args.WaitForInput = true; } - return (false, Array.Empty()); + return Task.FromResult<(bool, IReadOnlyList)>((false, [])); } /// - protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { var batch = new LLamaBatch(); @@ -258,18 +263,18 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In // Changes to support Multi-Modal LLMs. // (DecodeResult, int, int) header, end, result; - if (IsMultiModal && _EmbedImagePosition > 0) + if (IsMultiModal && _EmbedImagePosition > 0) { // Tokens previous to the images header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); _pastTokensCount = header.Item3; if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1); - + // Images - foreach( var image in _imageEmbedHandles ) + foreach (var image in _imageEmbedHandles) ClipModel!.EvalImageEmbed(Context, image, ref _pastTokensCount); - + // Post-image Tokens end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); _pastTokensCount = end.Item3; @@ -285,7 +290,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1); } - + if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { From 3c272fca651f9304cf781beff299ab2d293c12b5 Mon Sep 17 00:00:00 2001 From: cryptoc1 Date: Fri, 26 Sep 2025 12:23:27 -0400 Subject: [PATCH 2/3] General Update: [FIX] ensure a `default` value is specified for optional parameters defined via interface --- LLama/LLamaInstructExecutor.cs | 2 +- LLama/LLamaInteractExecutor.cs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index c65aecc3d..d7a8c4a94 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -101,7 +101,7 @@ public override async Task SaveState(string filename, CancellationToken cancella } } /// - public override async Task LoadState(string filename, CancellationToken cancellationToken) + public override async Task LoadState(string filename, CancellationToken cancellationToken = default) { using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 0029b3e2d..b5a75966d 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -69,7 +69,7 @@ public override ExecutorBaseState GetStateData() return state; } /// - public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken) + public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default) { if (data is InteractiveExecutorState state) { @@ -203,6 +203,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru /// /// /// + /// /// protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { From a7abbaebd86c3d3f6e9ab3774d3cdc655ada02a9 Mon Sep 17 00:00:00 2001 From: cryptoc1 Date: Fri, 26 Sep 2025 12:40:39 -0400 Subject: [PATCH 3/3] [FIX] update `LlamaInteractExector.PreprocessLlava` to return `void` --- LLama/LLamaInteractExecutor.cs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index b5a75966d..f05b1c974 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -120,7 +120,7 @@ protected override Task GetLoopCondition(InferStateArgs args, Cancellation } /// - protected override async Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken) + protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken) { if (_is_prompt_run) { @@ -132,7 +132,7 @@ protected override async Task PreprocessInputs(string? text, InferStateArgs args } else { - await PreprocessLlava(text, args, true); + PreprocessLlava(text, args, true); } } else @@ -153,14 +153,16 @@ protected override async Task PreprocessInputs(string? text, InferStateArgs args } else { - await PreprocessLlava(text, args, false); + PreprocessLlava(text, args, false); } } } + + return Task.CompletedTask; } /// - private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true) + private void PreprocessLlava(string text, InferStateArgs args, bool addBos = true) { // If the prompt contains the tag extract this. _imageInPrompt = text.Contains(""); @@ -195,7 +197,6 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru args.RemainedTokens -= line_inp.Length; } } - return Task.CompletedTask; } ///