Skip to content

Commit 1d26b39

Browse files
authored
Add deprecated_aliases to runopt and add warning (#1142) (#1142)
Summary: Similar to `runopt.alias` lets introduce and use `runopt.deprecated`. This will warn the user with a`UserWarning` when the user uses that specific name and suggests the primary one instead. bypass-github-export-checks Reviewed By: kiukchung Differential Revision: D84180061
1 parent 79e14da commit 1d26b39

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

torchx/specs/api.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pathlib
1616
import re
1717
import typing
18+
import warnings
1819
from dataclasses import asdict, dataclass, field
1920
from datetime import datetime
2021
from enum import Enum
@@ -894,11 +895,15 @@ class runopt:
894895
class alias(str):
895896
pass
896897

898+
class deprecated(str):
899+
pass
900+
897901
default: CfgVal
898902
opt_type: Type[CfgVal]
899903
is_required: bool
900904
help: str
901905
aliases: list[alias] | None = None
906+
deprecated_aliases: list[deprecated] | None = None
902907

903908
@property
904909
def is_type_list_of_str(self) -> bool:
@@ -990,7 +995,7 @@ class runopts:
990995

991996
def __init__(self) -> None:
992997
self._opts: Dict[str, runopt] = {}
993-
self._alias_to_key: dict[runopt.alias, str] = {}
998+
self._alias_to_key: dict[str, str] = {}
994999

9951000
def __iter__(self) -> Iterator[Tuple[str, runopt]]:
9961001
return self._opts.items().__iter__()
@@ -1044,12 +1049,24 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
10441049
val = resolved_cfg.get(cfg_key)
10451050
resolved_name = None
10461051
aliases = runopt.aliases or []
1052+
deprecated_aliases = runopt.deprecated_aliases or []
10471053
if val is None:
10481054
for alias in aliases:
10491055
val = resolved_cfg.get(alias)
10501056
if alias in cfg or val is not None:
10511057
resolved_name = alias
10521058
break
1059+
for alias in deprecated_aliases:
1060+
val = resolved_cfg.get(alias)
1061+
if val is not None:
1062+
resolved_name = alias
1063+
use_instead = self._alias_to_key.get(alias)
1064+
warnings.warn(
1065+
f"Run option `{alias}` is deprecated, use `{use_instead}` instead",
1066+
UserWarning,
1067+
stacklevel=2,
1068+
)
1069+
break
10531070
else:
10541071
resolved_name = cfg_key
10551072
for alias in aliases:
@@ -1175,20 +1192,32 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
11751192
def _get_primary_key_and_aliases(
11761193
self,
11771194
cfg_key: list[str] | str,
1178-
) -> tuple[str, list[runopt.alias]]:
1195+
) -> tuple[str, list[runopt.alias], list[runopt.deprecated]]:
11791196
"""
11801197
Returns the primary key and aliases for the given cfg_key.
11811198
"""
11821199
if isinstance(cfg_key, str):
1183-
return cfg_key, []
1200+
return cfg_key, [], []
11841201

11851202
if len(cfg_key) == 0:
11861203
raise ValueError("cfg_key must be a non-empty list")
1204+
1205+
if isinstance(cfg_key[0], runopt.alias) or isinstance(
1206+
cfg_key[0], runopt.deprecated
1207+
):
1208+
warnings.warn(
1209+
"The main name of the run option should be the head of the list.",
1210+
UserWarning,
1211+
stacklevel=2,
1212+
)
11871213
primary_key = None
11881214
aliases = list[runopt.alias]()
1215+
deprecated_aliases = list[runopt.deprecated]()
11891216
for name in cfg_key:
11901217
if isinstance(name, runopt.alias):
11911218
aliases.append(name)
1219+
elif isinstance(name, runopt.deprecated):
1220+
deprecated_aliases.append(name)
11921221
else:
11931222
if primary_key is not None:
11941223
raise ValueError(
@@ -1199,7 +1228,7 @@ def _get_primary_key_and_aliases(
11991228
raise ValueError(
12001229
"Missing cfg_key. Please provide one other than the aliases."
12011230
)
1202-
return primary_key, aliases
1231+
return primary_key, aliases, deprecated_aliases
12031232

12041233
def add(
12051234
self,
@@ -1214,7 +1243,9 @@ def add(
12141243
value (if any). If the ``default`` is not specified then this option
12151244
is a required option.
12161245
"""
1217-
primary_key, aliases = self._get_primary_key_and_aliases(cfg_key)
1246+
primary_key, aliases, deprecated_aliases = self._get_primary_key_and_aliases(
1247+
cfg_key
1248+
)
12181249
if required and default is not None:
12191250
raise ValueError(
12201251
f"Required option: {cfg_key} must not specify default value. Given: {default}"
@@ -1225,9 +1256,11 @@ def add(
12251256
f"Option: {cfg_key}, must be of type: {type_}."
12261257
f" Given: {default} ({type(default).__name__})"
12271258
)
1228-
opt = runopt(default, type_, required, help, aliases)
1259+
opt = runopt(default, type_, required, help, aliases, deprecated_aliases)
12291260
for alias in aliases:
12301261
self._alias_to_key[alias] = primary_key
1262+
for deprecated_alias in deprecated_aliases:
1263+
self._alias_to_key[deprecated_alias] = primary_key
12311264
self._opts[primary_key] = opt
12321265

12331266
def update(self, other: "runopts") -> None:

torchx/specs/test/api_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import tempfile
1414
import time
1515
import unittest
16+
import warnings
1617
from dataclasses import asdict
1718
from pathlib import Path
1819
from typing import Dict, List, Mapping, Tuple, Union
@@ -621,6 +622,32 @@ def test_runopts_resolve_with_none_valued_aliases(self) -> None:
621622
with self.assertRaises(InvalidRunConfigException):
622623
opts.resolve({"model_type_name": None, "modelTypeName": "low"})
623624

625+
def test_runopts_add_with_deprecated_aliases(self) -> None:
626+
opts = runopts()
627+
with warnings.catch_warnings(record=True) as w:
628+
opts.add(
629+
[runopt.deprecated("jobPriority"), "job_priority"],
630+
type_=str,
631+
help="run as user",
632+
)
633+
self.assertEqual(len(w), 1)
634+
self.assertEqual(w[0].category, UserWarning)
635+
self.assertEqual(
636+
str(w[0].message),
637+
"The main name of the run option should be the head of the list.",
638+
)
639+
640+
opts.resolve({"job_priority": "high"})
641+
with warnings.catch_warnings(record=True) as w:
642+
warnings.simplefilter("always")
643+
opts.resolve({"jobPriority": "high"})
644+
self.assertEqual(len(w), 1)
645+
self.assertEqual(w[0].category, UserWarning)
646+
self.assertEqual(
647+
str(w[0].message),
648+
"Run option `jobPriority` is deprecated, use `job_priority` instead",
649+
)
650+
624651
def get_runopts(self) -> runopts:
625652
opts = runopts()
626653
opts.add("run_as", type_=str, help="run as user", required=True)

0 commit comments

Comments
 (0)