Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions src/dbt_mcp/dbt_admin/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dbt_mcp.dbt_admin.client import DbtAdminAPIClient
from dbt_mcp.prompts.prompts import get_prompt
from dbt_mcp.tools.annotations import create_tool_annotations
from dbt_mcp.tools.config import DbtMcpContext
from dbt_mcp.tools.definitions import ToolDefinition
from dbt_mcp.tools.register import register_tools
from dbt_mcp.tools.tool_names import ToolName
Expand Down Expand Up @@ -38,17 +39,29 @@ class JobRunStatus(str, Enum):
}


def create_admin_api_tool_definitions(
admin_client: DbtAdminAPIClient, admin_api_config: AdminApiConfig
) -> list[ToolDefinition]:
def get_admin_client_and_config(
ctx: DbtMcpContext,
) -> tuple[DbtAdminAPIClient, AdminApiConfig]:
admin_api_config = ctx.get_admin_api_config()
if admin_api_config is None:
raise ValueError("admin api config is not set")
admin_api_client = ctx.get_admin_api_client()
if admin_api_client is None:
raise ValueError("admin api client is not set")
return admin_api_client, admin_api_config


def create_admin_api_tool_definitions() -> list[ToolDefinition]:
def list_jobs(
ctx: DbtMcpContext,
# TODO: add support for project_id in the future
# project_id: Optional[int] = None,
limit: int | None = None,
offset: int | None = None,
) -> list[dict[str, Any]] | str:
"""List jobs in an account."""
try:
admin_client, admin_api_config = get_admin_client_and_config(ctx)
params = {}
# if project_id:
# params["project_id"] = project_id
Expand All @@ -65,15 +78,20 @@ def list_jobs(
)
return str(e)

def get_job_details(job_id: int) -> dict[str, Any] | str:
def get_job_details(
ctx: DbtMcpContext,
job_id: int,
) -> dict[str, Any] | str:
"""Get details for a specific job."""
try:
admin_client, admin_api_config = get_admin_client_and_config(ctx)
return admin_client.get_job_details(admin_api_config.account_id, job_id)
except Exception as e:
logger.error(f"Error getting job {job_id}: {e}")
return str(e)

