Faster forward
Highligths
TorchMetrics v0.9 is now out, and it brings significant changes to how the forward method works. This blog post goes over these improvements and how they affect both users of TorchMetrics and users that implement custom metrics. TorchMetrics v0.9 also includes several new metrics and bug fixes.
Blog: TorchMetrics v0.9 — Faster forward
The Story of the Forward Method
Since the beginning of TorchMetrics, Forward has served the dual purpose of calculating the metric on the current batch and accumulating in a global state. Internally, this was achieved by calling update twice: one for each purpose, which meant repeating the same computation. However, for many metrics, calling update twice is unnecessary to achieve both the local batch statistics and accumulating globally because the global statistics are simple reductions of the local batch states.
In v0.9, we have finally implemented a logic that can take advantage of this and will only call update once before making a simple reduction. As you can see in the figure below, this can lead to a single call of forward being 2x faster in v0.9 compared to v0.8 of the same metric.
With the improvements to forward, many metrics have become significantly faster (up to 2x)
It should be noted that this change mainly benefits metrics (for example, confusionmatrix) where calling update is expensive.
We went through all existing metrics in TorchMetrics and enabled this feature for all appropriate metrics, which was almost 95% of all metrics. We want to stress that if you are using metrics from TorchMetrics, nothing has changed to the API, and no code changes are necessary.
[0.9.0] - 2022-05-31
Added
- Added
RetrievalPrecisionRecallCurveandRetrievalRecallAtFixedPrecisionto retrieval package (#951) - Added class property
full_state_updatethat determinesforwardshould callupdateonce or twice (#984,#1033) - Added support for nested metric collections (#1003)
- Added
Diceto classification package (#1021) - Added support to segmentation type
segmas IOU for mean average precision (#822)
Changed
- Renamed
reductionargument toaveragein Jaccard score and added additional options (#874)
Removed
- Removed deprecated
compute_on_stepargument (#962, #967, #979 ,#990, #991, #993, #1005, #1004, #1007)
Fixed
- Fixed non-empty state
dictfor a few metrics (#1012) - Fixed bug when comparing states while finding compute groups (#1022)
- Fixed
torch.doublesupport in stat score metrics (#1023) - Fixed
FIDcalculation for non-equal size real and fake input (#1028) - Fixed case where
KLDivergencecould outputNan(#1030) - Fixed deterministic for PyTorch<1.8 (#1035)
- Fixed default value for
mdmc_averageinAccuracy(#1036) - Fixed missing copy of property when using compute groups in
MetricCollection(#1052)
Contributors
@Borda, @burglarhobbit, @charlielito, @gianscarpe, @MrShevan, @phaseolud, @razmikmelikbekyan, @SkafteNicki, @tanmoyio, @vumichien
If we forgot someone due to not matching commit email with GitHub account, let us know :]