22
22
import torch
23
23
import torch .nn as nn
24
24
from torch import Tensor
25
+ from torch .utils .flop_counter import FlopCounterMode
25
26
from torch .utils .hooks import RemovableHandle
26
27
27
28
import lightning .pytorch as pl
28
29
from lightning .fabric .utilities .distributed import _is_dtensor
30
+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_4
29
31
from lightning .pytorch .utilities .model_helpers import _ModuleMode
30
32
from lightning .pytorch .utilities .rank_zero import WarningCache
31
33
@@ -180,29 +182,31 @@ class ModelSummary:
180
182
...
181
183
>>> model = LitModel()
182
184
>>> ModelSummary(model, max_depth=1) # doctest: +NORMALIZE_WHITESPACE
183
- | Name | Type | Params | Mode | In sizes | Out sizes
184
- --------------------------------------------------------------------
185
- 0 | net | Sequential | 132 K | train | [10, 256] | [10, 512]
186
- --------------------------------------------------------------------
185
+ | Name | Type | Params | Mode | FLOPs | In sizes | Out sizes
186
+ ----------------------------------------------------------------------------
187
+ 0 | net | Sequential | 132 K | train | 2.6 M | [10, 256] | [10, 512]
188
+ ----------------------------------------------------------------------------
187
189
132 K Trainable params
188
190
0 Non-trainable params
189
191
132 K Total params
190
192
0.530 Total estimated model params size (MB)
191
193
3 Modules in train mode
192
194
0 Modules in eval mode
195
+ 2.6 M Total Flops
193
196
>>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE
194
- | Name | Type | Params | Mode | In sizes | Out sizes
195
- ----------------------------------------------------------------------
196
- 0 | net | Sequential | 132 K | train | [10, 256] | [10, 512]
197
- 1 | net.0 | Linear | 131 K | train | [10, 256] | [10, 512]
198
- 2 | net.1 | BatchNorm1d | 1.0 K | train | [10, 512] | [10, 512]
199
- ----------------------------------------------------------------------
197
+ | Name | Type | Params | Mode | FLOPs | In sizes | Out sizes
198
+ ------------------------------------------------------------------------------
199
+ 0 | net | Sequential | 132 K | train | 2.6 M | [10, 256] | [10, 512]
200
+ 1 | net.0 | Linear | 131 K | train | 2.6 M | [10, 256] | [10, 512]
201
+ 2 | net.1 | BatchNorm1d | 1.0 K | train | 0 | [10, 512] | [10, 512]
202
+ ------------------------------------------------------------------------------
200
203
132 K Trainable params
201
204
0 Non-trainable params
202
205
132 K Total params
203
206
0.530 Total estimated model params size (MB)
204
207
3 Modules in train mode
205
208
0 Modules in eval mode
209
+ 2.6 M Total Flops
206
210
207
211
"""
208
212
@@ -212,6 +216,13 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None:
212
216
if not isinstance (max_depth , int ) or max_depth < - 1 :
213
217
raise ValueError (f"`max_depth` can be -1, 0 or > 0, got { max_depth } ." )
214
218
219
+ # The max-depth needs to be plus one because the root module is already counted as depth 0.
220
+ self ._flop_counter = FlopCounterMode (
221
+ mods = None if _TORCH_GREATER_EQUAL_2_4 else self ._model ,
222
+ display = False ,
223
+ depth = max_depth + 1 ,
224
+ )
225
+
215
226
self ._max_depth = max_depth
216
227
self ._layer_summary = self .summarize ()
217
228
# 1 byte -> 8 bits
@@ -279,6 +290,22 @@ def total_layer_params(self) -> int:
279
290
def model_size (self ) -> float :
280
291
return self .total_parameters * self ._precision_megabytes
281
292
293
+ @property
294
+ def total_flops (self ) -> int :
295
+ return self ._flop_counter .get_total_flops ()
296
+
297
+ @property
298
+ def flop_counts (self ) -> dict [str , dict [Any , int ]]:
299
+ flop_counts = self ._flop_counter .get_flop_counts ()
300
+ ret = {
301
+ name : flop_counts .get (
302
+ f"{ type (self ._model ).__name__ } .{ name } " ,
303
+ {},
304
+ )
305
+ for name in self .layer_names
306
+ }
307
+ return ret
308
+
282
309
def summarize (self ) -> dict [str , LayerSummary ]:
283
310
summary = OrderedDict ((name , LayerSummary (module )) for name , module in self .named_modules )
284
311
if self ._model .example_input_array is not None :
@@ -307,8 +334,18 @@ def _forward_example_input(self) -> None:
307
334
mode .capture (model )
308
335
model .eval ()
309
336
337
+ # FlopCounterMode does not support ScriptModules before torch 2.4.0, so we use a null context
338
+ flop_context = (
339
+ contextlib .nullcontext ()
340
+ if (
341
+ not _TORCH_GREATER_EQUAL_2_4
342
+ and any (isinstance (m , torch .jit .ScriptModule ) for m in self ._model .modules ())
343
+ )
344
+ else self ._flop_counter
345
+ )
346
+
310
347
forward_context = contextlib .nullcontext () if trainer is None else trainer .precision_plugin .forward_context ()
311
- with torch .no_grad (), forward_context :
348
+ with torch .no_grad (), forward_context , flop_context :
312
349
# let the model hooks collect the input- and output shapes
313
350
if isinstance (input_ , (list , tuple )):
314
351
model (* input_ )
@@ -330,6 +367,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]:
330
367
("Type" , self .layer_types ),
331
368
("Params" , list (map (get_human_readable_count , self .param_nums ))),
332
369
("Mode" , ["train" if mode else "eval" for mode in self .training_modes ]),
370
+ ("FLOPs" , list (map (get_human_readable_count , (sum (x .values ()) for x in self .flop_counts .values ())))),
333
371
]
334
372
if self ._model .example_input_array is not None :
335
373
arrays .append (("In sizes" , [str (x ) for x in self .in_sizes ]))
@@ -349,6 +387,7 @@ def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], t
349
387
layer_summaries ["Type" ].append (NOT_APPLICABLE )
350
388
layer_summaries ["Params" ].append (get_human_readable_count (total_leftover_params ))
351
389
layer_summaries ["Mode" ].append (NOT_APPLICABLE )
390
+ layer_summaries ["FLOPs" ].append (NOT_APPLICABLE )
352
391
if "In sizes" in layer_summaries :
353
392
layer_summaries ["In sizes" ].append (NOT_APPLICABLE )
354
393
if "Out sizes" in layer_summaries :
@@ -361,8 +400,16 @@ def __str__(self) -> str:
361
400
trainable_parameters = self .trainable_parameters
362
401
model_size = self .model_size
363
402
total_training_modes = self .total_training_modes
364
-
365
- return _format_summary_table (total_parameters , trainable_parameters , model_size , total_training_modes , * arrays )
403
+ total_flops = self .total_flops
404
+
405
+ return _format_summary_table (
406
+ total_parameters ,
407
+ trainable_parameters ,
408
+ model_size ,
409
+ total_training_modes ,
410
+ total_flops ,
411
+ * arrays ,
412
+ )
366
413
367
414
def __repr__ (self ) -> str :
368
415
return str (self )
@@ -383,6 +430,7 @@ def _format_summary_table(
383
430
trainable_parameters : int ,
384
431
model_size : float ,
385
432
total_training_modes : dict [str , int ],
433
+ total_flops : int ,
386
434
* cols : tuple [str , list [str ]],
387
435
) -> str :
388
436
"""Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big
@@ -423,6 +471,8 @@ def _format_summary_table(
423
471
summary += "Modules in train mode"
424
472
summary += "\n " + s .format (total_training_modes ["eval" ], 10 )
425
473
summary += "Modules in eval mode"
474
+ summary += "\n " + s .format (get_human_readable_count (total_flops ), 10 )
475
+ summary += "Total Flops"
426
476
427
477
return summary
428
478
0 commit comments