Skip to content

[Prototype] Block interface: workspace #342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: block_interface_config
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7174e32
stuff
jlamypoirier Jul 31, 2025
9c8f1a6
Merge branch 'block_interface' into block_interface_2
jlamypoirier Jul 31, 2025
267ba29
stuff
jlamypoirier Jul 31, 2025
46deb66
Merge branch 'block_interface' into block_interface_2
jlamypoirier Jul 31, 2025
faa83fe
misc
jlamypoirier Jul 31, 2025
9e79555
Merge branch 'block_interface' into block_interface_2
jlamypoirier Jul 31, 2025
729ee1d
misc
jlamypoirier Jul 31, 2025
c3ae392
Merge branch 'block_interface' into block_interface_2
jlamypoirier Jul 31, 2025
aec2989
misc
jlamypoirier Jul 31, 2025
91ae15e
Merge branch 'block_interface' into block_interface_2
jlamypoirier Jul 31, 2025
bdb738b
misc
jlamypoirier Jul 31, 2025
ebdf8ee
Merge branch 'block_interface' into block_interface_2
jlamypoirier Jul 31, 2025
8c7a451
misc
jlamypoirier Jul 31, 2025
4fc292a
Merge branch 'block_interface' into block_interface_2
jlamypoirier Aug 1, 2025
dc0741d
Merge branch 'block_interface' into block_interface_2
jlamypoirier Aug 1, 2025
c63ec17
fixes
jlamypoirier Aug 1, 2025
5bf6639
Merge branch 'block_interface_config' into block_interface_2
jlamypoirier Aug 1, 2025
47dc12d
stuff
jlamypoirier Aug 1, 2025
17f7c80
stuff
jlamypoirier Aug 1, 2025
9f88310
Merge branch 'block_interface_config' into block_interface_2
jlamypoirier Aug 1, 2025
f99bbab
stuff
jlamypoirier Aug 1, 2025
17e66a8
Merge branch 'block_interface_config' into block_interface_2
jlamypoirier Aug 1, 2025
5b4c2b3
stuff
jlamypoirier Aug 1, 2025
13b7370
Merge branch 'block_interface_config' into block_interface_2
jlamypoirier Aug 1, 2025
69ac11b
Merge branch 'block_interface_config' into block_interface_2
jlamypoirier Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 58 additions & 10 deletions fast_llm/layers/block/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import abc
import enum
import functools
import typing

from fast_llm.config import Field, FieldHint, check_field, config_class
Expand All @@ -9,7 +11,8 @@
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.layers.block.block import BlockLayer
from fast_llm.layers.block.block import Block, BlockLayer


# TODO: Generalize these beyond language models? (Ex. vision)

Expand Down Expand Up @@ -156,6 +159,43 @@ class BlockConfig(BaseModelConfig):
hint=FieldHint.architecture,
)

block_sequence: "BlockSequenceConfig" = Field(init=False)

def _validate(self) -> None:
assert hasattr(self, "block_sequence")
Assert.incl(self, self.block_sequence.blocks.values())
self.mixer.block = self
self.mlp.block = self
super()._validate()

def setup_tensor_space(self, tensor_space: TensorSpace) -> None:
self.mlp.setup_tensor_space(tensor_space)
self.mixer.setup_tensor_space(tensor_space)

# Hidden dimension
tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.block_sequence.hidden_size))

@abc.abstractmethod
def get_block(self) -> "Block":
pass


@config_class()
class BlockSequenceConfig(BaseModelConfig):
_abstract = True

blocks: dict[str, BlockConfig] = Field()
block_pattern: tuple[str, ...] = Field(
default=None,
desc="The pattern of blocks (referred by name) to use. The sequence is repeated until reaching `num_blocks`."
" Default: cycle over `blocks` in the order they are defined.",
)
default_block: str = Field(
default=None,
desc="The default block configuration to use when referring to the model."
" Used to set some defaults in the language model.",
)

# TODO: Move these, not specific to a single block.
num_blocks: int = Field(
default=12,
Expand All @@ -174,15 +214,23 @@ class BlockConfig(BaseModelConfig):
desc="Store the residuals for the transformer in full precision (`optimization_dtype`).",
hint=FieldHint.stability,
)
per_layer_lr_scale: list[float] | None = Field(
default=None,
desc="Custom learning rate scale for each layer.",
doc="May be used to freeze some layers by setting their scale to zero.",
hint=FieldHint.feature,
)

def _validate(self) -> None:
for block in self.blocks.values():
block.validate()
if self.block_pattern is None:
self.block_pattern = tuple(self.blocks)
if self.default_block is None:
self.default_block = self.block_pattern[0]
super()._validate()

def get_block_config(self, block_index: int) -> BlockConfig:
return self.blocks[self.block_pattern[block_index % len(self.block_pattern)]]

def setup_tensor_space(self, tensor_space: TensorSpace) -> None:
super().setup_tensor_space(tensor_space)
for block in self.blocks.values():
block.setup_tensor_space(tensor_space)

