diff --git a/topic/machine-learning/pycaret/automl_timeseries_forecasting_with_pycaret.ipynb b/topic/machine-learning/pycaret/automl_timeseries_forecasting_with_pycaret.ipynb index f023b748..9d262e79 100644 --- a/topic/machine-learning/pycaret/automl_timeseries_forecasting_with_pycaret.ipynb +++ b/topic/machine-learning/pycaret/automl_timeseries_forecasting_with_pycaret.ipynb @@ -1084,12 +1084,14 @@ "# all available models are included by default)\n", "# - \"fold\" defines the number of folds to use for cross-validation.\n", "\n", - "# Note: This is only relevant if we are executing automated tests\n", + "# On CI/testing, only evaluate a single cheap model.\n", + "# # Alternatives: arima, ets, et_cds_dt, exp_smooth, naive.\n", "if \"PYTEST_CURRENT_TEST\" in os.environ:\n", " best_models = compare_models(sort=\"MASE\",\n", - " include=[\"ets\", \"et_cds_dt\", \"naive\"],\n", - " n_select=3)\n", - "# If we are not in an automated test, compare all available models\n", + " include=[\"ets\"],\n", + " n_select=1)\n", + "\n", + "# When not on CI/testing, compare all available models.\n", "else:\n", " best_models = compare_models(sort=\"MASE\", n_select=3)" ] diff --git a/topic/machine-learning/pycaret/automl_timeseries_forecasting_with_pycaret.py b/topic/machine-learning/pycaret/automl_timeseries_forecasting_with_pycaret.py index 06aa7b33..4d15f3ed 100644 --- a/topic/machine-learning/pycaret/automl_timeseries_forecasting_with_pycaret.py +++ b/topic/machine-learning/pycaret/automl_timeseries_forecasting_with_pycaret.py @@ -82,10 +82,15 @@ def fetch_data(): def run_experiment(data): setup(data = data, fh=15, target="total_sales", index="month", log_experiment=True) + + # On CI/testing, only evaluate a single cheap model. + # Alternatives: arima, ets, et_cds_dt, exp_smooth, naive. if "PYTEST_CURRENT_TEST" in os.environ: best_models = compare_models(sort="MASE", - include=["arima", "ets", "exp_smooth"], - n_select=3) + include=["ets"], + n_select=1) + + # When not on CI/testing, compare all available models. else: best_models = compare_models(sort="MASE", n_select=3)