2121
2222import logging
2323import time
24+ from typing import Any
2425
2526from fastmcp import Context
2627from sqlalchemy .exc import SQLAlchemyError
4647logger = logging .getLogger (__name__ )
4748
4849
50+ def _find_chart (identifier : int | str ) -> Any | None :
51+ """Find a chart by numeric ID or UUID string."""
52+ from superset .daos .chart import ChartDAO
53+
54+ if isinstance (identifier , int ) or (
55+ isinstance (identifier , str ) and identifier .isdigit ()
56+ ):
57+ chart_id = int (identifier ) if isinstance (identifier , str ) else identifier
58+ return ChartDAO .find_by_id (chart_id )
59+ return ChartDAO .find_by_id (identifier , id_column = "uuid" )
60+
61+
62+ def _build_update_payload (
63+ request : UpdateChartRequest ,
64+ chart : Any ,
65+ ) -> dict [str , Any ] | GenerateChartResponse :
66+ """Build the update payload for a chart update.
67+
68+ Returns a dict payload on success, or a GenerateChartResponse error
69+ when neither config nor chart_name is provided.
70+ """
71+ if request .config is not None :
72+ dataset_id = chart .datasource_id if chart .datasource_id else None
73+ new_form_data = map_config_to_form_data (request .config , dataset_id = dataset_id )
74+ new_form_data .pop ("_mcp_warnings" , None )
75+
76+ chart_name = (
77+ request .chart_name
78+ if request .chart_name
79+ else chart .slice_name or generate_chart_name (request .config )
80+ )
81+
82+ return {
83+ "slice_name" : chart_name ,
84+ "viz_type" : new_form_data ["viz_type" ],
85+ "params" : json .dumps (new_form_data ),
86+ }
87+
88+ # Name-only update: keep existing visualization, just rename
89+ if not request .chart_name :
90+ return GenerateChartResponse .model_validate (
91+ {
92+ "chart" : None ,
93+ "error" : {
94+ "error_type" : "ValidationError" ,
95+ "message" : ("Either 'config' or 'chart_name' must be provided." ),
96+ "details" : (
97+ "Either 'config' or 'chart_name' must be provided. "
98+ "Use config for visualization changes, chart_name "
99+ "for renaming."
100+ ),
101+ },
102+ "success" : False ,
103+ "schema_version" : "2.0" ,
104+ "api_version" : "v1" ,
105+ }
106+ )
107+ return {"slice_name" : request .chart_name }
108+
109+
49110@tool (
50111 tags = ["mutate" ],
51112 class_permission_name = "Chart" ,
@@ -105,29 +166,22 @@ async def update_chart(
105166 start_time = time .time ()
106167
107168 try :
108- # Find the existing chart
109- from superset .daos .chart import ChartDAO
110-
111169 with event_logger .log_context (action = "mcp.update_chart.chart_lookup" ):
112- chart = None
113- if isinstance (request .identifier , int ) or (
114- isinstance (request .identifier , str ) and request .identifier .isdigit ()
115- ):
116- chart_id = (
117- int (request .identifier )
118- if isinstance (request .identifier , str )
119- else request .identifier
120- )
121- chart = ChartDAO .find_by_id (chart_id )
122- else :
123- # Try UUID lookup using DAO flexible method
124- chart = ChartDAO .find_by_id (request .identifier , id_column = "uuid" )
170+ chart = _find_chart (request .identifier )
125171
126172 if not chart :
127173 return GenerateChartResponse .model_validate (
128174 {
129175 "chart" : None ,
130- "error" : f"No chart found with identifier: { request .identifier } " ,
176+ "error" : {
177+ "error_type" : "NotFound" ,
178+ "message" : (
179+ f"No chart found with identifier: { request .identifier } "
180+ ),
181+ "details" : (
182+ f"No chart found with identifier: { request .identifier } "
183+ ),
184+ },
131185 "success" : False ,
132186 "schema_version" : "2.0" ,
133187 "api_version" : "v1" ,
@@ -157,30 +211,15 @@ async def update_chart(
157211 }
158212 )
159213
160- # Map the new config to form_data format
161- # Get dataset_id from existing chart for column type checking
162- dataset_id = chart .datasource_id if chart .datasource_id else None
163- new_form_data = map_config_to_form_data (request .config , dataset_id = dataset_id )
164- new_form_data .pop ("_mcp_warnings" , None )
165-
166- # Update chart using Superset's command
214+ # Build update payload (config update or name-only rename)
167215 from superset .commands .chart .update import UpdateChartCommand
168216
169- with event_logger .log_context (action = "mcp.update_chart.db_write" ):
170- # Generate new chart name if provided, otherwise keep existing
171- chart_name = (
172- request .chart_name
173- if request .chart_name
174- else chart .slice_name or generate_chart_name (request .config )
175- )
217+ payload_or_error = _build_update_payload (request , chart )
218+ if isinstance (payload_or_error , GenerateChartResponse ):
219+ return payload_or_error
176220
177- update_payload = {
178- "slice_name" : chart_name ,
179- "viz_type" : new_form_data ["viz_type" ],
180- "params" : json .dumps (new_form_data ),
181- }
182-
183- command = UpdateChartCommand (chart .id , update_payload )
221+ with event_logger .log_context (action = "mcp.update_chart.db_write" ):
222+ command = UpdateChartCommand (chart .id , payload_or_error )
184223 updated_chart = command .run ()
185224
186225 # Generate semantic analysis
@@ -199,7 +238,11 @@ async def update_chart(
199238 chart_name = (
200239 updated_chart .slice_name
201240 if updated_chart and hasattr (updated_chart , "slice_name" )
202- else generate_chart_name (request .config )
241+ else (
242+ generate_chart_name (request .config )
243+ if request .config
244+ else "Updated chart"
245+ )
203246 )
204247 accessibility = AccessibilityMetadata (
205248 color_blind_safe = True , # Would need actual analysis
@@ -288,7 +331,11 @@ async def update_chart(
288331 return GenerateChartResponse .model_validate (
289332 {
290333 "chart" : None ,
291- "error" : f"Chart update failed: { str (e )} " ,
334+ "error" : {
335+ "error_type" : type (e ).__name__ ,
336+ "message" : f"Chart update failed: { e } " ,
337+ "details" : str (e ),
338+ },
292339 "performance" : {
293340 "query_duration_ms" : execution_time ,
294341 "cache_status" : "error" ,
0 commit comments