# Hidden dimension
tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.hidden_size))
@functools.cached_property
def default_block_config(self) -> BlockConfig:
return self.blocks[self.default_block]
8 changes: 5 additions & 3 deletions fast_llm/layers/block/mlp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _validate(self) -> None:
self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu
# TODO: `hidden_size` not yet validated.
if self.ffn_hidden_size is None:
self.ffn_hidden_size = 4 * self.block.hidden_size
self.ffn_hidden_size = 4 * self.block.block_sequence.hidden_size

self.num_unshared_experts = self.num_experts - self.num_shared_experts

Expand All @@ -206,7 +206,7 @@ def layer_1_weight_initialization_method(self) -> Initializer:
if self.layer_1_weight_initialization.has_initialization:
return self.layer_1_weight_initialization.get_initializer()
else:
return init_normal_(0, self.block.hidden_size**-0.5)
return init_normal_(0, self.block.block_sequence.hidden_size**-0.5)

@functools.cached_property
def layer_1_bias_initialization_method(self) -> Initializer:
Expand All @@ -220,7 +220,9 @@ def layer_2_weight_initialization_method(self) -> Initializer:
if self.layer_2_weight_initialization.has_initialization:
return self.layer_2_weight_initialization.get_initializer()
else:
return init_normal_(0, self.block.hidden_size**-0.5 / max(2 * self.block.num_blocks, 1))
return init_normal_(
0, self.block.block_sequence.hidden_size**-0.5 / max(2 * self.block.block_sequence.num_blocks, 1)
)

@functools.cached_property
def layer_2_bias_initialization_method(self) -> Initializer:
Expand Down
14 changes: 9 additions & 5 deletions fast_llm/layers/block/mlp/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,26 @@
from fast_llm.functional.config import TritonConfig
from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd
from fast_llm.layers.block.block import BlockLayer
from fast_llm.layers.block.config import BlockConfig, BlockDimNames
from fast_llm.layers.block.config import BlockDimNames
from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames
from fast_llm.layers.block.peft import TransformerSubLayerName
from fast_llm.layers.common.linear import LinearBase
from fast_llm.utils import get_lr_scale


class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]):
def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int, name: str):
def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str):
super().__init__(config, tensor_space, block_index, name)

hidden_dim = self._tensor_space[BlockDimNames.hidden]
self._intermediate_dim = self._tensor_space[MLPDimNames.composite_expert_mlp]
self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation

layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None
layer_lr_scale = (
self._config.block.block_sequence.per_layer_lr_scale[self._block_index]
if self._config.block.block_sequence.per_layer_lr_scale
else None
)
lr_scale = (
tuple(self._config.mlp_lr_scale)
if isinstance(self._config.mlp_lr_scale, list)
Expand Down Expand Up @@ -50,8 +54,8 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index:
)

# PEFT.
self.layer_1 = self._config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1)
self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2)
self.layer_1 = self._config.block.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1)
self.layer_2 = self._config.block.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2)


class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]):
Expand Down
32 changes: 16 additions & 16 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import functools

from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.engine.base_model.config import BaseModelConfig
from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_
from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
from fast_llm.engine.distributed.config import DistributedDimNames
from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl
from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs
from fast_llm.layers.block.config import BlockDimNames, BlockKwargs, BlockSequenceConfig
from fast_llm.utils import Assert


Expand Down Expand Up @@ -45,10 +44,8 @@ class LanguageModelKwargs(BlockKwargs):


