Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions metaflow/plugins/aws/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def create_job(
ephemeral_storage=None,
log_driver=None,
log_options=None,
container_secrets=None,
offload_command_to_s3=False,
):
job_name = self._job_name(
Expand Down Expand Up @@ -303,6 +304,7 @@ def create_job(
ephemeral_storage=ephemeral_storage,
log_driver=log_driver,
log_options=log_options,
container_secrets=container_secrets,
)
.task_id(attrs.get("metaflow.task_id"))
.environment_variable("AWS_DEFAULT_REGION", self._client.region())
Expand Down Expand Up @@ -427,6 +429,7 @@ def launch_job(
ephemeral_storage=None,
log_driver=None,
log_options=None,
container_secrets=None,
):
if queue is None:
queue = next(self._client.active_job_queues(), None)
Expand Down Expand Up @@ -469,6 +472,7 @@ def launch_job(
ephemeral_storage=ephemeral_storage,
log_driver=log_driver,
log_options=log_options,
container_secrets=container_secrets,
)
self.num_parallel = num_parallel
self.job = job.execute()
Expand Down
20 changes: 20 additions & 0 deletions metaflow/plugins/aws/batch/batch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,25 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs):
if env_deco:
env.update(env_deco[0].attributes["vars"])

# Collect ECS-style container secrets from @secrets decorator, if any.
# These entries have shape: {"name": <ENV>, "value_from": <SecretsManager ARN>} (snake_case input),
# and will be injected at container startup by AWS Batch/ECS via job definition.
container_secrets = []
secrets_deco = [deco for deco in node.decorators if deco.name == "secrets"]
if secrets_deco:
try:
for s in secrets_deco[0].attributes.get("sources", []) or []:
if isinstance(s, dict):
name = s.get("name")
value_from = s.get("value_from")
if isinstance(name, str) and isinstance(value_from, str):
container_secrets.append(
{"name": name, "value_from": value_from}
)
except Exception:
# best-effort only; ignore malformed entries silently to avoid breaking launches
pass

# Add the environment variables related to the input-paths argument
if split_vars:
env.update(split_vars)
Expand Down Expand Up @@ -366,6 +385,7 @@ def _sync_metadata():
log_driver=log_driver,
log_options=log_options,
num_parallel=num_parallel,
container_secrets=container_secrets,
)
except Exception:
traceback.print_exc()
Expand Down
16 changes: 16 additions & 0 deletions metaflow/plugins/aws/batch/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _register_job_definition(
ephemeral_storage,
log_driver,
log_options,
container_secrets=None,
):
# identify platform from any compute environment associated with the
# queue
Expand Down Expand Up @@ -199,6 +200,19 @@ def _register_job_definition(
"propagateTags": True,
}

# Inject ECS secrets for container-start environment variables, if provided.
if container_secrets:
norm = []
for item in container_secrets:
if not isinstance(item, dict):
continue
name = item.get("name")
value_from = item.get("value_from")
if isinstance(name, str) and isinstance(value_from, str):
norm.append({"name": name, "valueFrom": value_from})
if norm:
job_definition["containerProperties"]["secrets"] = norm

log_options_dict = {}
if log_options:
if isinstance(log_options, str):
Expand Down Expand Up @@ -480,6 +494,7 @@ def job_def(
ephemeral_storage,
log_driver,
log_options,
container_secrets=None,
):
self.payload["jobDefinition"] = self._register_job_definition(
image,
Expand All @@ -502,6 +517,7 @@ def job_def(
ephemeral_storage,
log_driver,
log_options,
container_secrets,
)
return self

Expand Down
6 changes: 6 additions & 0 deletions metaflow/plugins/secrets/secrets_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def task_pre_step(
SecretSpec.secret_spec_from_str(secret_spec_str_or_dict, role=role)
)
elif isinstance(secret_spec_str_or_dict, dict):
# If the dict is an ECS-style container-start secret spec, skip runtime fetching.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would break workloads running on EKS

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @savingoyal, just checking if this works?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump!

# These entries will be wired into the AWS Batch job definition as ECS secrets.
if "name" in secret_spec_str_or_dict and (
"value_from" in secret_spec_str_or_dict
):
continue
secret_specs.append(
SecretSpec.secret_spec_from_dict(secret_spec_str_or_dict, role=role)
)
Expand Down