Skip to content

[BUG] Huge Memory Consumption for TFT & Small Dataset #1322

Open
@vilorel

Description

@vilorel
  • PyTorch-Forecasting version: 1.0.0
  • PyTorch version: 2.0.1+cpu
  • Python version: 3.10
  • Operating System: Ubuntu

Expected behavior

I followed this guide here which is mostly similar to yours except for a few changes in the trainer:

  • attention_head_size = 4
  • hidden_size = 160
  • hidden_continuous_size = 160
    The datasets between the two posts are more or less similar in size (40k - 60k row), and so are the features. I'm by no means an expert in the field, and I do realize that moving from 16/8 hidden_size to 160/160 changes the model significantly, as well as the no of parameters, but I tried to run it on a 128GB machine, and it ran out of memory. I had to use a 512GB server just to train this small dataset.

I then experimented with my own dataset and faced similar issues.

Actual behavior

To run the below example, I need again to use a server with 512GB RAM, and the RAM consumption rises up to about 74.5% and stays there throughout the training. The dataset is not that large, as you can see. What if I wanted to train 90M records or even a larger number?
The model is also not that large IMHO. Am I missing something?

Code to reproduce the problem

I then tried my own example & test dataset to give you more concrete numbers:

[172801 rows x 9 columns]
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 172801 entries, 0 to 172800
Data columns (total 9 columns):
 #   Column         Non-Null Count   Dtype   
---  ------         --------------   -----   
 0   time_idx       172801 non-null  int32   
 1   dow            172801 non-null  int8    
 2   hod            172801 non-null  int8    
 3   item           172801 non-null  category
 4   m0             172801 non-null  float32 
 5   m1             172801 non-null  float32 
 6   m2             172801 non-null  float32 
 7   m3             172801 non-null  float32 
 8   y              172801 non-null  float32 
dtypes: category(1), float32(5), int32(1), int8(2)
memory usage: 4.4 MB
ML Data Size: 172801
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/user/.local/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:197: UserWarning: Attribute 'loss' is an instance of nn.Module and is already saved during checkpointing. It is recommended to ignore them using self.save_hyperparameters(ignore=['loss']).
  rank_zero_warn(

   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0     
1  | logging_metrics                    | ModuleList                      | 0     
2  | input_embeddings                   | MultiEmbedding                  | 1     
3  | prescalers                         | ModuleDict                      | 1.5 K 
4  | static_variable_selection          | VariableSelectionNetwork        | 104 K 
5  | encoder_variable_selection         | VariableSelectionNetwork        | 211 K 
6  | decoder_variable_selection         | VariableSelectionNetwork        | 104 K 
7  | static_context_variable_selection  | GatedResidualNetwork            | 66.3 K
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 66.3 K
9  | static_context_initial_cell_lstm   | GatedResidualNetwork            | 66.3 K
10 | static_context_enrichment          | GatedResidualNetwork            | 66.3 K
11 | lstm_encoder                       | LSTM                            | 132 K 
12 | lstm_decoder                       | LSTM                            | 132 K 
13 | post_lstm_gate_encoder             | GatedLinearUnit                 | 33.0 K
14 | post_lstm_add_norm_encoder         | AddNorm                         | 256   
15 | static_enrichment                  | GatedResidualNetwork            | 82.7 K
16 | multihead_attn                     | InterpretableMultiHeadAttention | 41.2 K
17 | post_attn_gate_norm                | GateAddNorm                     | 33.3 K
18 | pos_wise_ff                        | GatedResidualNetwork            | 66.3 K
19 | pre_output_gate_norm               | GateAddNorm                     | 33.3 K
20 | output_layer                       | Linear                          | 903   
----------------------------------------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.961     Total estimated model params size (MB)

The configuration I used in this example is the following:

        lr_logger = LearningRateMonitor()
        logger = TensorBoardLogger(model_path)

        trainer = pl.Trainer(
            max_epochs=45,
            accelerator='cpu',
            devices=1,
            enable_model_summary=True,
            gradient_clip_val=0.1,
            callbacks=[lr_logger, early_stop_callback],
            logger=logger)

        tft = TemporalFusionTransformer.from_dataset(
            training,
            learning_rate=0.001,  # 0.001
            hidden_size=128, 
            hidden_continuous_size=64,  
            attention_head_size=4,
            dropout=0.1,
            output_size=7,
            loss=QuantileLoss(),
            logging_metrics=[MAE(), MeanSquaredError(), RMSE(), MAPE()],
            log_interval=10,
            reduce_on_plateau_patience=4)

        trainer.fit(
            tft,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader)

If it makes any difference, the number of workers for both dataloaders is set to 0 via the num_workers parameter.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    Status

    Needs triage & validation

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions