Skip to content

Commit 0ce8a04

Browse files
Implements cancellation tokens (#1281)
1 parent f18e789 commit 0ce8a04

File tree

66 files changed

+617
-532
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+617
-532
lines changed

DevProxy.Abstractions/LanguageModel/BaseLanguageModelClient.cs

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,22 @@
1010

1111
namespace DevProxy.Abstractions.LanguageModel;
1212

13-
public abstract class BaseLanguageModelClient(ILogger logger) : ILanguageModelClient
13+
public abstract class BaseLanguageModelClient(LanguageModelConfiguration configuration, ILogger logger) : ILanguageModelClient
1414
{
15-
private readonly ILogger _logger = logger;
15+
protected LanguageModelConfiguration Configuration { get; } = configuration;
16+
protected ILogger Logger { get; } = logger;
17+
18+
private bool? _lmAvailable;
19+
1620
private readonly ConcurrentDictionary<string, (IEnumerable<ILanguageModelChatCompletionMessage>?, CompletionOptions?)> _promptCache = new();
1721

18-
public virtual async Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(string promptFileName, Dictionary<string, object> parameters)
22+
public virtual async Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(string promptFileName, Dictionary<string, object> parameters, CancellationToken cancellationToken)
1923
{
2024
ArgumentNullException.ThrowIfNull(promptFileName, nameof(promptFileName));
2125

2226
if (!promptFileName.EndsWith(".prompty", StringComparison.OrdinalIgnoreCase))
2327
{
24-
_logger.LogDebug("Prompt file name '{PromptFileName}' does not end with '.prompty'. Appending the extension.", promptFileName);
28+
Logger.LogDebug("Prompt file name '{PromptFileName}' does not end with '.prompty'. Appending the extension.", promptFileName);
2529
promptFileName += ".prompty";
2630
}
2731

@@ -34,35 +38,78 @@ public abstract class BaseLanguageModelClient(ILogger logger) : ILanguageModelCl
3438
return null;
3539
}
3640

37-
return await GenerateChatCompletionAsync(messages, options);
41+
return await GenerateChatCompletionAsync(messages, options, cancellationToken);
42+
}
43+
44+
public async Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(IEnumerable<ILanguageModelChatCompletionMessage> messages, CompletionOptions? options, CancellationToken cancellationToken)
45+
{
46+
if (Configuration is null)
47+
{
48+
return null;
49+
}
50+
51+
if (!await IsEnabledAsync(cancellationToken))
52+
{
53+
Logger.LogDebug("Language model is not available.");
54+
return null;
55+
}
56+
57+
return await GenerateChatCompletionCoreAsync(messages, options, cancellationToken);
58+
}
59+
60+
public async Task<ILanguageModelCompletionResponse?> GenerateCompletionAsync(string prompt, CompletionOptions? options, CancellationToken cancellationToken)
61+
{
62+
if (Configuration is null)
63+
{
64+
return null;
65+
}
66+
67+
if (!await IsEnabledAsync(cancellationToken))
68+
{
69+
Logger.LogDebug("Language model is not available.");
70+
return null;
71+
}
72+
73+
return await GenerateCompletionCoreAsync(prompt, options, cancellationToken);
74+
}
75+
76+
public async Task<bool> IsEnabledAsync(CancellationToken cancellationToken)
77+
{
78+
if (_lmAvailable.HasValue)
79+
{
80+
return _lmAvailable.Value;
81+
}
82+
83+
_lmAvailable = await IsEnabledCoreAsync(cancellationToken);
84+
return _lmAvailable.Value;
3885
}
3986

40-
public virtual Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(IEnumerable<ILanguageModelChatCompletionMessage> messages, CompletionOptions? options = null) => throw new NotImplementedException();
87+
protected abstract IEnumerable<ILanguageModelChatCompletionMessage> ConvertMessages(ChatMessage[] messages);
4188

42-
public virtual Task<ILanguageModelCompletionResponse?> GenerateCompletionAsync(string prompt, CompletionOptions? options = null) => throw new NotImplementedException();
89+
protected abstract Task<ILanguageModelCompletionResponse?> GenerateChatCompletionCoreAsync(IEnumerable<ILanguageModelChatCompletionMessage> messages, CompletionOptions? options, CancellationToken cancellationToken);
4390

44-
public virtual Task<bool> IsEnabledAsync() => throw new NotImplementedException();
91+
protected abstract Task<ILanguageModelCompletionResponse?> GenerateCompletionCoreAsync(string prompt, CompletionOptions? options, CancellationToken cancellationToken);
4592

46-
protected virtual IEnumerable<ILanguageModelChatCompletionMessage> ConvertMessages(ChatMessage[] messages) => throw new NotImplementedException();
93+
protected abstract Task<bool> IsEnabledCoreAsync(CancellationToken cancellationToken);
4794

4895
private (IEnumerable<ILanguageModelChatCompletionMessage>?, CompletionOptions?) LoadPrompt(string promptFileName, Dictionary<string, object> parameters)
4996
{
50-
_logger.LogDebug("Prompt file {PromptFileName} not in the cache. Loading...", promptFileName);
97+
Logger.LogDebug("Prompt file {PromptFileName} not in the cache. Loading...", promptFileName);
5198

5299
var filePath = Path.Combine(ProxyUtils.AppFolder!, "prompts", promptFileName);
53100
if (!File.Exists(filePath))
54101
{
55102
throw new FileNotFoundException($"Prompt file '{filePath}' not found.");
56103
}
57104

58-
_logger.LogDebug("Loading prompt file: {FilePath}", filePath);
105+
Logger.LogDebug("Loading prompt file: {FilePath}", filePath);
59106
var promptContents = File.ReadAllText(filePath);
60107

61108
var prompty = PromptyCore.Prompty.Load(promptContents, []);
62109
if (prompty.Prepare(parameters) is not ChatMessage[] promptyMessages ||
63110
promptyMessages.Length == 0)
64111
{
65-
_logger.LogError("No messages found in the prompt file: {FilePath}", filePath);
112+
Logger.LogError("No messages found in the prompt file: {FilePath}", filePath);
66113
return (null, null);
67114
}
68115

DevProxy.Abstractions/LanguageModel/ILanguageModelClient.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ namespace DevProxy.Abstractions.LanguageModel;
66

77
public interface ILanguageModelClient
88
{
9-
Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(string promptFileName, Dictionary<string, object> parameters);
10-
Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(IEnumerable<ILanguageModelChatCompletionMessage> messages, CompletionOptions? options = null);
11-
Task<ILanguageModelCompletionResponse?> GenerateCompletionAsync(string prompt, CompletionOptions? options = null);
12-
Task<bool> IsEnabledAsync();
9+
Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(string promptFileName, Dictionary<string, object> parameters, CancellationToken cancellationToken);
10+
Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(IEnumerable<ILanguageModelChatCompletionMessage> messages, CompletionOptions? options, CancellationToken cancellationToken);
11+
Task<ILanguageModelCompletionResponse?> GenerateCompletionAsync(string prompt, CompletionOptions? options, CancellationToken cancellationToken);
12+
Task<bool> IsEnabledAsync(CancellationToken cancellationToken);
1313
}

0 commit comments

Comments
 (0)