|
4 | 4 | from unittest import mock |
5 | 5 |
|
6 | 6 | import pytest |
| 7 | +import copy |
7 | 8 | from airflow.exceptions import AirflowException |
8 | 9 | from airflow.utils.task_group import TaskGroup |
9 | 10 | from astro_databricks.operators.notebook import DatabricksNotebookOperator |
|
50 | 51 | "timeout_seconds": 0, |
51 | 52 | } |
52 | 53 |
|
| 54 | +expected_workflow_json_existing_cluster_id = copy.deepcopy(expected_workflow_json) |
| 55 | +# remove job_cluster_key and add existing_cluster_id |
| 56 | +expected_workflow_json_existing_cluster_id['tasks'][1].pop('job_cluster_key') |
| 57 | +expected_workflow_json_existing_cluster_id['tasks'][1]['existing_cluster_id'] = 'foo' |
53 | 58 |
|
54 | 59 | @mock.patch("astro_databricks.operators.workflow.DatabricksHook") |
55 | 60 | @mock.patch("astro_databricks.operators.workflow.ApiClient") |
@@ -374,3 +379,58 @@ def test_create_workflow_with_nested_task_groups( |
374 | 379 | == "unit_test_dag__test_workflow__middle_task_group__inner_task_group__inner_notebook" |
375 | 380 | ) |
376 | 381 | assert outer_notebook_json["libraries"] == [{"pypi": {"package": "mlflow==2.4.0"}}] |
| 382 | + |
| 383 | +@mock.patch("astro_databricks.operators.workflow.DatabricksHook") |
| 384 | +@mock.patch("astro_databricks.operators.workflow.ApiClient") |
| 385 | +@mock.patch("astro_databricks.operators.workflow.JobsApi") |
| 386 | +@mock.patch( |
| 387 | + "astro_databricks.operators.workflow.RunsApi.get_run", |
| 388 | + return_value={"state": {"life_cycle_state": "RUNNING"}}, |
| 389 | +) |
| 390 | +def test_create_workflow_from_notebooks_with_different_clusters( |
| 391 | + mock_run_api, mock_jobs_api, mock_api, mock_hook, dag |
| 392 | +): |
| 393 | + mock_jobs_api.return_value.create_job.return_value = {"job_id": 1} |
| 394 | + with dag: |
| 395 | + task_group = DatabricksWorkflowTaskGroup( |
| 396 | + group_id="test_workflow", |
| 397 | + databricks_conn_id="foo", |
| 398 | + job_clusters=[{"job_cluster_key": "foo"}], |
| 399 | + notebook_params={"notebook_path": "/foo/bar"}, |
| 400 | + notebook_packages=[{"tg_index": {"package": "tg_package"}}], |
| 401 | + ) |
| 402 | + with task_group: |
| 403 | + notebook_1 = DatabricksNotebookOperator( |
| 404 | + task_id="notebook_1", |
| 405 | + databricks_conn_id="foo", |
| 406 | + notebook_path="/foo/bar", |
| 407 | + notebook_packages=[{"nb_index": {"package": "nb_package"}}], |
| 408 | + source="WORKSPACE", |
| 409 | + job_cluster_key="foo", |
| 410 | + ) |
| 411 | + notebook_2 = DatabricksNotebookOperator( |
| 412 | + task_id="notebook_2", |
| 413 | + databricks_conn_id="foo", |
| 414 | + notebook_path="/foo/bar", |
| 415 | + source="WORKSPACE", |
| 416 | + existing_cluster_id="foo", |
| 417 | + notebook_params={ |
| 418 | + "foo": "bar", |
| 419 | + }, |
| 420 | + ) |
| 421 | + notebook_1 >> notebook_2 |
| 422 | + |
| 423 | + assert len(task_group.children) == 3 |
| 424 | + task_group.children["test_workflow.launch"].execute(context={}) |
| 425 | + mock_jobs_api.return_value.create_job.assert_called_once_with( |
| 426 | + json=expected_workflow_json_existing_cluster_id, |
| 427 | + version=DATABRICKS_JOBS_API_VERSION, |
| 428 | + ) |
| 429 | + mock_jobs_api.return_value.run_now.assert_called_once_with( |
| 430 | + job_id=1, |
| 431 | + jar_params=[], |
| 432 | + notebook_params={"notebook_path": "/foo/bar"}, |
| 433 | + python_params=[], |
| 434 | + spark_submit_params=[], |
| 435 | + version=DATABRICKS_JOBS_API_VERSION, |
| 436 | + ) |
0 commit comments