15
15
import base64
16
16
import os
17
17
from io import BytesIO
18
- from typing import List , Literal , Optional , Union
18
+ from typing import ClassVar , List , Literal , Optional , Tuple , Union
19
19
20
20
from openai import OpenAI
21
21
from PIL import Image
29
29
30
30
31
31
@MCPServer ()
32
- class OpenAIImageToolkit (BaseToolkit ):
33
- r"""A class toolkit for image generation using OpenAI's
34
- Image Generation API.
35
- """
36
-
37
- @api_keys_required (
38
- [
39
- ("api_key" , "OPENAI_API_KEY" ),
40
- ]
41
- )
32
+ class ImageGenToolkit (BaseToolkit ):
33
+ r"""A class toolkit for image generation using Grok and OpenAI models."""
34
+
35
+ GROK_MODELS : ClassVar [List [str ]] = [
36
+ "grok-2-image" ,
37
+ "grok-2-image-latest" ,
38
+ "grok-2-image-1212" ,
39
+ ]
40
+ OPENAI_MODELS : ClassVar [List [str ]] = [
41
+ "gpt-image-1" ,
42
+ "dall-e-3" ,
43
+ "dall-e-2" ,
44
+ ]
45
+
42
46
def __init__ (
43
47
self ,
44
48
model : Optional [
45
- Literal ["gpt-image-1" , "dall-e-3" , "dall-e-2" ]
46
- ] = "gpt-image-1" ,
49
+ Literal [
50
+ "gpt-image-1" ,
51
+ "dall-e-3" ,
52
+ "dall-e-2" ,
53
+ "grok-2-image" ,
54
+ "grok-2-image-latest" ,
55
+ "grok-2-image-1212" ,
56
+ ]
57
+ ] = "dall-e-3" ,
47
58
timeout : Optional [float ] = None ,
48
59
api_key : Optional [str ] = None ,
49
60
url : Optional [str ] = None ,
@@ -72,12 +83,12 @@ def __init__(
72
83
# NOTE: Some arguments are set in the constructor to prevent the agent
73
84
# from making invalid API calls with model-specific parameters. For
74
85
# example, the 'style' argument is only supported by 'dall-e-3'.
75
- r"""Initializes a new instance of the OpenAIImageToolkit class.
86
+ r"""Initializes a new instance of the ImageGenToolkit class.
76
87
77
88
Args:
78
89
api_key (Optional[str]): The API key for authenticating
79
- with the OpenAI service. (default: :obj:`None`)
80
- url (Optional[str]): The url to the OpenAI service.
90
+ with the image model service. (default: :obj:`None`)
91
+ url (Optional[str]): The url to the image model service.
81
92
(default: :obj:`None`)
82
93
model (Optional[str]): The model to use.
83
94
(default: :obj:`"dall-e-3"`)
@@ -103,9 +114,23 @@ def __init__(
103
114
image.(default: :obj:`"image_save"`)
104
115
"""
105
116
super ().__init__ (timeout = timeout )
106
- api_key = api_key or os .environ .get ("OPENAI_API_KEY" )
107
- url = url or os .environ .get ("OPENAI_API_BASE_URL" )
108
- self .client = OpenAI (api_key = api_key , base_url = url )
117
+ if model not in self .GROK_MODELS + self .OPENAI_MODELS :
118
+ available_models = sorted (self .OPENAI_MODELS + self .GROK_MODELS )
119
+ raise ValueError (
120
+ f"Unsupported model: { model } . "
121
+ f"Supported models are: { available_models } "
122
+ )
123
+
124
+ # Set default url for Grok models
125
+ url = "https://api.x.ai/v1" if model in self .GROK_MODELS else url
126
+
127
+ api_key , base_url = (
128
+ self .get_openai_credentials (url , api_key )
129
+ if model in self .OPENAI_MODELS
130
+ else self .get_grok_credentials (url , api_key )
131
+ )
132
+
133
+ self .client = OpenAI (api_key = api_key , base_url = base_url )
109
134
self .model = model
110
135
self .size = size
111
136
self .quality = quality
@@ -139,7 +164,7 @@ def base64_to_image(self, base64_string: str) -> Optional[Image.Image]:
139
164
return None
140
165
141
166
def _build_base_params (self , prompt : str , n : Optional [int ] = None ) -> dict :
142
- r"""Build base parameters dict for OpenAI API calls.
167
+ r"""Build base parameters dict for Image Model API calls.
143
168
144
169
Args:
145
170
prompt (str): The text prompt for the image operation.
@@ -153,6 +178,10 @@ def _build_base_params(self, prompt: str, n: Optional[int] = None) -> dict:
153
178
# basic parameters supported by all models
154
179
if n is not None :
155
180
params ["n" ] = n # type: ignore[assignment]
181
+
182
+ if self .model in self .GROK_MODELS :
183
+ return params
184
+
156
185
if self .size is not None :
157
186
params ["size" ] = self .size
158
187
@@ -179,16 +208,18 @@ def _build_base_params(self, prompt: str, n: Optional[int] = None) -> dict:
179
208
params ["quality" ] = self .quality
180
209
if self .background is not None :
181
210
params ["background" ] = self .background
182
-
183
211
return params
184
212
185
213
def _handle_api_response (
186
- self , response , image_name : Union [str , List [str ]], operation : str
214
+ self ,
215
+ response ,
216
+ image_name : Union [str , List [str ]],
217
+ operation : str ,
187
218
) -> str :
188
- r"""Handle API response from OpenAI image operations.
219
+ r"""Handle API response from image operations.
189
220
190
221
Args:
191
- response: The response object from OpenAI API.
222
+ response: The response object from image model API.
192
223
image_name (Union[str, List[str]]): Name(s) for the saved image
193
224
file(s). If str, the same name is used for all images (will
194
225
cause error for multiple images). If list, must have exactly
@@ -198,8 +229,9 @@ def _handle_api_response(
198
229
Returns:
199
230
str: Success message with image path/URL or error message.
200
231
"""
232
+ source = "Grok" if self .model in self .GROK_MODELS else "OpenAI"
201
233
if response .data is None or len (response .data ) == 0 :
202
- error_msg = "No image data returned from OpenAI API."
234
+ error_msg = f "No image data returned from { source } API."
203
235
logger .error (error_msg )
204
236
return error_msg
205
237
@@ -283,7 +315,7 @@ def generate_image(
283
315
image_name : Union [str , List [str ]] = "image.png" ,
284
316
n : int = 1 ,
285
317
) -> str :
286
- r"""Generate an image using OpenAI's Image Generation models.
318
+ r"""Generate an image using image models.
287
319
The generated image will be saved locally (for ``b64_json`` response
288
320
formats) or an image URL will be returned (for ``url`` response
289
321
formats).
@@ -309,15 +341,50 @@ def generate_image(
309
341
logger .error (error_msg )
310
342
return error_msg
311
343
344
+ @api_keys_required ([("api_key" , "XAI_API_KEY" )])
345
+ def get_grok_credentials (self , url , api_key ) -> Tuple [str , str ]: # type: ignore[return-value]
346
+ r"""Get API credentials for the specified Grok model.
347
+
348
+ Args:
349
+ url (str): The base URL for the Grok API.
350
+ api_key (str): The API key for the Grok API.
351
+
352
+ Returns:
353
+ tuple: (api_key, base_url)
354
+ """
355
+
356
+ # Get credentials based on model type
357
+ api_key = api_key or os .getenv ("XAI_API_KEY" )
358
+ return api_key , url
359
+
360
+ @api_keys_required ([("api_key" , "OPENAI_API_KEY" )])
361
+ def get_openai_credentials (self , url , api_key ) -> Tuple [str , str | None ]: # type: ignore[return-value]
362
+ r"""Get API credentials for the specified OpenAI model.
363
+
364
+ Args:
365
+ url (str): The base URL for the OpenAI API.
366
+ api_key (str): The API key for the OpenAI API.
367
+
368
+ Returns:
369
+ Tuple[str, str | None]: (api_key, base_url)
370
+ """
371
+
372
+ api_key = api_key or os .getenv ("OPENAI_API_KEY" )
373
+ base_url = url or os .getenv ("OPENAI_API_BASE_URL" )
374
+ return api_key , base_url
375
+
312
376
def get_tools (self ) -> List [FunctionTool ]:
313
- r"""Returns a list of FunctionTool objects representing the
314
- functions in the toolkit.
377
+ r"""Returns a list of FunctionTool objects representing the functions
378
+ in the toolkit.
315
379
316
380
Returns:
317
- List[FunctionTool]: A list of FunctionTool objects
318
- representing the functions in the toolkit.
381
+ List[FunctionTool]: A list of FunctionTool objects representing the
382
+ functions in the toolkit.
319
383
"""
320
384
return [
321
385
FunctionTool (self .generate_image ),
322
- # could add edit_image function later
323
386
]
387
+
388
+
389
+ # Backward compatibility alias
390
+ OpenAIImageToolkit = ImageGenToolkit
0 commit comments