def trigger_job_run(
ctx: DbtMcpContext,
job_id: int,
cause: str = "Triggered by dbt MCP",
git_branch: str | None = None,
Expand All @@ -82,6 +100,7 @@ def trigger_job_run(
) -> dict[str, Any] | str:
"""Trigger a job run."""
try:
admin_client, admin_api_config = get_admin_client_and_config(ctx)
kwargs = {}
if git_branch:
kwargs["git_branch"] = git_branch
Expand All @@ -97,6 +116,7 @@ def trigger_job_run(
return str(e)

def list_jobs_runs(
ctx: DbtMcpContext,
job_id: int | None = None,
status: JobRunStatus | None = None,
limit: int | None = None,
Expand All @@ -105,6 +125,7 @@ def list_jobs_runs(
) -> list[dict[str, Any]] | str:
"""List runs in an account."""
try:
admin_client, admin_api_config = get_admin_client_and_config(ctx)
params: dict[str, Any] = {}
if job_id:
params["job_definition_id"] = job_id
Expand All @@ -125,6 +146,7 @@ def list_jobs_runs(
return str(e)

def get_job_run_details(
ctx: DbtMcpContext,
run_id: int,
debug: bool = Field(
default=False,
Expand All @@ -133,32 +155,45 @@ def get_job_run_details(
) -> dict[str, Any] | str:
"""Get details for a specific job run."""
try:
admin_client, admin_api_config = get_admin_client_and_config(ctx)
return admin_client.get_job_run_details(
admin_api_config.account_id, run_id, debug=debug
)
except Exception as e:
logger.error(f"Error getting run {run_id}: {e}")
return str(e)

def cancel_job_run(run_id: int) -> dict[str, Any] | str:
def cancel_job_run(
ctx: DbtMcpContext,
run_id: int,
) -> dict[str, Any] | str:
"""Cancel a job run."""
try:
admin_client, admin_api_config = get_admin_client_and_config(ctx)
return admin_client.cancel_job_run(admin_api_config.account_id, run_id)
except Exception as e:
logger.error(f"Error cancelling run {run_id}: {e}")
return str(e)

def retry_job_run(run_id: int) -> dict[str, Any] | str:
def retry_job_run(
ctx: DbtMcpContext,
run_id: int,
) -> dict[str, Any] | str:
"""Retry a failed job run."""
try:
admin_client, admin_api_config = get_admin_client_and_config(ctx)
return admin_client.retry_job_run(admin_api_config.account_id, run_id)
except Exception as e:
logger.error(f"Error retrying run {run_id}: {e}")
return str(e)

def list_job_run_artifacts(run_id: int) -> list[str] | str:
def list_job_run_artifacts(
ctx: DbtMcpContext,
run_id: int,
) -> list[str] | str:
"""List artifacts for a job run."""
try:
admin_client, admin_api_config = get_admin_client_and_config(ctx)
return admin_client.list_job_run_artifacts(
admin_api_config.account_id, run_id
)
Expand All @@ -167,10 +202,14 @@ def list_job_run_artifacts(run_id: int) -> list[str] | str:
return str(e)

def get_job_run_artifact(
run_id: int, artifact_path: str, step: int | None = None
ctx: DbtMcpContext,
run_id: int,
artifact_path: str,
step: int | None = None,
) -> Any | str:
"""Get a specific job run artifact."""
try:
admin_client, admin_api_config = get_admin_client_and_config(ctx)
return admin_client.get_job_run_artifact(
admin_api_config.account_id, run_id, artifact_path, step
)
Expand Down Expand Up @@ -276,13 +315,11 @@ def get_job_run_artifact(

def register_admin_api_tools(
dbt_mcp: FastMCP,
admin_config: AdminApiConfig,
exclude_tools: Sequence[ToolName] = [],
) -> None:
"""Register dbt Admin API tools."""
admin_client = DbtAdminAPIClient(admin_config)
register_tools(
dbt_mcp,
create_admin_api_tool_definitions(admin_client, admin_config),
create_admin_api_tool_definitions(),
exclude_tools,
)
41 changes: 29 additions & 12 deletions src/dbt_mcp/dbt_cli/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,31 @@

from dbt_mcp.config.config import DbtCliConfig
from dbt_mcp.prompts.prompts import get_prompt
from dbt_mcp.tools.annotations import create_tool_annotations
from dbt_mcp.tools.config import DbtMcpContext
from dbt_mcp.tools.definitions import ToolDefinition
from dbt_mcp.tools.register import register_tools
from dbt_mcp.tools.tool_names import ToolName
from dbt_mcp.tools.annotations import create_tool_annotations


def create_dbt_cli_tool_definitions(config: DbtCliConfig) -> list[ToolDefinition]:
def get_cli_config(ctx: DbtMcpContext) -> DbtCliConfig:
dbt_cli_config = ctx.get_dbt_cli_config()
if dbt_cli_config is None:
raise ValueError("dbt cli config is not set")
return dbt_cli_config


def create_dbt_cli_tool_definitions() -> list[ToolDefinition]:
def _run_dbt_command(
ctx: DbtMcpContext,
command: list[str],
selector: str | None = None,
resource_type: list[str] | None = None,
is_selectable: bool = False,
is_full_refresh: bool | None = False,
vars: str | None = None,
) -> str:
config = ctx.get_dbt_cli_config()
try:
# Commands that should always be quiet to reduce output verbosity
verbose_commands = [
Expand Down Expand Up @@ -79,6 +89,7 @@ def _run_dbt_command(
return str(e)

def build(
ctx: DbtMcpContext,
selector: str | None = Field(
default=None, description=get_prompt("dbt_cli/args/selectors")
),
Expand All @@ -90,20 +101,22 @@ def build(
),
) -> str:
return _run_dbt_command(
ctx,
["build"],
selector,
is_selectable=True,
is_full_refresh=is_full_refresh,
vars=vars,
)

def compile() -> str:
return _run_dbt_command(["compile"])
def compile(ctx: DbtMcpContext) -> str:
return _run_dbt_command(ctx, ["compile"])

def docs() -> str:
return _run_dbt_command(["docs", "generate"])
def docs(ctx: DbtMcpContext) -> str:
return _run_dbt_command(ctx, ["docs", "generate"])

def ls(
ctx: DbtMcpContext,
selector: str | None = Field(
default=None, description=get_prompt("dbt_cli/args/selectors")
),
Expand All @@ -113,16 +126,18 @@ def ls(
),
) -> str:
return _run_dbt_command(
ctx,
["list"],
selector,
resource_type=resource_type,
is_selectable=True,
)

def parse() -> str:
return _run_dbt_command(["parse"])
def parse(ctx: DbtMcpContext) -> str:
return _run_dbt_command(ctx, ["parse"])

def run(
ctx: DbtMcpContext,
selector: str | None = Field(
default=None, description=get_prompt("dbt_cli/args/selectors")
),
Expand All @@ -134,6 +149,7 @@ def run(
),
) -> str:
return _run_dbt_command(
ctx,
["run"],
selector,
is_selectable=True,
Expand All @@ -142,16 +158,18 @@ def run(
)

def test(
ctx: DbtMcpContext,
selector: str | None = Field(
default=None, description=get_prompt("dbt_cli/args/selectors")
),
vars: str | None = Field(
default=None, description=get_prompt("dbt_cli/args/vars")
),
) -> str:
return _run_dbt_command(["test"], selector, is_selectable=True, vars=vars)
return _run_dbt_command(ctx, ["test"], selector, is_selectable=True, vars=vars)

def show(
ctx: DbtMcpContext,
sql_query: str = Field(description=get_prompt("dbt_cli/args/sql_query")),
limit: int = Field(default=5, description=get_prompt("dbt_cli/args/limit")),
) -> str:
Expand All @@ -171,7 +189,7 @@ def show(
if cli_limit is not None:
args.extend(["--limit", str(cli_limit)])
args.extend(["--output", "json"])
return _run_dbt_command(args)
return _run_dbt_command(ctx, args)

return [
ToolDefinition(
Expand Down Expand Up @@ -260,11 +278,10 @@ def show(

def register_dbt_cli_tools(
dbt_mcp: FastMCP,
config: DbtCliConfig,
exclude_tools: Sequence[ToolName] = [],
) -> None:
register_tools(
dbt_mcp,
create_dbt_cli_tool_definitions(config),
create_dbt_cli_tool_definitions(),
exclude_tools,
)
Loading
Loading