Skip to content

Commit e042201

Browse files
Add Built-in Support for Model Stacking (#520)
* add stacking model & config * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: Add StackingEmbeddingLayer to delete "forward" from StackingModel * refactor: remove the use of eval for passing ruff format. * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add Stacking Model Documentation and Tutorial - Updated API documentation to include `StackingModelConfig` and `StackingModel`. - Added a new tutorial notebook demonstrating model stacking in PyTorch Tabular, covering setup, configuration, training, and evaluation. - Enhanced existing documentation to explain the model stacking concept and its benefits. This commit improves the usability and understanding of the stacking functionality in the library. * Refactor: Remove GatedAdditiveTreeEnsembleConfig from model configuration This commit removes the GatedAdditiveTreeEnsembleConfig lambda function from the get_model_configs function in the test_model_stacking.py file, streamlining the model configuration process. This change enhances code clarity and focuses on the relevant model configurations for stacking. * Update mkdocs.yml to include new Model Stacking section in documentation - Added a new entry for "Model Stacking" in the navigation structure. - Included a link to the tutorial notebook "tutorials/16-Model Stacking.ipynb" for users to learn about model stacking. This change enhances the documentation by providing users with direct access to resources related to model stacking. * Refactor mkdocs.yml to streamline navigation structure - Removed unnecessary indentation for the "Model Stacking" entry in the navigation. - Maintained the link to the tutorial notebook "tutorials/16-Model Stacking.ipynb" for user access. This change improves the clarity of the documentation structure without altering the content. * Refactor StackingModelConfig to simplify model_configs type annotation - Changed the type annotation of model_configs from list[ModelConfig] to list * Refactor StackingBackbone forward method to remove type annotation * Refactor StackingEmbeddingLayer to remove type annotation from forward method * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add model stacking diagram and enhance documentation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f04a05c commit e042201

File tree

11 files changed

+1915
-1
lines changed

11 files changed

+1915
-1
lines changed

docs/apidocs_model.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
::: pytorch_tabular.models.TabTransformerConfig
3131
options:
3232
heading_level: 3
33+
::: pytorch_tabular.models.StackingModelConfig
34+
options:
35+
heading_level: 3
3336
::: pytorch_tabular.config.ModelConfig
3437
options:
3538
heading_level: 3
@@ -66,7 +69,9 @@
6669
::: pytorch_tabular.models.TabTransformerModel
6770
options:
6871
heading_level: 3
69-
72+
::: pytorch_tabular.models.StackingModel
73+
options:
74+
heading_level: 3
7075
## Base Model Class
7176
::: pytorch_tabular.models.BaseModel
7277
options:

docs/imgs/model_stacking_concept.png

59.2 KB
Loading

docs/models.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,30 @@ All the parameters have beet set to recommended values from the paper. Let's loo
253253
**For a complete list of parameters refer to the API Docs**
254254
[pytorch_tabular.models.DANetConfig][]
255255

256+
## Model Stacking
257+
258+
Model stacking is an ensemble learning technique that combines multiple base models to create a more powerful predictive model. Each base model processes the input features independently, and their outputs are concatenated before making the final prediction. This allows the model to leverage different learning patterns captured by each backbone architecture. You can use it by choosing `StackingModelConfig`.
259+
260+
The following diagram shows the concept of model stacking in PyTorch Tabular.
261+
![Model Stacking](imgs/model_stacking_concept.png)
262+
263+
The following model architectures are supported for stacking:
264+
- Category Embedding Model
265+
- TabNet Model
266+
- FTTransformer Model
267+
- Gated Additive Tree Ensemble Model
268+
- DANet Model
269+
- AutoInt Model
270+
- GANDALF Model
271+
- Node Model
272+
273+
All the parameters have been set to provide flexibility while maintaining ease of use. Let's look at them:
274+
275+
- `model_configs`: List[ModelConfig]: List of configurations for each base model. Each config should be a valid PyTorch Tabular model config (e.g., NodeConfig, GANDALFConfig)
276+
277+
**For a complete list of parameters refer to the API Docs**
278+
[pytorch_tabular.models.StackingModelConfig][]
279+
256280
## Implementing New Architectures
257281

258282
PyTorch Tabular is very easy to extend and infinitely customizable. All the models that have been implemented in PyTorch Tabular inherits an Abstract Class `BaseModel` which is in fact a PyTorchLightning Model.

docs/tutorials/16-Model Stacking.ipynb

Lines changed: 1486 additions & 0 deletions
Large diffs are not rendered by default.

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ nav:
2424
- SHAP, Deep LIFT and so on through Captum Integration: "tutorials/14-Explainability.ipynb"
2525
- Custom PyTorch Models:
2626
- Implementing New Supervised Architectures: "tutorials/04-Implementing New Architectures.ipynb"
27+
- Model Stacking: "tutorials/16-Model Stacking.ipynb"
2728
- Other Features:
2829
- Using Neural Categorical Embeddings in Scikit-Learn Workflows: "tutorials/03-Neural Embedding in Scikit-Learn Workflows.ipynb"
2930
- Self-Supervised Learning using Denoising Autoencoders: "tutorials/08-Self-Supervised Learning-DAE.ipynb"

src/pytorch_tabular/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .gate import GatedAdditiveTreeEnsembleConfig, GatedAdditiveTreeEnsembleModel
2020
from .mixture_density import MDNConfig, MDNModel
2121
from .node import NodeConfig, NODEModel
22+
from .stacking import StackingModel, StackingModelConfig
2223
from .tab_transformer import TabTransformerConfig, TabTransformerModel
2324
from .tabnet import TabNetModel, TabNetModelConfig
2425

@@ -45,6 +46,8 @@
4546
"GANDALFBackbone",
4647
"DANetConfig",
4748
"DANetModel",
49+
"StackingModel",
50+
"StackingModelConfig",
4851
"category_embedding",
4952
"node",
5053
"mixture_density",
@@ -55,4 +58,5 @@
5558
"gate",
5659
"gandalf",
5760
"danet",
61+
"stacking",
5862
]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .config import StackingModelConfig
2+
from .stacking_model import StackingBackbone, StackingModel
3+
4+
__all__ = ["StackingModel", "StackingModelConfig", "StackingBackbone"]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from dataclasses import dataclass, field
2+
3+
from pytorch_tabular.config import ModelConfig
4+
5+
6+
@dataclass
7+
class StackingModelConfig(ModelConfig):
8+
"""StackingModelConfig is a configuration class for the StackingModel. It is used to stack multiple models
9+
together. Now, CategoryEmbeddingModel, TabNetModel, FTTransformerModel, GatedAdditiveTreeEnsembleModel, DANetModel,
10+
AutoIntModel, GANDALFModel, NodeModel are supported.
11+
12+
Args:
13+
model_configs (list[ModelConfig]): List of model configs to stack.
14+
15+
"""
16+
17+
model_configs: list = field(default_factory=list, metadata={"help": "List of model configs to stack"})
18+
_module_src: str = field(default="models.stacking")
19+
_model_name: str = field(default="StackingModel")
20+
_backbone_name: str = field(default="StackingBackbone")
21+
_config_name: str = field(default="StackingConfig")
22+
23+
24+
# if __name__ == "__main__":
25+
# from pytorch_tabular.utils import generate_doc_dataclass
26+
# print(generate_doc_dataclass(StackingModelConfig))
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import inspect
2+
3+
import torch
4+
import torch.nn as nn
5+
from omegaconf import DictConfig
6+
7+
import pytorch_tabular.models as models
8+
from pytorch_tabular.models import BaseModel
9+
from pytorch_tabular.models.common.heads import blocks
10+
from pytorch_tabular.models.gate import GatedAdditiveTreesBackbone
11+
from pytorch_tabular.models.node import NODEBackbone
12+
13+
14+
def instantiate_backbone(hparams, backbone_name):
15+
backbone_class = getattr(getattr(models, hparams._module_src.split(".")[-1]), backbone_name)
16+
class_args = list(inspect.signature(backbone_class).parameters.keys())
17+
if "config" in class_args:
18+
return backbone_class(config=hparams)
19+
else:
20+
return backbone_class(
21+
**{
22+
arg: getattr(hparams, arg) if arg != "block_activation" else getattr(nn, getattr(hparams, arg))()
23+
for arg in class_args
24+
}
25+
)
26+
27+
28+
class StackingEmbeddingLayer(nn.Module):
29+
def __init__(self, embedding_layers: nn.ModuleList):
30+
super().__init__()
31+
self.embedding_layers = embedding_layers
32+
33+
def forward(self, x):
34+
outputs = []
35+
for embedding_layer in self.embedding_layers:
36+
em_output = embedding_layer(x)
37+
outputs.append(em_output)
38+
return outputs
39+
40+
41+
class StackingBackbone(nn.Module):
42+
def __init__(self, config: DictConfig):
43+
super().__init__()
44+
self.hparams = config
45+
self._build_network()
46+
47+
def _build_network(self):
48+
self._backbones = nn.ModuleList()
49+
self._heads = nn.ModuleList()
50+
self._backbone_output_dims = []
51+
assert len(self.hparams.model_configs) > 0, "Stacking requires more than 0 model"
52+
for model_i in range(len(self.hparams.model_configs)):
53+
# move necessary params to each model config
54+
self.hparams.model_configs[model_i].embedded_cat_dim = self.hparams.embedded_cat_dim
55+
self.hparams.model_configs[model_i].continuous_dim = self.hparams.continuous_dim
56+
self.hparams.model_configs[model_i].n_continuous_features = self.hparams.continuous_dim
57+
58+
self.hparams.model_configs[model_i].embedding_dims = self.hparams.embedding_dims
59+
self.hparams.model_configs[model_i].categorical_cardinality = self.hparams.categorical_cardinality
60+
self.hparams.model_configs[model_i].categorical_dim = self.hparams.categorical_dim
61+
self.hparams.model_configs[model_i].cat_embedding_dims = self.hparams.embedding_dims
62+
63+
# if output_dim is not set, set it to 128
64+
if getattr(self.hparams.model_configs[model_i], "output_dim", None) is None:
65+
self.hparams.model_configs[model_i].output_dim = 128
66+
67+
# if inferred_config is not set, set it to None.
68+
if getattr(self.hparams, "inferred_config", None) is not None:
69+
self.hparams.model_configs[model_i].inferred_config = self.hparams.inferred_config
70+
71+
# instantiate backbone
72+
_backbone = instantiate_backbone(
73+
self.hparams.model_configs[model_i], self.hparams.model_configs[model_i]._backbone_name
74+
)
75+
# set continuous_dim
76+
_backbone.continuous_dim = self.hparams.continuous_dim
77+
# if output_dim is not set, set it to the output_dim in model_config
78+
if getattr(_backbone, "output_dim", None) is None:
79+
setattr(
80+
_backbone,
81+
"output_dim",
82+
self.hparams.model_configs[model_i].output_dim,
83+
)
84+
self._backbones.append(_backbone)
85+
self._backbone_output_dims.append(_backbone.output_dim)
86+
87+
self.output_dim = sum(self._backbone_output_dims)
88+
89+
def _build_embedding_layer(self):
90+
assert getattr(self, "_backbones", None) is not None, "Backbones are not built"
91+
embedding_layers = nn.ModuleList()
92+
for backbone in self._backbones:
93+
if getattr(backbone, "_build_embedding_layer", None) is None:
94+
embedding_layers.append(nn.Identity())
95+
else:
96+
embedding_layers.append(backbone._build_embedding_layer())
97+
return StackingEmbeddingLayer(embedding_layers)
98+
99+
def forward(self, x_list):
100+
outputs = []
101+
for i, backbone in enumerate(self._backbones):
102+
bb_output = backbone(x_list[i])
103+
if len(bb_output.shape) == 3 and isinstance(backbone, GatedAdditiveTreesBackbone):
104+
bb_output = bb_output.mean(dim=-1)
105+
elif len(bb_output.shape) == 3 and isinstance(backbone, NODEBackbone):
106+
bb_output = bb_output.mean(dim=1)
107+
outputs.append(bb_output)
108+
x = torch.cat(outputs, dim=1)
109+
return x
110+
111+
112+
class StackingModel(BaseModel):
113+
def __init__(self, config: DictConfig, **kwargs):
114+
super().__init__(config, **kwargs)
115+
116+
def _build_network(self):
117+
self._backbone = StackingBackbone(self.hparams)
118+
self._embedding_layer = self._backbone._build_embedding_layer()
119+
self.output_dim = self._backbone.output_dim
120+
self._head = self._get_head_from_config()
121+
122+
def _get_head_from_config(self):
123+
_head_callable = getattr(blocks, self.hparams.head)
124+
return _head_callable(
125+
in_units=self.output_dim,
126+
output_dim=self.hparams.output_dim,
127+
config=_head_callable._config_template(**self.hparams.head_config),
128+
)
129+
130+
@property
131+
def backbone(self):
132+
return self._backbone
133+
134+
@property
135+
def embedding_layer(self):
136+
return self._embedding_layer
137+
138+
@property
139+
def head(self):
140+
return self._head

src/pytorch_tabular/models/tabnet/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class TabNetModelConfig(ModelConfig):
129129
_module_src: str = field(default="models.tabnet")
130130
_model_name: str = field(default="TabNetModel")
131131
_config_name: str = field(default="TabNetModelConfig")
132+
_backbone_name: str = field(default="TabNetBackbone")
132133

133134

134135
# if __name__ == "__main__":

0 commit comments

Comments
 (0)