Skip to content

Commit b48a209

Browse files
[Feature Request] Integrate Grok image (#3040)
Co-authored-by: Wendong-Fan <133094783+Wendong-Fan@users.noreply.github.com>
1 parent dec651e commit b48a209

File tree

7 files changed

+161
-39
lines changed

7 files changed

+161
-39
lines changed

.env

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,7 @@
130130
# E2B
131131
# E2B_API_KEY="Fill your e2b or e2b-compatible sandbox provider API Key here"
132132
# E2B_DOMAIN="Fill your custom e2b domain here"
133+
134+
# Grok API key
135+
# XAI_API_KEY="Fill your Grok API Key here"
136+
# XAI_API_BASE_URL="Fill your Grok API Base URL here"

camel/toolkits/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from .math_toolkit import MathToolkit
2424
from .search_toolkit import SearchToolkit
2525
from .weather_toolkit import WeatherToolkit
26-
from .openai_image_toolkit import OpenAIImageToolkit
26+
from .image_generation_toolkit import ImageGenToolkit, OpenAIImageToolkit
2727
from .ask_news_toolkit import AskNewsToolkit, AsyncAskNewsToolkit
2828
from .linkedin_toolkit import LinkedInToolkit
2929
from .reddit_toolkit import RedditToolkit
@@ -102,7 +102,7 @@
102102
'SearchToolkit',
103103
'SlackToolkit',
104104
'WhatsAppToolkit',
105-
'OpenAIImageToolkit',
105+
'ImageGenToolkit',
106106
'TwitterToolkit',
107107
'WeatherToolkit',
108108
'RetrievalToolkit',
@@ -151,7 +151,7 @@
151151
'PlaywrightMCPToolkit',
152152
'WolframAlphaToolkit',
153153
'BohriumToolkit',
154-
'OpenAIImageToolkit',
154+
'OpenAIImageToolkit', # Backward compatibility
155155
'TaskPlanningToolkit',
156156
'HybridBrowserToolkit',
157157
'EdgeOnePagesMCPToolkit',

camel/toolkits/openai_image_toolkit.py renamed to camel/toolkits/image_generation_toolkit.py

Lines changed: 98 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import base64
1616
import os
1717
from io import BytesIO
18-
from typing import List, Literal, Optional, Union
18+
from typing import ClassVar, List, Literal, Optional, Tuple, Union
1919

2020
from openai import OpenAI
2121
from PIL import Image
@@ -29,21 +29,32 @@
2929

3030

3131
@MCPServer()
32-
class OpenAIImageToolkit(BaseToolkit):
33-
r"""A class toolkit for image generation using OpenAI's
34-
Image Generation API.
35-
"""
36-
37-
@api_keys_required(
38-
[
39-
("api_key", "OPENAI_API_KEY"),
40-
]
41-
)
32+
class ImageGenToolkit(BaseToolkit):
33+
r"""A class toolkit for image generation using Grok and OpenAI models."""
34+
35+
GROK_MODELS: ClassVar[List[str]] = [
36+
"grok-2-image",
37+
"grok-2-image-latest",
38+
"grok-2-image-1212",
39+
]
40+
OPENAI_MODELS: ClassVar[List[str]] = [
41+
"gpt-image-1",
42+
"dall-e-3",
43+
"dall-e-2",
44+
]
45+
4246
def __init__(
4347
self,
4448
model: Optional[
45-
Literal["gpt-image-1", "dall-e-3", "dall-e-2"]
46-
] = "gpt-image-1",
49+
Literal[
50+
"gpt-image-1",
51+
"dall-e-3",
52+
"dall-e-2",
53+
"grok-2-image",
54+
"grok-2-image-latest",
55+
"grok-2-image-1212",
56+
]
57+
] = "dall-e-3",
4758
timeout: Optional[float] = None,
4859
api_key: Optional[str] = None,
4960
url: Optional[str] = None,
@@ -72,12 +83,12 @@ def __init__(
7283
# NOTE: Some arguments are set in the constructor to prevent the agent
7384
# from making invalid API calls with model-specific parameters. For
7485
# example, the 'style' argument is only supported by 'dall-e-3'.
75-
r"""Initializes a new instance of the OpenAIImageToolkit class.
86+
r"""Initializes a new instance of the ImageGenToolkit class.
7687
7788
Args:
7889
api_key (Optional[str]): The API key for authenticating
79-
with the OpenAI service. (default: :obj:`None`)
80-
url (Optional[str]): The url to the OpenAI service.
90+
with the image model service. (default: :obj:`None`)
91+
url (Optional[str]): The url to the image model service.
8192
(default: :obj:`None`)
8293
model (Optional[str]): The model to use.
8394
(default: :obj:`"dall-e-3"`)
@@ -103,9 +114,23 @@ def __init__(
103114
image.(default: :obj:`"image_save"`)
104115
"""
105116
super().__init__(timeout=timeout)
106-
api_key = api_key or os.environ.get("OPENAI_API_KEY")
107-
url = url or os.environ.get("OPENAI_API_BASE_URL")
108-
self.client = OpenAI(api_key=api_key, base_url=url)
117+
if model not in self.GROK_MODELS + self.OPENAI_MODELS:
118+
available_models = sorted(self.OPENAI_MODELS + self.GROK_MODELS)
119+
raise ValueError(
120+
f"Unsupported model: {model}. "
121+
f"Supported models are: {available_models}"
122+
)
123+
124+
# Set default url for Grok models
125+
url = "https://api.x.ai/v1" if model in self.GROK_MODELS else url
126+
127+
api_key, base_url = (
128+
self.get_openai_credentials(url, api_key)
129+
if model in self.OPENAI_MODELS
130+
else self.get_grok_credentials(url, api_key)
131+
)
132+
133+
self.client = OpenAI(api_key=api_key, base_url=base_url)
109134
self.model = model
110135
self.size = size
111136
self.quality = quality
@@ -139,7 +164,7 @@ def base64_to_image(self, base64_string: str) -> Optional[Image.Image]:
139164
return None
140165

141166
def _build_base_params(self, prompt: str, n: Optional[int] = None) -> dict:
142-
r"""Build base parameters dict for OpenAI API calls.
167+
r"""Build base parameters dict for Image Model API calls.
143168
144169
Args:
145170
prompt (str): The text prompt for the image operation.
@@ -153,6 +178,10 @@ def _build_base_params(self, prompt: str, n: Optional[int] = None) -> dict:
153178
# basic parameters supported by all models
154179
if n is not None:
155180
params["n"] = n # type: ignore[assignment]
181+
182+
if self.model in self.GROK_MODELS:
183+
return params
184+
156185
if self.size is not None:
157186
params["size"] = self.size
158187

@@ -179,16 +208,18 @@ def _build_base_params(self, prompt: str, n: Optional[int] = None) -> dict:
179208
params["quality"] = self.quality
180209
if self.background is not None:
181210
params["background"] = self.background
182-
183211
return params
184212

185213
def _handle_api_response(
186-
self, response, image_name: Union[str, List[str]], operation: str
214+
self,
215+
response,
216+
image_name: Union[str, List[str]],
217+
operation: str,
187218
) -> str:
188-
r"""Handle API response from OpenAI image operations.
219+
r"""Handle API response from image operations.
189220
190221
Args:
191-
response: The response object from OpenAI API.
222+
response: The response object from image model API.
192223
image_name (Union[str, List[str]]): Name(s) for the saved image
193224
file(s). If str, the same name is used for all images (will
194225
cause error for multiple images). If list, must have exactly
@@ -198,8 +229,9 @@ def _handle_api_response(
198229
Returns:
199230
str: Success message with image path/URL or error message.
200231
"""
232+
source = "Grok" if self.model in self.GROK_MODELS else "OpenAI"
201233
if response.data is None or len(response.data) == 0:
202-
error_msg = "No image data returned from OpenAI API."
234+
error_msg = f"No image data returned from {source} API."
203235
logger.error(error_msg)
204236
return error_msg
205237

@@ -283,7 +315,7 @@ def generate_image(
283315
image_name: Union[str, List[str]] = "image.png",
284316
n: int = 1,
285317
) -> str:
286-
r"""Generate an image using OpenAI's Image Generation models.
318+
r"""Generate an image using image models.
287319
The generated image will be saved locally (for ``b64_json`` response
288320
formats) or an image URL will be returned (for ``url`` response
289321
formats).
@@ -309,15 +341,50 @@ def generate_image(
309341
logger.error(error_msg)
310342
return error_msg
311343

344+
@api_keys_required([("api_key", "XAI_API_KEY")])
345+
def get_grok_credentials(self, url, api_key) -> Tuple[str, str]: # type: ignore[return-value]
346+
r"""Get API credentials for the specified Grok model.
347+
348+
Args:
349+
url (str): The base URL for the Grok API.
350+
api_key (str): The API key for the Grok API.
351+
352+
Returns:
353+
tuple: (api_key, base_url)
354+
"""
355+
356+
# Get credentials based on model type
357+
api_key = api_key or os.getenv("XAI_API_KEY")
358+
return api_key, url
359+
360+
@api_keys_required([("api_key", "OPENAI_API_KEY")])
361+
def get_openai_credentials(self, url, api_key) -> Tuple[str, str | None]: # type: ignore[return-value]
362+
r"""Get API credentials for the specified OpenAI model.
363+
364+
Args:
365+
url (str): The base URL for the OpenAI API.
366+
api_key (str): The API key for the OpenAI API.
367+
368+
Returns:
369+
Tuple[str, str | None]: (api_key, base_url)
370+
"""
371+
372+
api_key = api_key or os.getenv("OPENAI_API_KEY")
373+
base_url = url or os.getenv("OPENAI_API_BASE_URL")
374+
return api_key, base_url
375+
312376
def get_tools(self) -> List[FunctionTool]:
313-
r"""Returns a list of FunctionTool objects representing the
314-
functions in the toolkit.
377+
r"""Returns a list of FunctionTool objects representing the functions
378+
in the toolkit.
315379
316380
Returns:
317-
List[FunctionTool]: A list of FunctionTool objects
318-
representing the functions in the toolkit.
381+
List[FunctionTool]: A list of FunctionTool objects representing the
382+
functions in the toolkit.
319383
"""
320384
return [
321385
FunctionTool(self.generate_image),
322-
# could add edit_image function later
323386
]
387+
388+
389+
# Backward compatibility alias
390+
OpenAIImageToolkit = ImageGenToolkit

camel/utils/commons.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
354354
key_way = "https://www.zhipuai.cn/"
355355
elif env_var_name == 'KLAVIS_API_KEY':
356356
key_way = "https://www.klavis.ai/docs"
357+
elif env_var_name == 'XAI_API_KEY':
358+
key_way = "https://api.x.ai/v1"
357359

358360
if missing_keys:
359361
raise ValueError(
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14+
from camel.agents import ChatAgent
15+
from camel.models import ModelFactory
16+
from camel.toolkits import ImageGenToolkit
17+
from camel.types import ModelPlatformType, ModelType
18+
19+
# Define system message
20+
sys_msg = "You are a helpful assistant that can generate images."
21+
22+
# Create Image Generation Toolkit with Grok-2 model and base64 response format
23+
tools = [
24+
*ImageGenToolkit(
25+
model="grok-2-image-1212",
26+
response_format="b64_json",
27+
).get_tools()
28+
]
29+
30+
model = ModelFactory.create(
31+
model_platform=ModelPlatformType.DEFAULT,
32+
model_type=ModelType.DEFAULT,
33+
)
34+
35+
# Set agent
36+
camel_agent = ChatAgent(
37+
system_message=sys_msg,
38+
model=model,
39+
tools=tools,
40+
)
41+
42+
# Define a user message
43+
usr_msg = "Generate 1 image of a camel working out in a gym."
44+
45+
# Get response information
46+
response = camel_agent.step(usr_msg)
47+
48+
print(f"Tool calls made: {len(response.info['tool_calls'])}")
49+
print(f"\nAgent response: {response.msg}")

examples/toolkits/openai_image_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
1414
from camel.agents import ChatAgent
1515
from camel.models import ModelFactory
16-
from camel.toolkits import OpenAIImageToolkit
16+
from camel.toolkits import ImageGenToolkit
1717
from camel.types import ModelPlatformType, ModelType
1818

1919
# Define system message
2020
sys_msg = "You are a helpful assistant that can generate images."
2121

22-
# Create OpenAI Image Toolkit with DALL-E 3 model and base64 response format
22+
# Create Image Generation Toolkit with DALL-E 3 model and base64 response
2323
tools = [
24-
*OpenAIImageToolkit(
24+
*ImageGenToolkit(
2525
model="dall-e-3",
2626
response_format="b64_json",
2727
size="1024x1024",

test/toolkits/test_openai_image.py renamed to test/toolkits/test_image_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
import pytest
1515
from PIL import Image
1616

17-
from camel.toolkits import OpenAIImageToolkit
17+
from camel.toolkits import ImageGenToolkit
1818

1919

2020
@pytest.fixture
2121
def image_toolkit():
22-
return OpenAIImageToolkit()
22+
return ImageGenToolkit()
2323

2424

2525
def test_base64_to_image_valid(image_toolkit):

0 commit comments

Comments
 (0)