Skip to content

Commit c8495ec

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
optionally checkpoint r_squared metrics when required (pytorch#3083)
Summary: Pull Request resolved: pytorch#3083 r_squared states get checkpointed when not needed, resulting in loading issues. this bypasses by making those states persistent only when needed, ie: r_squared=True Reviewed By: zxpmirror1994 Differential Revision: D76162934 fbshipit-source-id: adf76774b8d6331c0a67250eef11060b1bbb419e
1 parent fd68a70 commit c8495ec

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchrec/metrics/mse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,14 @@ def __init__(
110110
torch.zeros(self._n_tasks, dtype=torch.double),
111111
add_window_state=True,
112112
dist_reduce_fx="sum",
113-
persistent=True,
113+
persistent=include_r_squared,
114114
)
115115
self._add_state(
116116
"label_squared_sum",
117117
torch.zeros(self._n_tasks, dtype=torch.double),
118118
add_window_state=True,
119119
dist_reduce_fx="sum",
120-
persistent=True,
120+
persistent=include_r_squared,
121121
)
122122

123123
def update(

0 commit comments

Comments
 (0)