|
| 1 | +import inspect |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from omegaconf import DictConfig |
| 6 | + |
| 7 | +import pytorch_tabular.models as models |
| 8 | +from pytorch_tabular.models import BaseModel |
| 9 | +from pytorch_tabular.models.common.heads import blocks |
| 10 | +from pytorch_tabular.models.gate import GatedAdditiveTreesBackbone |
| 11 | +from pytorch_tabular.models.node import NODEBackbone |
| 12 | + |
| 13 | + |
| 14 | +def instantiate_backbone(hparams, backbone_name): |
| 15 | + backbone_class = getattr(getattr(models, hparams._module_src.split(".")[-1]), backbone_name) |
| 16 | + class_args = list(inspect.signature(backbone_class).parameters.keys()) |
| 17 | + if "config" in class_args: |
| 18 | + return backbone_class(config=hparams) |
| 19 | + else: |
| 20 | + return backbone_class( |
| 21 | + **{ |
| 22 | + arg: getattr(hparams, arg) if arg != "block_activation" else getattr(nn, getattr(hparams, arg))() |
| 23 | + for arg in class_args |
| 24 | + } |
| 25 | + ) |
| 26 | + |
| 27 | + |
| 28 | +class StackingEmbeddingLayer(nn.Module): |
| 29 | + def __init__(self, embedding_layers: nn.ModuleList): |
| 30 | + super().__init__() |
| 31 | + self.embedding_layers = embedding_layers |
| 32 | + |
| 33 | + def forward(self, x): |
| 34 | + outputs = [] |
| 35 | + for embedding_layer in self.embedding_layers: |
| 36 | + em_output = embedding_layer(x) |
| 37 | + outputs.append(em_output) |
| 38 | + return outputs |
| 39 | + |
| 40 | + |
| 41 | +class StackingBackbone(nn.Module): |
| 42 | + def __init__(self, config: DictConfig): |
| 43 | + super().__init__() |
| 44 | + self.hparams = config |
| 45 | + self._build_network() |
| 46 | + |
| 47 | + def _build_network(self): |
| 48 | + self._backbones = nn.ModuleList() |
| 49 | + self._heads = nn.ModuleList() |
| 50 | + self._backbone_output_dims = [] |
| 51 | + assert len(self.hparams.model_configs) > 0, "Stacking requires more than 0 model" |
| 52 | + for model_i in range(len(self.hparams.model_configs)): |
| 53 | + # move necessary params to each model config |
| 54 | + self.hparams.model_configs[model_i].embedded_cat_dim = self.hparams.embedded_cat_dim |
| 55 | + self.hparams.model_configs[model_i].continuous_dim = self.hparams.continuous_dim |
| 56 | + self.hparams.model_configs[model_i].n_continuous_features = self.hparams.continuous_dim |
| 57 | + |
| 58 | + self.hparams.model_configs[model_i].embedding_dims = self.hparams.embedding_dims |
| 59 | + self.hparams.model_configs[model_i].categorical_cardinality = self.hparams.categorical_cardinality |
| 60 | + self.hparams.model_configs[model_i].categorical_dim = self.hparams.categorical_dim |
| 61 | + self.hparams.model_configs[model_i].cat_embedding_dims = self.hparams.embedding_dims |
| 62 | + |
| 63 | + # if output_dim is not set, set it to 128 |
| 64 | + if getattr(self.hparams.model_configs[model_i], "output_dim", None) is None: |
| 65 | + self.hparams.model_configs[model_i].output_dim = 128 |
| 66 | + |
| 67 | + # if inferred_config is not set, set it to None. |
| 68 | + if getattr(self.hparams, "inferred_config", None) is not None: |
| 69 | + self.hparams.model_configs[model_i].inferred_config = self.hparams.inferred_config |
| 70 | + |
| 71 | + # instantiate backbone |
| 72 | + _backbone = instantiate_backbone( |
| 73 | + self.hparams.model_configs[model_i], self.hparams.model_configs[model_i]._backbone_name |
| 74 | + ) |
| 75 | + # set continuous_dim |
| 76 | + _backbone.continuous_dim = self.hparams.continuous_dim |
| 77 | + # if output_dim is not set, set it to the output_dim in model_config |
| 78 | + if getattr(_backbone, "output_dim", None) is None: |
| 79 | + setattr( |
| 80 | + _backbone, |
| 81 | + "output_dim", |
| 82 | + self.hparams.model_configs[model_i].output_dim, |
| 83 | + ) |
| 84 | + self._backbones.append(_backbone) |
| 85 | + self._backbone_output_dims.append(_backbone.output_dim) |
| 86 | + |
| 87 | + self.output_dim = sum(self._backbone_output_dims) |
| 88 | + |
| 89 | + def _build_embedding_layer(self): |
| 90 | + assert getattr(self, "_backbones", None) is not None, "Backbones are not built" |
| 91 | + embedding_layers = nn.ModuleList() |
| 92 | + for backbone in self._backbones: |
| 93 | + if getattr(backbone, "_build_embedding_layer", None) is None: |
| 94 | + embedding_layers.append(nn.Identity()) |
| 95 | + else: |
| 96 | + embedding_layers.append(backbone._build_embedding_layer()) |
| 97 | + return StackingEmbeddingLayer(embedding_layers) |
| 98 | + |
| 99 | + def forward(self, x_list): |
| 100 | + outputs = [] |
| 101 | + for i, backbone in enumerate(self._backbones): |
| 102 | + bb_output = backbone(x_list[i]) |
| 103 | + if len(bb_output.shape) == 3 and isinstance(backbone, GatedAdditiveTreesBackbone): |
| 104 | + bb_output = bb_output.mean(dim=-1) |
| 105 | + elif len(bb_output.shape) == 3 and isinstance(backbone, NODEBackbone): |
| 106 | + bb_output = bb_output.mean(dim=1) |
| 107 | + outputs.append(bb_output) |
| 108 | + x = torch.cat(outputs, dim=1) |
| 109 | + return x |
| 110 | + |
| 111 | + |
| 112 | +class StackingModel(BaseModel): |
| 113 | + def __init__(self, config: DictConfig, **kwargs): |
| 114 | + super().__init__(config, **kwargs) |
| 115 | + |
| 116 | + def _build_network(self): |
| 117 | + self._backbone = StackingBackbone(self.hparams) |
| 118 | + self._embedding_layer = self._backbone._build_embedding_layer() |
| 119 | + self.output_dim = self._backbone.output_dim |
| 120 | + self._head = self._get_head_from_config() |
| 121 | + |
| 122 | + def _get_head_from_config(self): |
| 123 | + _head_callable = getattr(blocks, self.hparams.head) |
| 124 | + return _head_callable( |
| 125 | + in_units=self.output_dim, |
| 126 | + output_dim=self.hparams.output_dim, |
| 127 | + config=_head_callable._config_template(**self.hparams.head_config), |
| 128 | + ) |
| 129 | + |
| 130 | + @property |
| 131 | + def backbone(self): |
| 132 | + return self._backbone |
| 133 | + |
| 134 | + @property |
| 135 | + def embedding_layer(self): |
| 136 | + return self._embedding_layer |
| 137 | + |
| 138 | + @property |
| 139 | + def head(self): |
| 140 | + return self._head |
0 commit comments