Skip to content

Commit 2ef478c

Browse files
authored
Merge pull request #185 from jrzaurin/flash_attention
Flash attention
2 parents cd1ff79 + 67439c4 commit 2ef478c

File tree

67 files changed

+7302
-1293
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+7302
-1293
lines changed

README.md

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,26 +130,33 @@ passed through a series of ResNet blocks built with dense layers.
130130
3. **TabNet**: details on TabNet can be found in
131131
[TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442)
132132

133+
Two simpler attention based models that we call:
134+
135+
4. **ContextAttentionMLP**: MLP with at attention mechanism "on top" that is based on
136+
[Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~./hovy/papers/16HLT-hierarchical-attention-networks.pd)
137+
5. **SelfAttentionMLP**: MLP with an attention mechanism that is a simplified
138+
version of a transformer block that we refer as "query-key self-attention".
139+
133140
The ``Tabformer`` family, i.e. Transformers for Tabular data:
134141

135-
4. **TabTransformer**: details on the TabTransformer can be found in
142+
6. **TabTransformer**: details on the TabTransformer can be found in
136143
[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
137-
5. **SAINT**: Details on SAINT can be found in
144+
7. **SAINT**: Details on SAINT can be found in
138145
[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342).
139-
6. **FT-Transformer**: details on the FT-Transformer can be found in
146+
8. **FT-Transformer**: details on the FT-Transformer can be found in
140147
[Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959).
141-
7. **TabFastFormer**: adaptation of the FastFormer for tabular data. Details
148+
9. **TabFastFormer**: adaptation of the FastFormer for tabular data. Details
142149
on the Fasformer can be found in
143150
[FastFormers: Highly Efficient Transformer Models for Natural Language Understanding](https://arxiv.org/abs/2010.13382)
144-
8. **TabPerceiver**: adaptation of the Perceiver for tabular data. Details on
151+
10. **TabPerceiver**: adaptation of the Perceiver for tabular data. Details on
145152
the Perceiver can be found in
146153
[Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)
147154

148155
And probabilistic DL models for tabular data based on
149156
[Weight Uncertainty in Neural Networks](https://arxiv.org/abs/1505.05424):
150157

151-
9. **BayesianWide**: Probabilistic adaptation of the `Wide` model.
152-
10. **BayesianTabMlp**: Probabilistic adaptation of the `TabMlp` model
158+
11. **BayesianWide**: Probabilistic adaptation of the `Wide` model.
159+
12. **BayesianTabMlp**: Probabilistic adaptation of the `TabMlp` model
153160

154161
Note that while there are scientific publications for the TabTransformer,
155162
SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own
@@ -196,7 +203,6 @@ using `Wide` and `DeepDense` and defaults settings.
196203
Building a wide (linear) and deep model with ``pytorch-widedeep``:
197204

198205
```python
199-
import pandas as pd
200206
import numpy as np
201207
import torch
202208
from sklearn.model_selection import train_test_split

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.3.1
1+
1.3.2
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from time import time
2+
3+
from sklearn.model_selection import train_test_split
4+
5+
from pytorch_widedeep import Trainer
6+
from pytorch_widedeep.models import WideDeep, TabTransformer
7+
from pytorch_widedeep.metrics import Accuracy
8+
from pytorch_widedeep.datasets import load_adult
9+
from pytorch_widedeep.preprocessing import TabPreprocessor
10+
11+
# use_cuda = torch.cuda.is_available()
12+
13+
df = load_adult(as_frame=True)
14+
df.columns = [c.replace("-", "_") for c in df.columns]
15+
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
16+
df.drop("income", axis=1, inplace=True)
17+
target_colname = "income_label"
18+
19+
cat_embed_cols = []
20+
for col in df.columns:
21+
if df[col].dtype == "O" or df[col].nunique() < 200 and col != target_colname:
22+
cat_embed_cols.append(col)
23+
24+
train, test = train_test_split(
25+
df, test_size=0.1, random_state=1, stratify=df[[target_colname]]
26+
)
27+
28+
with_cls_token = True
29+
tab_preprocessor = TabPreprocessor(
30+
cat_embed_cols=cat_embed_cols, with_attention=True, with_cls_token=with_cls_token
31+
)
32+
33+
X_tab_train = tab_preprocessor.fit_transform(train)
34+
X_tab_test = tab_preprocessor.transform(test)
35+
target = train[target_colname].values
36+
37+
38+
tab_transformer = TabTransformer(
39+
column_idx=tab_preprocessor.column_idx,
40+
cat_embed_input=tab_preprocessor.cat_embed_input,
41+
input_dim=16,
42+
n_heads=2,
43+
n_blocks=2,
44+
)
45+
46+
linear_tab_transformer = TabTransformer(
47+
column_idx=tab_preprocessor.column_idx,
48+
cat_embed_input=tab_preprocessor.cat_embed_input,
49+
input_dim=16,
50+
n_heads=2,
51+
n_blocks=2,
52+
use_linear_attention=True,
53+
)
54+
55+
flash_tab_transformer = TabTransformer(
56+
column_idx=tab_preprocessor.column_idx,
57+
cat_embed_input=tab_preprocessor.cat_embed_input,
58+
input_dim=16,
59+
n_heads=2,
60+
n_blocks=2,
61+
use_flash_attention=True,
62+
)
63+
64+
s_model = WideDeep(deeptabular=tab_transformer)
65+
l_model = WideDeep(deeptabular=linear_tab_transformer)
66+
f_model = WideDeep(deeptabular=flash_tab_transformer)
67+
68+
for name, model in [("standard", s_model), ("linear", l_model), ("flash", f_model)]:
69+
trainer = Trainer(
70+
model,
71+
objective="binary",
72+
metrics=[Accuracy],
73+
)
74+
75+
s = time()
76+
trainer.fit(
77+
X_tab=X_tab_train,
78+
target=target,
79+
n_epochs=1,
80+
batch_size=64,
81+
val_split=0.2,
82+
)
83+
e = time() - s
84+
print(f"{name} attention time: {round(e, 3)} secs")

examples/scripts/wide_deep_for_recsys/ml100k_data_preparation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from pytorch_widedeep.datasets import load_movielens100k
1212

13-
data, user, items = load_movielens100k(as_frame=True)
13+
data, users, items = load_movielens100k(as_frame=True)
1414

1515
# Alternatively, as specified in the docs: 'The last 19 fields are the genres' so:
1616
# list_of_genres = items.columns.tolist()[-19:]
@@ -37,7 +37,7 @@
3737
]
3838

3939

40-
# adding a column with the number of movies watched per user
40+
# adding a column with the number of movies watched per users
4141
dataset = data.sort_values(["user_id", "timestamp"]).reset_index(drop=True)
4242
dataset["one"] = 1
4343
dataset["num_watched"] = dataset.groupby("user_id")["one"].cumsum()
@@ -61,6 +61,9 @@
6161
)
6262
dataset["prev_movies"] = dataset["prev_movies"].apply(lambda x: x.split())
6363

64+
# Adding user feats
65+
dataset = dataset.merge(users, on="user_id", how="left")
66+
6467
# Adding a genre_rate as the mean of all movies rated for a given genre per
6568
# user
6669
dataset = dataset.merge(items[["movie_id"] + list_of_genres], on="movie_id", how="left")

examples/scripts/wide_deep_for_recsys/pytorch_wide_deep_pt2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
).to_list()
4848
y_train = df_train.target.values.astype(int)
4949

50-
df_test_user_item = df_train[["user_id", "movie_id", "rating"]]
50+
df_test_user_item = df_test[["user_id", "movie_id", "rating"]]
5151
test_movies_sequences = df_test.prev_movies.apply(
5252
lambda x: [int(el) for el in x]
5353
).to_list()
@@ -89,7 +89,7 @@
8989
tab_mlp = TabMlp(
9090
column_idx=tab_preprocessor.column_idx,
9191
cat_embed_input=tab_preprocessor.cat_embed_input,
92-
mlp_hidden_dims=[1024, 512, 256],
92+
mlp_hidden_dims=[512, 256],
9393
mlp_activation="relu",
9494
)
9595

@@ -124,7 +124,7 @@
124124
"X_text": X_test_text,
125125
"target": y_test,
126126
},
127-
n_epochs=10,
128-
batch_size=521,
127+
n_epochs=2,
128+
batch_size=32,
129129
shuffle=False,
130130
)

mkdocs/mkdocs.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ nav:
5656
- 16_Self-Supervised Pre-Training pt 1: examples/16_Self_Supervised_Pretraning_pt1.ipynb
5757
- 16_Self-Supervised Pre-Training pt 2: examples/16_Self_Supervised_Pretraning_pt2.ipynb
5858
- 17_Using_a_huggingface_model: examples/17_Usign_a_hugging_face_model.ipynb
59+
- 18_feature_importance_via_attention_weights: examples/18_feature_importance_via_attention_weights.ipynb
60+
- 19_wide_and_deep_for_recsys_pt1: examples/19_wide_and_deep_for_recsys_pt1.ipynb
61+
- 19_wide_and_deep_for_recsys_pt2: examples/19_wide_and_deep_for_recsys_pt2.ipynb
5962
- Contributing: contributing.md
6063

6164
theme:

mkdocs/site/404.html

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,12 @@
739739

740740

741741

742+
743+
744+
745+
746+
747+
742748

743749

744750

@@ -1012,6 +1018,48 @@
10121018

10131019

10141020

1021+
1022+
1023+
1024+
1025+
1026+
<li class="md-nav__item">
1027+
<a href="/examples/18_feature_importance_via_attention_weights.html" class="md-nav__link">
1028+
18_feature_importance_via_attention_weights
1029+
</a>
1030+
</li>
1031+
1032+
1033+
1034+
1035+
1036+
1037+
1038+
1039+
1040+
<li class="md-nav__item">
1041+
<a href="/examples/19_wide_and_deep_for_recsys_pt1.html" class="md-nav__link">
1042+
19_wide_and_deep_for_recsys_pt1
1043+
</a>
1044+
</li>
1045+
1046+
1047+
1048+
1049+
1050+
1051+
1052+
1053+
1054+
<li class="md-nav__item">
1055+
<a href="/examples/19_wide_and_deep_for_recsys_pt2.html" class="md-nav__link">
1056+
19_wide_and_deep_for_recsys_pt2
1057+
</a>
1058+
</li>
1059+
1060+
1061+
1062+
10151063
</ul>
10161064
</nav>
10171065
</li>

mkdocs/site/contributing.html

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,12 @@
743743

744744

745745

746+
747+
748+
749+
750+
751+
746752

747753

748754

@@ -1016,6 +1022,48 @@
10161022

10171023

10181024

1025+
1026+
1027+
1028+
1029+
1030+
<li class="md-nav__item">
1031+
<a href="examples/18_feature_importance_via_attention_weights.html" class="md-nav__link">
1032+
18_feature_importance_via_attention_weights
1033+
</a>
1034+
</li>
1035+
1036+
1037+
1038+
1039+
1040+
1041+
1042+
1043+
1044+
<li class="md-nav__item">
1045+
<a href="examples/19_wide_and_deep_for_recsys_pt1.html" class="md-nav__link">
1046+
19_wide_and_deep_for_recsys_pt1
1047+
</a>
1048+
</li>
1049+
1050+
1051+
1052+
1053+
1054+
1055+
1056+
1057+
1058+
<li class="md-nav__item">
1059+
<a href="examples/19_wide_and_deep_for_recsys_pt2.html" class="md-nav__link">
1060+
19_wide_and_deep_for_recsys_pt2
1061+
</a>
1062+
</li>
1063+
1064+
1065+
1066+
10191067
</ul>
10201068
</nav>
10211069
</li>
@@ -1095,7 +1143,7 @@ <h1>Contributing</h1>
10951143
<nav class="md-footer__inner md-grid" aria-label="Footer" >
10961144

10971145

1098-
<a href="examples/17_Usign_a_hugging_face_model.html" class="md-footer__link md-footer__link--prev" aria-label="Previous: 17_Using_a_huggingface_model" rel="prev">
1146+
<a href="examples/19_wide_and_deep_for_recsys_pt2.html" class="md-footer__link md-footer__link--prev" aria-label="Previous: 19_wide_and_deep_for_recsys_pt2" rel="prev">
10991147
<div class="md-footer__button md-icon">
11001148
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11h12Z"/></svg>
11011149
</div>
@@ -1104,7 +1152,7 @@ <h1>Contributing</h1>
11041152
<span class="md-footer__direction">
11051153
Previous
11061154
</span>
1107-
17_Using_a_huggingface_model
1155+
19_wide_and_deep_for_recsys_pt2
11081156
</div>
11091157
</div>
11101158
</a>

0 commit comments

Comments
 (0)