Skip to content

Commit dfd6773

Browse files
authored
Add Mistral Provider Support (#11)
1 parent d939756 commit dfd6773

File tree

7 files changed

+65
-7
lines changed

7 files changed

+65
-7
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ The providers are imported from [providers.py](/os_computer_use/providers.py) an
4747
- HuggingFace Spaces:
4848
- OS-Atlas (grounding)
4949
- ShowUI (grounding)
50+
- Mistral AI (Pixtral for vision, Mistral Large for actions)
5051

5152
If you add a new model or provider, please [make a PR](../../pulls) to this repository with the updated providers.py!
5253

@@ -101,6 +102,8 @@ GROQ_API_KEY=...
101102
GEMINI_API_KEY=...
102103
OPENAI_API_KEY=...
103104
ANTHROPIC_API_KEY=...
105+
# Required: Provide your Hugging Face token to bypass Gradio rate limits.
106+
HF_TOKEN=...
104107
```
105108

106109
### 4. Start the web interface

os_computer_use/config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
grounding_model = providers.OSAtlasProvider()
66
# grounding_model = providers.ShowUIProvider()
77

8-
# vision_model = providers.FireworksProvider("llama3.2")
8+
#vision_model = providers.FireworksProvider("llama3.2")
99
# vision_model = providers.OpenAIProvider("gpt-4o")
1010
# vision_model = providers.AnthropicProvider("claude-3.5-sonnet")
1111
vision_model = providers.GroqProvider("llama3.2")
12+
#vision_model = providers.MistralProvider("pixtral") # pixtral-large-latest has vision capabilities
1213

13-
# action_model = providers.FireworksProvider("llama3.3")
14+
15+
#action_model = providers.FireworksProvider("llama3.3")
1416
# action_model = providers.OpenAIProvider("gpt-4o")
1517
# action_model = providers.AnthropicProvider("claude-3.5-sonnet")
1618
action_model = providers.GroqProvider("llama3.3")
19+
#action_model = providers.MistralProvider("large") # mistral-large-latest for non-vision tasks

os_computer_use/llm_provider.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def create_image_block(self, base64_image):
116116
}
117117

118118
def call(self, messages, functions=None):
119-
120119
# If functions are provided, only return actions
121120
tools = self.create_function_schema(functions) if functions else None
122121
completion = self.completion(messages, tools=tools)
@@ -204,3 +203,20 @@ def call(self, messages, functions=None):
204203
# Only return response text
205204
else:
206205
return text
206+
207+
208+
class MistralBaseProvider(OpenAIBaseProvider):
209+
def create_function_def(self, name, details, properties, required):
210+
# If description is wrapped in a dict, extract the inner string
211+
if isinstance(details.get("description"), dict):
212+
details["description"] = details["description"].get("description", "")
213+
return super().create_function_def(name, details, properties, required)
214+
215+
def call(self, messages, functions=None):
216+
if messages and messages[-1].get("role") == "assistant":
217+
prefix = messages.pop()["content"]
218+
if messages and messages[-1].get("role") == "user":
219+
messages[-1]["content"] = prefix + "\n" + messages[-1].get("content", "")
220+
else:
221+
messages.append({"role": "user", "content": prefix})
222+
return super().call(messages, functions)

os_computer_use/osatlas_provider.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
11
from gradio_client import Client, handle_file
22
from os_computer_use.logging import logger
3-
43
from os_computer_use.grounding import extract_bbox_midpoint
54

5+
import os
6+
7+
68
OSATLAS_HUGGINGFACE_SOURCE = "maxiw/OS-ATLAS"
79
OSATLAS_HUGGINGFACE_MODEL = "OS-Copilot/OS-Atlas-Base-7B"
810
OSATLAS_HUGGINGFACE_API = "/run_example"
911

12+
HF_TOKEN = os.getenv("HF_TOKEN")
13+
1014

1115
class OSAtlasProvider:
1216
"""
1317
The OS-Atlas provider is used to make calls to OS-Atlas.
1418
"""
1519

1620
def __init__(self):
17-
self.client = Client(OSATLAS_HUGGINGFACE_SOURCE)
21+
self.client = Client(OSATLAS_HUGGINGFACE_SOURCE, hf_token=HF_TOKEN)
1822

1923
def call(self, prompt, image_data):
2024
result = self.client.predict(

os_computer_use/providers.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from dotenv import load_dotenv
3-
from os_computer_use.llm_provider import OpenAIBaseProvider, AnthropicBaseProvider
3+
from os_computer_use.llm_provider import OpenAIBaseProvider, AnthropicBaseProvider, MistralBaseProvider
44
from os_computer_use.osatlas_provider import OSAtlasProvider
55
from os_computer_use.showui_provider import ShowUIProvider
66

@@ -62,3 +62,13 @@ class GroqProvider(OpenAIBaseProvider):
6262
"llama3.2": "llama-3.2-90b-vision-preview",
6363
"llama3.3": "llama-3.3-70b-versatile",
6464
}
65+
66+
class MistralProvider(MistralBaseProvider):
67+
base_url = "https://api.mistral.ai/v1"
68+
api_key = os.getenv("MISTRAL_API_KEY")
69+
aliases = {
70+
"small": "mistral-small-latest",
71+
"medium": "mistral-medium-latest",
72+
"large": "mistral-large-latest",
73+
"pixtral": "pixtral-large-latest"
74+
}

os_computer_use/sandbox_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,4 +219,4 @@ def run(self, instruction):
219219

220220
self.messages.append(
221221
Message(logger.log(f"OBSERVATION: {result}", "yellow"))
222-
)
222+
)

tests/llm_provider.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
OpenAIProvider,
44
GroqProvider,
55
FireworksProvider,
6+
MistralProvider,
67
)
78
from os_computer_use.llm_provider import Message
89

@@ -15,6 +16,7 @@
1516
}
1617

1718

19+
1820
# Function to simulate taking a screenshot
1921
def take_screenshot():
2022
with open("./tests/test_screenshot.png", "rb") as f:
@@ -62,3 +64,23 @@ def take_screenshot():
6264
fireworks = FireworksProvider("llama3.2")
6365
print(fireworks.call(toolcall_messages, tools)[1])
6466
print(fireworks.call(messages))
67+
68+
69+
70+
# Pixtral
71+
mistral = MistralProvider("pixtral")
72+
print("\nTesting Mistral :")
73+
print(mistral.call(toolcall_messages, tools)[1])
74+
print(mistral.call(messages))
75+
76+
77+
# Mistral Large (non-vision) using text-only messages
78+
mistral_large = MistralProvider("large") # Using mistral-large-latest for non-vision tasks
79+
text_messages = [Message("What is the capital of France?", role="user")]
80+
print("\nTesting Mistral Large with text-only:")
81+
print(mistral_large.call(text_messages))
82+
83+
# Test tool calls for Mistral Large using text-only messages (no image data)
84+
text_tool_messages = [Message("Click on the submit button", role="user")]
85+
print("\nTesting Mistral Large Tool Calls with text:")
86+
print(mistral_large.call(text_tool_messages, tools)[1])

0 commit comments

Comments
 (0)