Skip to content

Commit 6bfae63

Browse files
authored
add support for local models back (#514)
1 parent cdb7d61 commit 6bfae63

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

patchwork/common/client/llm/openai.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def _cached_list_models_from_openai(api_key):
2626

2727

2828
class OpenAiLlmClient(LlmClient):
29-
def __init__(self, api_key: str, base_url=None):
29+
def __init__(self, api_key: str, base_url=None, **kwargs):
3030
self.api_key = api_key
3131
self.base_url = base_url
32-
self.client = OpenAI(api_key=api_key, base_url=base_url)
32+
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwargs)
3333

3434
def __is_not_openai_url(self):
3535
# Some providers/apis only implement the chat completion endpoint.

patchwork/steps/CallLLM/CallLLM.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def __init__(self, inputs: dict):
4141

4242
self.call_limit = int(inputs.get("max_llm_calls", -1))
4343
self.model_args = {key[len("model_") :]: value for key, value in inputs.items() if key.startswith("model_")}
44-
self.client_args = {key[len("client_") :]: value for key, value in inputs.items() if key.startswith("client_")}
4544
self.save_responses_to_file = inputs.get("save_responses_to_file", None)
4645
self.model = inputs.get("model", "gpt-3.5-turbo")
4746
self.allow_truncated = inputs.get("allow_truncated", False)
@@ -55,7 +54,8 @@ def __init__(self, inputs: dict):
5554

5655
openai_key = inputs.get("openai_api_key") or os.environ.get("OPENAI_API_KEY")
5756
if openai_key is not None:
58-
client = OpenAiLlmClient(openai_key)
57+
client_args = {key[len("client_") :]: value for key, value in inputs.items() if key.startswith("client_")}
58+
client = OpenAiLlmClient(openai_key, **client_args)
5959
clients.append(client)
6060

6161
google_key = inputs.get("google_api_key")
@@ -68,6 +68,7 @@ def __init__(self, inputs: dict):
6868
client = AnthropicLlmClient(anthropic_key)
6969
clients.append(client)
7070

71+
7172
if len(clients) == 0:
7273
raise ValueError(
7374
f"Model API key not found.\n"

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "patchwork-cli"
3-
version = "0.0.41"
3+
version = "0.0.42"
44
description = ""
55
authors = ["patched.codes"]
66
license = "AGPL"

0 commit comments

Comments
 (0)