Skip to content

Commit 7ba8948

Browse files
authored
feat: simple transformer embedding net (#1494)
* attention and rotary encoding * rm unnecessary parameters * standard rope * implemented transformer architecture * fixed the rope frequencies * added support for MoE (from Mixtral) * comments and vit embedding * added arxiv ref * final fixes * type linting * fixed types * type fixes * type fixes
1 parent dd882dc commit 7ba8948

File tree

3 files changed

+929
-0
lines changed

3 files changed

+929
-0
lines changed

sbi/neural_nets/embedding_nets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ResNetEmbedding1D,
1111
ResNetEmbedding2D,
1212
)
13+
from sbi.neural_nets.embedding_nets.transformer import TransformerEmbedding
1314

1415
__all__ = [
1516
"CausalCNNEmbedding",

0 commit comments

Comments
 (0)