@config_class()
class LanguageModelBaseConfig(BaseModelConfig):
# TODO: block
transformer: BlockConfig = Field(
desc="Configuration for the transformer architecture.",
class LanguageModelBaseConfig(BlockSequenceConfig):
decoder: BlockSequenceConfig = Field(
hint=FieldHint.architecture,
)
vocab_size: int = Field(
Expand All @@ -57,6 +54,13 @@ class LanguageModelBaseConfig(BaseModelConfig):
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
embedding_dropout: float = Field(
# TODO: backward compatibility?
default=0.0,
desc="Dropout applied to the embedding layer.",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
absolute_position_embeddings: int | None = Field(
# TODO: backward compatibility?
default=None,
Expand Down Expand Up @@ -209,19 +213,14 @@ def _validate(self) -> None:
Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads)
for coeff in self.prediction_loss_coefficient:
Assert.geq(coeff, 0)
if self.transformer.per_layer_lr_scale is not None:
# -1 because the first prediction head's transformer layer is accounted for in num_layers
# +1 because the layer index starts at 1
Assert.eq(
len(self.transformer.per_layer_lr_scale), self.transformer.num_blocks + self.prediction_heads - 1 + 1
)

if self.output_weight_initialization.has_initialization:
assert self.use_absolute_position_embeddings
if self.output_weight_initialization.has_initialization:
assert not self.tie_word_embeddings

def setup_tensor_space(self, tensor_space: TensorSpace) -> None:
self.transformer.setup_tensor_space(tensor_space)
super().setup_tensor_space(tensor_space)
tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor)

# Embedding dimensions
Expand All @@ -235,25 +234,26 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None:

@property
def use_absolute_position_embeddings(self) -> int:
# TODO: Set through num embeddings instead instead.
return self.absolute_position_embeddings is not None

@functools.cached_property
def word_embedding_weight_initialization_method(self) -> Initializer:
if self.word_embedding_weight_initialization.has_initialization:
return self.word_embedding_weight_initialization.get_initializer()
else:
return init_normal_(self.transformer.hidden_size**-0.5)
return init_normal_(self.hidden_size**-0.5)

@functools.cached_property
def position_embedding_weight_initialization_method(self) -> Initializer:
if self.position_embedding_weight_initialization.has_initialization:
return self.position_embedding_weight_initialization.get_initializer()
else:
return init_normal_(self.transformer.hidden_size**-0.5)
return init_normal_(self.hidden_size**-0.5)

@functools.cached_property
def output_weight_initialization_method(self) -> Initializer:
if self.output_weight_initialization.has_initialization:
return self.output_weight_initialization.get_initializer()
else:
return init_normal_(self.transformer.hidden_size**-0.5)
return init_normal_(self.hidden_size**-0.5)
5 changes: 3 additions & 2 deletions fast_llm/layers/language_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight"


class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer):
class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[ConfigType], Layer):
"""
A language model embedding layer.
Consists of word embeddings (tensor-parallel or sequence-tensor-parallel),
Expand Down Expand Up @@ -44,6 +44,7 @@ def __init__(
self._parallel_embeddings = (
self._tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings
)

hidden_dim = self._tensor_space[LanguageModelDimNames.hidden]
vocab_dim = self._tensor_space[
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
Expand Down Expand Up @@ -107,7 +108,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask
if self._sequence_parallel
else self._tensor_space.distributed.pp_generator
):
embeddings = torch.dropout(embeddings, self._config.transformer.hidden_dropout, self.training)
embeddings = torch.dropout(embeddings, self._config.embedding_dropout, self.training)
return embeddings.to(dtype=self._residual_dtype)

def forward(
Expand Down
17 changes: 13 additions & 4 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,22 @@ class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[Config

config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig

def __init__(self, config: ConfigType, tensor_space: TensorSpace, prediction_distance: int):
def __init__(
self,
config: ConfigType,
tensor_space: TensorSpace,
prediction_distance: int,
):
super().__init__(config)
# TODO: Avoid default_block_config?
self._debug = DebugLayer(
tensor_space,
f"Language model head",
self._config.transformer.debug_transformer,
self._config.transformer.debug_transformer_memory,
self._config.default_block_config.debug_transformer,
self._config.default_block_config.debug_transformer_memory,
)
self._tensor_space = tensor_space

self._group_size = tensor_space.distributed_config.tensor_parallel
self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel
self._parallel_embeddings = (
Expand All @@ -67,7 +75,8 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, prediction_dis
else 1.0
)
self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance)
self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim)
# TODO: Avoid default_block_config?
self.final_norm = self._config.default_block_config.normalization.get_layer(hidden_dim)
self._logits_scale_factor = self._config.logits_scale_factor
self._language_model_loss_factor = self._config.language_model_loss_factor
self._distillation_loss_factor = self._config.distillation_loss_factor
Expand Down
8 changes: 1 addition & 7 deletions fast_llm/layers/ssm/discrete_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,7 @@ def __init__(
tensor_space: TensorSpace,
block_config: BlockConfig,
):
super().__init__(
tensor_space,
block_index,
self._mixer_name,
debug_level=block_config.debug_transformer,
debug_memory=block_config.debug_transformer_memory,
)
super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer)
self._config: SSMConfig = config
layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None
lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale)
Expand Down
8 changes: 1 addition & 7 deletions fast_llm/layers/ssm/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,7 @@ def __init__(
block_index: int,
block_config: BlockConfig,
):
super().__init__(
tensor_space,
block_index,
self._mixer_name,
debug_level=block_config.debug_transformer,
debug_memory=block_config.debug_transformer_memory,
)
super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer)
self._config: SSMConfig = config
Assert.eq(self._config.activation_type, ActivationType.silu)
layer_lr_scale: float | None = (
Expand Down
8 changes: 1 addition & 7 deletions fast_llm/layers/ssm/mamba_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,7 @@ def __init__(
tensor_space: TensorSpace,
block_config: BlockConfig,
):
super().__init__(
tensor_space,
block_index,
self._mixer_name,
debug_level=block_config.debug_transformer,
debug_memory=block_config.debug_transformer_memory,
)
super().__init__(tensor_space, block_index, debug_level=block_config.debug_transformer)
assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer"
self._config = config
# TODO: It's not silu?
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i
self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size
self._local_heads_per_group = self._tensor_space[AttentionDimNames.group_heads].size
self._local_heads = self._local_head_groups * self._local_heads_per_group
self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power)
self._softmax_scale: float = self._kv_channels ** (-self._config.attention_softmax_scale_power)

hidden_dim = self._tensor_space[AttentionDimNames.hidden]

Expand Down
Loading