Skip to content

Commit db5a1a1

Browse files
committed
test for _check_arg
relative
1 parent 1aab672 commit db5a1a1

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ addopts =
1212
[coverage:run]
1313
parallel = True
1414
concurrency = thread
15+
relative_files = True
1516

1617
[coverage:report]
1718
exclude_lines =

tests/bases/test_collections.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,11 @@ def test_metric_collection_same_order():
182182
col2 = MetricCollection({"b": m2, "a": m1})
183183
for k1, k2 in zip(col1.keys(), col2.keys()):
184184
assert k1 == k2
185+
186+
187+
def test_collection_check_arg():
188+
assert MetricCollection._check_arg(None, 'prefix') is None
189+
assert MetricCollection._check_arg('sample', 'prefix') == 'sample'
190+
191+
with pytest.raises(ValueError, match="Expected input `postfix` to be a string, but got"):
192+
MetricCollection._check_arg(1, 'postfix')

torchmetrics/collections.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,7 @@ def _set_name(self, base: str) -> str:
191191
return name
192192

193193
@staticmethod
194-
def _check_arg(arg: str, name: str) -> Optional[str]:
195-
if arg is not None:
196-
if isinstance(arg, str):
197-
return arg
198-
raise ValueError(f'Expected input {name} to be a string')
199-
return None
194+
def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
195+
if arg is None or isinstance(arg, str):
196+
return arg
197+
raise ValueError(f'Expected input `{name}` to be a string, but got {type(arg)}')

0 commit comments

Comments
 (0)