Skip to content

Commit b027a7c

Browse files
YanXiong-Metafacebook-github-bot
authored andcommitted
convert batch size to float before torch.std in params reporter
Summary: X-link: facebookresearch/FBGEMM#1854 ### Description This diff converts the batch size to a float before calculating the standard deviation in the `params_reporter` module. This change is made to ensure accurate calculations, as `torch.std` expects a floating-point input. ### Changes - **Modified code in `bench_params_reporter.py`** - Changed the line `torch.std(torch.tensor([b for bs in batch_size_per_feature_per_rank for b in bs]))` to `torch.std(torch.tensor([b for bs in batch_size_per_feature_per_rank for b in bs])).float()`. ### Reason The `torch.std` function requires a floating-point tensor to calculate the standard deviation. By converting the batch size to a float, we ensure that the calculation is performed correctly. Reviewed By: q10 Differential Revision: D81809491 fbshipit-source-id: c03ed8ae80923f89d531c648827bc22a8fcd9c70
1 parent 4d66f94 commit b027a7c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def extract_params(
209209
for bs in batch_size_per_feature_per_rank
210210
for b in bs
211211
]
212-
)
212+
).float()
213213
)
214214
)
215215
)

0 commit comments

Comments
 (0)