Skip to content

Commit 49b2f4c

Browse files
GdoongMathewsudiptob2
authored andcommitted
feat: add flops in ModelSummary columns (#20868)
1 parent 1286cdd commit 49b2f4c

File tree

6 files changed

+78
-14
lines changed

6 files changed

+78
-14
lines changed

src/lightning/pytorch/callbacks/model_summary.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,17 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
6868
model_size = model_summary.model_size
6969
total_training_modes = model_summary.total_training_modes
7070

71+
# todo Add `total_flops` in DeepSpeedSummary.
72+
total_flops = model_summary.total_flops if hasattr(model_summary, "total_flops") else 0
73+
7174
if trainer.is_global_zero:
7275
self.summarize(
7376
summary_data,
7477
total_parameters,
7578
trainable_parameters,
7679
model_size,
7780
total_training_modes,
81+
total_flops=total_flops,
7882
**self._summarize_kwargs,
7983
)
8084

@@ -92,13 +96,15 @@ def summarize(
9296
trainable_parameters: int,
9397
model_size: float,
9498
total_training_modes: dict[str, int],
99+
total_flops: int,
95100
**summarize_kwargs: Any,
96101
) -> None:
97102
summary_table = _format_summary_table(
98103
total_parameters,
99104
trainable_parameters,
100105
model_size,
101106
total_training_modes,
107+
total_flops,
102108
*summary_data,
103109
)
104110
log.info("\n" + summary_table)

src/lightning/pytorch/callbacks/rich_model_summary.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def summarize(
7272
trainable_parameters: int,
7373
model_size: float,
7474
total_training_modes: dict[str, int],
75+
total_flops: int,
7576
**summarize_kwargs: Any,
7677
) -> None:
7778
from rich import get_console
@@ -86,6 +87,7 @@ def summarize(
8687
table.add_column("Type")
8788
table.add_column("Params", justify="right")
8889
table.add_column("Mode")
90+
table.add_column("FLOPs", justify="right")
8991

9092
column_names = list(zip(*summary_data))[0]
9193

@@ -113,5 +115,6 @@ def summarize(
113115
grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")
114116
grid.add_row(f"[bold]Modules in train mode[/]: {total_training_modes['train']}")
115117
grid.add_row(f"[bold]Modules in eval mode[/]: {total_training_modes['eval']}")
118+
grid.add_row(f"[bold]Total FLOPs[/]: {get_human_readable_count(total_flops)}")
116119

117120
console.print(grid)

src/lightning/pytorch/utilities/model_summary/model_summary.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
import torch
2323
import torch.nn as nn
2424
from torch import Tensor
25+
from torch.utils.flop_counter import FlopCounterMode
2526
from torch.utils.hooks import RemovableHandle
2627

2728
import lightning.pytorch as pl
2829
from lightning.fabric.utilities.distributed import _is_dtensor
30+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
2931
from lightning.pytorch.utilities.model_helpers import _ModuleMode
3032
from lightning.pytorch.utilities.rank_zero import WarningCache
3133

@@ -180,29 +182,31 @@ class ModelSummary:
180182
...
181183
>>> model = LitModel()
182184
>>> ModelSummary(model, max_depth=1) # doctest: +NORMALIZE_WHITESPACE
183-
| Name | Type | Params | Mode | In sizes | Out sizes
184-
--------------------------------------------------------------------
185-
0 | net | Sequential | 132 K | train | [10, 256] | [10, 512]
186-
--------------------------------------------------------------------
185+
| Name | Type | Params | Mode | FLOPs | In sizes | Out sizes
186+
----------------------------------------------------------------------------
187+
0 | net | Sequential | 132 K | train | 2.6 M | [10, 256] | [10, 512]
188+
----------------------------------------------------------------------------
187189
132 K Trainable params
188190
0 Non-trainable params
189191
132 K Total params
190192
0.530 Total estimated model params size (MB)
191193
3 Modules in train mode
192194
0 Modules in eval mode
195+
2.6 M Total Flops
193196
>>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE
194-
| Name | Type | Params | Mode | In sizes | Out sizes
195-
----------------------------------------------------------------------
196-
0 | net | Sequential | 132 K | train | [10, 256] | [10, 512]
197-
1 | net.0 | Linear | 131 K | train | [10, 256] | [10, 512]
198-
2 | net.1 | BatchNorm1d | 1.0 K | train | [10, 512] | [10, 512]
199-
----------------------------------------------------------------------
197+
| Name | Type | Params | Mode | FLOPs | In sizes | Out sizes
198+
------------------------------------------------------------------------------
199+
0 | net | Sequential | 132 K | train | 2.6 M | [10, 256] | [10, 512]
200+
1 | net.0 | Linear | 131 K | train | 2.6 M | [10, 256] | [10, 512]
201+
2 | net.1 | BatchNorm1d | 1.0 K | train | 0 | [10, 512] | [10, 512]
202+
------------------------------------------------------------------------------
200203
132 K Trainable params
201204
0 Non-trainable params
202205
132 K Total params
203206
0.530 Total estimated model params size (MB)
204207
3 Modules in train mode
205208
0 Modules in eval mode
209+
2.6 M Total Flops
206210
207211
"""
208212

@@ -212,6 +216,13 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None:
212216
if not isinstance(max_depth, int) or max_depth < -1:
213217
raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")
214218

219+
# The max-depth needs to be plus one because the root module is already counted as depth 0.
220+
self._flop_counter = FlopCounterMode(
221+
mods=None if _TORCH_GREATER_EQUAL_2_4 else self._model,
222+
display=False,
223+
depth=max_depth + 1,
224+
)
225+
215226
self._max_depth = max_depth
216227
self._layer_summary = self.summarize()
217228
# 1 byte -> 8 bits
@@ -279,6 +290,22 @@ def total_layer_params(self) -> int:
279290
def model_size(self) -> float:
280291
return self.total_parameters * self._precision_megabytes
281292

293+
@property
294+
def total_flops(self) -> int:
295+
return self._flop_counter.get_total_flops()
296+
297+
@property
298+
def flop_counts(self) -> dict[str, dict[Any, int]]:
299+
flop_counts = self._flop_counter.get_flop_counts()
300+
ret = {
301+
name: flop_counts.get(
302+
f"{type(self._model).__name__}.{name}",
303+
{},
304+
)
305+
for name in self.layer_names
306+
}
307+
return ret
308+
282309
def summarize(self) -> dict[str, LayerSummary]:
283310
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
284311
if self._model.example_input_array is not None:
@@ -307,8 +334,18 @@ def _forward_example_input(self) -> None:
307334
mode.capture(model)
308335
model.eval()
309336

337+
# FlopCounterMode does not support ScriptModules before torch 2.4.0, so we use a null context
338+
flop_context = (
339+
contextlib.nullcontext()
340+
if (
341+
not _TORCH_GREATER_EQUAL_2_4
342+
and any(isinstance(m, torch.jit.ScriptModule) for m in self._model.modules())
343+
)
344+
else self._flop_counter
345+
)
346+
310347
forward_context = contextlib.nullcontext() if trainer is None else trainer.precision_plugin.forward_context()
311-
with torch.no_grad(), forward_context:
348+
with torch.no_grad(), forward_context, flop_context:
312349
# let the model hooks collect the input- and output shapes
313350
if isinstance(input_, (list, tuple)):
314351
model(*input_)
@@ -330,6 +367,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]:
330367
("Type", self.layer_types),
331368
("Params", list(map(get_human_readable_count, self.param_nums))),
332369
("Mode", ["train" if mode else "eval" for mode in self.training_modes]),
370+
("FLOPs", list(map(get_human_readable_count, (sum(x.values()) for x in self.flop_counts.values())))),
333371
]
334372
if self._model.example_input_array is not None:
335373
arrays.append(("In sizes", [str(x) for x in self.in_sizes]))
@@ -349,6 +387,7 @@ def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], t
349387
layer_summaries["Type"].append(NOT_APPLICABLE)
350388
layer_summaries["Params"].append(get_human_readable_count(total_leftover_params))
351389
layer_summaries["Mode"].append(NOT_APPLICABLE)
390+
layer_summaries["FLOPs"].append(NOT_APPLICABLE)
352391
if "In sizes" in layer_summaries:
353392
layer_summaries["In sizes"].append(NOT_APPLICABLE)
354393
if "Out sizes" in layer_summaries:
@@ -361,8 +400,16 @@ def __str__(self) -> str:
361400
trainable_parameters = self.trainable_parameters
362401
model_size = self.model_size
363402
total_training_modes = self.total_training_modes
364-
365-
return _format_summary_table(total_parameters, trainable_parameters, model_size, total_training_modes, *arrays)
403+
total_flops = self.total_flops
404+
405+
return _format_summary_table(
406+
total_parameters,
407+
trainable_parameters,
408+
model_size,
409+
total_training_modes,
410+
total_flops,
411+
*arrays,
412+
)
366413

367414
def __repr__(self) -> str:
368415
return str(self)
@@ -383,6 +430,7 @@ def _format_summary_table(
383430
trainable_parameters: int,
384431
model_size: float,
385432
total_training_modes: dict[str, int],
433+
total_flops: int,
386434
*cols: tuple[str, list[str]],
387435
) -> str:
388436
"""Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big
@@ -423,6 +471,8 @@ def _format_summary_table(
423471
summary += "Modules in train mode"
424472
summary += "\n" + s.format(total_training_modes["eval"], 10)
425473
summary += "Modules in eval mode"
474+
summary += "\n" + s.format(get_human_readable_count(total_flops), 10)
475+
summary += "Total Flops"
426476

427477
return summary
428478

tests/tests_pytorch/callbacks/test_model_summary.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def summarize(
6565
assert summary_data[4][0] == "Mode"
6666
assert summary_data[4][1][0] == "train"
6767

68+
assert summary_data[5][0] == "FLOPs"
69+
assert all(isinstance(x, str) for x in summary_data[5][1])
70+
6871
assert total_training_modes == {"train": 1, "eval": 0}
6972

7073
model = BoringModel()

tests/tests_pytorch/callbacks/test_rich_model_summary.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,11 @@ def example_input_array(self) -> Any:
6262
trainable_parameters=1,
6363
model_size=1,
6464
total_training_modes=summary.total_training_modes,
65+
total_flops=1,
6566
)
6667

6768
# ensure that summary was logged + the breakdown of model parameters
6869
assert mock_console.call_count == 2
6970
# assert that the input summary data was converted correctly
7071
args, _ = mock_table_add_row.call_args_list[0]
71-
assert args[1:] == ("0", "layer", "Linear", "66 ", "train", "[4, 32]", "[4, 2]")
72+
assert args[1:] == ("0", "layer", "Linear", "66 ", "train", "512 ", "[4, 32]", "[4, 2]")

tests/tests_pytorch/utilities/test_model_summary.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def test_empty_model_summary_shapes(max_depth):
173173
assert summary.in_sizes == []
174174
assert summary.out_sizes == []
175175
assert summary.param_nums == []
176+
assert summary.total_flops == 0
176177

177178

178179
@pytest.mark.parametrize("max_depth", [-1, 1])

0 commit comments

Comments
 (0)