Skip to content

Commit 1aab672

Browse files
SkafteNickiBorda
andauthored
Add postfix arg to MetricCollection (#188)
* postfix * chglog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
1 parent c1ae5ac commit 1aab672

File tree

3 files changed

+50
-18
lines changed

3 files changed

+50
-18
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3434
- Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142))
3535
- Added property `is_differentiable` to metrics and test for differentiability ([#154](https://github.com/PyTorchLightning/metrics/pull/154))
3636
- Added support for `average`, `ignore_index` and `mdmc_average` in `Accuracy` metric ([#166](https://github.com/PyTorchLightning/metrics/pull/166))
37+
- Added `postfix` arg to `MetricCollection` ([#188](https://github.com/PyTorchLightning/metrics/pull/188))
3738

3839
### Changed
3940

tests/bases/test_collections.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,30 +133,47 @@ def test_metric_collection_args_kwargs(tmpdir):
133133
assert metric_collection['DummyMetricDiff'].x == -20
134134

135135

136-
def test_metric_collection_prefix_arg(tmpdir):
136+
@pytest.mark.parametrize(
137+
"prefix, postfix", [
138+
[None, None],
139+
['prefix_', None],
140+
[None, '_postfix'],
141+
['prefix_', '_postfix'],
142+
]
143+
)
144+
def test_metric_collection_prefix_postfix_args(prefix, postfix):
137145
""" Test that the prefix arg alters the keywords in the output"""
138146
m1 = DummyMetricSum()
139147
m2 = DummyMetricDiff()
140148
names = ['DummyMetricSum', 'DummyMetricDiff']
149+
names = [prefix + n if prefix is not None else n for n in names]
150+
names = [n + postfix if postfix is not None else n for n in names]
141151

142-
metric_collection = MetricCollection([m1, m2], prefix='prefix_')
152+
metric_collection = MetricCollection([m1, m2], prefix=prefix, postfix=postfix)
143153

144154
# test forward
145155
out = metric_collection(5)
146156
for name in names:
147-
assert f"prefix_{name}" in out, 'prefix argument not working as intended with forward method'
157+
assert name in out, 'prefix or postfix argument not working as intended with forward method'
148158

149159
# test compute
150160
out = metric_collection.compute()
151161
for name in names:
152-
assert f"prefix_{name}" in out, 'prefix argument not working as intended with compute method'
162+
assert name in out, 'prefix or postfix argument not working as intended with compute method'
153163

154164
# test clone
155165
new_metric_collection = metric_collection.clone(prefix='new_prefix_')
156166
out = new_metric_collection(5)
167+
names = [n[len(prefix):] if prefix is not None else n for n in names] # strip away old prefix
157168
for name in names:
158169
assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method'
159170

171+
new_metric_collection = new_metric_collection.clone(postfix='_new_postfix')
172+
out = new_metric_collection(5)
173+
names = [n[:-len(postfix)] if postfix is not None else n for n in names] # strip away old postfix
174+
for name in names:
175+
assert f"new_prefix_{name}_new_postfix" in out, 'postfix argument not working as intended with clone method'
176+
160177

161178
def test_metric_collection_same_order():
162179
m1 = DummyMetricSum()

torchmetrics/collections.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class name as key for the output dict.
4040
4141
prefix: a string to append in front of the keys of the output dict
4242
43+
postfix: a string to append after the keys of the output dict
44+
4345
Raises:
4446
ValueError:
4547
If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``.
@@ -48,7 +50,11 @@ class name as key for the output dict.
4850
ValueError:
4951
If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``.
5052
ValueError:
51-
If ``metrics`` is is ``dict`` and passed any additional_metrics.
53+
If ``metrics`` is ``dict`` and additional_metrics are passed in.
54+
ValueError:
55+
If ``prefix`` is set and it is not a string.
56+
ValueError:
57+
If ``postfix`` is set and it is not a string.
5258
5359
Example (input as list):
5460
>>> import torch
@@ -85,6 +91,7 @@ def __init__(
8591
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]],
8692
*additional_metrics: Metric,
8793
prefix: Optional[str] = None,
94+
postfix: Optional[str] = None
8895
):
8996
super().__init__()
9097
if isinstance(metrics, Metric):
@@ -128,15 +135,16 @@ def __init__(
128135
else:
129136
raise ValueError("Unknown input to MetricCollection.")
130137

131-
self.prefix = self._check_prefix_arg(prefix)
138+
self.prefix = self._check_arg(prefix, 'prefix')
139+
self.postfix = self._check_arg(postfix, 'postfix')
132140

133141
def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
134142
"""
135143
Iteratively call forward for each metric. Positional arguments (args) will
136144
be passed to every metric in the collection, while keyword arguments (kwargs)
137145
will be filtered based on the signature of the individual metric.
138146
"""
139-
return {self._set_prefix(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
147+
return {self._set_name(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
140148

141149
def update(self, *args, **kwargs): # pylint: disable=E0202
142150
"""
@@ -149,20 +157,25 @@ def update(self, *args, **kwargs): # pylint: disable=E0202
149157
m.update(*args, **m_kwargs)
150158

151159
def compute(self) -> Dict[str, Any]:
152-
return {self._set_prefix(k): m.compute() for k, m in self.items()}
160+
return {self._set_name(k): m.compute() for k, m in self.items()}
153161

154162
def reset(self) -> None:
155163
""" Iteratively call reset for each metric """
156164
for _, m in self.items():
157165
m.reset()
158166

159-
def clone(self, prefix: Optional[str] = None) -> 'MetricCollection':
167+
def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> 'MetricCollection':
160168
""" Make a copy of the metric collection
161169
Args:
162170
prefix: a string to append in front of the metric keys
171+
postfix: a string to append after the keys of the output dict
172+
163173
"""
164174
mc = deepcopy(self)
165-
mc.prefix = self._check_prefix_arg(prefix)
175+
if prefix:
176+
mc.prefix = self._check_arg(prefix, 'prefix')
177+
if postfix:
178+
mc.postfix = self._check_arg(postfix, 'postfix')
166179
return mc
167180

168181
def persistent(self, mode: bool = True) -> None:
@@ -172,14 +185,15 @@ def persistent(self, mode: bool = True) -> None:
172185
for _, m in self.items():
173186
m.persistent(mode)
174187

175-
def _set_prefix(self, k: str) -> str:
176-
return k if self.prefix is None else self.prefix + k
188+
def _set_name(self, base: str) -> str:
189+
name = base if self.prefix is None else self.prefix + base
190+
name = name if self.postfix is None else name + self.postfix
191+
return name
177192

178193
@staticmethod
179-
def _check_prefix_arg(prefix: str) -> Optional[str]:
180-
if prefix is not None:
181-
if isinstance(prefix, str):
182-
return prefix
183-
else:
184-
raise ValueError('Expected input `prefix` to be a string')
194+
def _check_arg(arg: str, name: str) -> Optional[str]:
195+
if arg is not None:
196+
if isinstance(arg, str):
197+
return arg
198+
raise ValueError(f'Expected input {name} to be a string')
185199
return None

0 commit comments

Comments
 (0)