Skip to content
This repository was archived by the owner on Sep 16, 2025. It is now read-only.

Commit d72aeb3

Browse files
authored
Support for existing_cluster_id in DatabricksNotebookOperator (#73)
When tasks are launched with `DatabricksNotebookOperators` from within a TaskGroup using the `DatabricksWorkflowTaskGroup`, currently we do not support using `existing_cluster_id` for those Notebook tasks. The PR addresses this issue by allowing to support `existing_cluster_id` in such cases and additionally also keeps supporting the current `job_cluster_key` approach allowing users to use a combination of both for a workflow. closes: #70
1 parent f7d4e6a commit d72aeb3

2 files changed

Lines changed: 69 additions & 1 deletion

File tree

src/astro_databricks/operators/notebook.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,17 @@ def convert_to_databricks_workflow_task(
207207
for t in self.upstream_task_ids
208208
if t in relevant_upstreams
209209
],
210-
"job_cluster_key": self.job_cluster_key,
211210
**base_task_json,
212211
}
212+
213+
if self.existing_cluster_id and self.job_cluster_key:
214+
raise ValueError ("Both existing_cluster_id and job_cluster_key are set. Only one cluster can be set per task.")
215+
216+
if self.existing_cluster_id:
217+
result['existing_cluster_id'] = self.existing_cluster_id
218+
elif self.job_cluster_key:
219+
result['job_cluster_key'] = self.job_cluster_key
220+
213221
return result
214222

215223
def _get_databricks_task_id(self, task_id: str):

tests/databricks/test_workflow.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from unittest import mock
55

66
import pytest
7+
import copy
78
from airflow.exceptions import AirflowException
89
from airflow.utils.task_group import TaskGroup
910
from astro_databricks.operators.notebook import DatabricksNotebookOperator
@@ -50,6 +51,10 @@
5051
"timeout_seconds": 0,
5152
}
5253

54+
expected_workflow_json_existing_cluster_id = copy.deepcopy(expected_workflow_json)
55+
# remove job_cluster_key and add existing_cluster_id
56+
expected_workflow_json_existing_cluster_id['tasks'][1].pop('job_cluster_key')
57+
expected_workflow_json_existing_cluster_id['tasks'][1]['existing_cluster_id'] = 'foo'
5358

5459
@mock.patch("astro_databricks.operators.workflow.DatabricksHook")
5560
@mock.patch("astro_databricks.operators.workflow.ApiClient")
@@ -374,3 +379,58 @@ def test_create_workflow_with_nested_task_groups(
374379
== "unit_test_dag__test_workflow__middle_task_group__inner_task_group__inner_notebook"
375380
)
376381
assert outer_notebook_json["libraries"] == [{"pypi": {"package": "mlflow==2.4.0"}}]
382+
383+
@mock.patch("astro_databricks.operators.workflow.DatabricksHook")
384+
@mock.patch("astro_databricks.operators.workflow.ApiClient")
385+
@mock.patch("astro_databricks.operators.workflow.JobsApi")
386+
@mock.patch(
387+
"astro_databricks.operators.workflow.RunsApi.get_run",
388+
return_value={"state": {"life_cycle_state": "RUNNING"}},
389+
)
390+
def test_create_workflow_from_notebooks_with_different_clusters(
391+
mock_run_api, mock_jobs_api, mock_api, mock_hook, dag
392+
):
393+
mock_jobs_api.return_value.create_job.return_value = {"job_id": 1}
394+
with dag:
395+
task_group = DatabricksWorkflowTaskGroup(
396+
group_id="test_workflow",
397+
databricks_conn_id="foo",
398+
job_clusters=[{"job_cluster_key": "foo"}],
399+
notebook_params={"notebook_path": "/foo/bar"},
400+
notebook_packages=[{"tg_index": {"package": "tg_package"}}],
401+
)
402+
with task_group:
403+
notebook_1 = DatabricksNotebookOperator(
404+
task_id="notebook_1",
405+
databricks_conn_id="foo",
406+
notebook_path="/foo/bar",
407+
notebook_packages=[{"nb_index": {"package": "nb_package"}}],
408+
source="WORKSPACE",
409+
job_cluster_key="foo",
410+
)
411+
notebook_2 = DatabricksNotebookOperator(
412+
task_id="notebook_2",
413+
databricks_conn_id="foo",
414+
notebook_path="/foo/bar",
415+
source="WORKSPACE",
416+
existing_cluster_id="foo",
417+
notebook_params={
418+
"foo": "bar",
419+
},
420+
)
421+
notebook_1 >> notebook_2
422+
423+
assert len(task_group.children) == 3
424+
task_group.children["test_workflow.launch"].execute(context={})
425+
mock_jobs_api.return_value.create_job.assert_called_once_with(
426+
json=expected_workflow_json_existing_cluster_id,
427+
version=DATABRICKS_JOBS_API_VERSION,
428+
)
429+
mock_jobs_api.return_value.run_now.assert_called_once_with(
430+
job_id=1,
431+
jar_params=[],
432+
notebook_params={"notebook_path": "/foo/bar"},
433+
python_params=[],
434+
spark_submit_params=[],
435+
version=DATABRICKS_JOBS_API_VERSION,
436+
)

0 commit comments

Comments
 (0)