Skip to content

Commit dea274c

Browse files
danieldkshadeMe
andauthored
Add MPTGenerator (#296)
* Add `MPTGenerator` This is pretty straightforward, except that I had to add a workaround because we don't support weight tying yet. * No special JIT handling in causal LMs The special handling of JIT'ed code has not been necessary since the model outputs are also tuples. * Whitespace fix Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
1 parent 3e5c220 commit dea274c

File tree

10 files changed

+112
-22
lines changed

10 files changed

+112
-22
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Generator wrappers:
5454
- Dolly v2
5555
- Falcon
5656
- Llama 1/2
57+
- MPT
5758

5859
All types of models can be loaded from Huggingface Hub.
5960

curated_transformers/generation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TopPTransform,
1616
VocabMaskTransform,
1717
)
18+
from .mpt import MPTGenerator
1819
from .stop_conditions import (
1920
CompoundStopCondition,
2021
EndOfSequenceCondition,
@@ -39,6 +40,7 @@
3940
"LlamaGenerator",
4041
"LogitsTransform",
4142
"MaxGeneratedPiecesCondition",
43+
"MPTGenerator",
4244
"SampleGeneratorConfig",
4345
"StopCondition",
4446
"StringGenerator",

curated_transformers/generation/auto_generator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@
44

55
from ..models.auto_model import AutoModel
66
from ..quantization.bnb.config import BitsAndBytesConfig
7-
from .default_generator import DefaultGenerator
87
from .dolly_v2 import DollyV2Generator
98
from .falcon import FalconGenerator
109
from .generator_wrapper import GeneratorWrapper
1110
from .hf_hub import FromHFHub
11+
from .llama import LlamaGenerator
12+
from .mpt import MPTGenerator
1213

1314
# For the time being, we enable support for a generator on a case-by-case basis.
1415
# In the future we might defer all unknown generators to DefaultGenerator.
1516
GENERATOR_MAP: Dict[str, Type[FromHFHub]] = {
1617
"dolly-v2": DollyV2Generator,
1718
"falcon": FalconGenerator,
18-
"llama": DefaultGenerator,
19+
"llama": LlamaGenerator,
20+
"mpt": MPTGenerator,
1921
}
2022

2123

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import TypeVar
2+
3+
from ..models.llama import LlamaCausalLM
4+
from ..tokenizers.tokenizer import Tokenizer
5+
from .default_generator import DefaultGenerator
6+
from .hf_hub import FromHFHub
7+
8+
9+
class MPTGenerator(DefaultGenerator, FromHFHub):
10+
"""
11+
Generator for MPT model variants.
12+
"""
13+
14+
def __init__(self, tokenizer: Tokenizer, causal_lm: LlamaCausalLM):
15+
"""
16+
Construct an MPT generator.
17+
18+
:param tokenizer:
19+
An MPT tokenizer.
20+
:param causal_lm:
21+
An MPT causal language model.
22+
"""
23+
super().__init__(
24+
tokenizer,
25+
causal_lm,
26+
)

curated_transformers/models/mpt/_hf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ def convert_hf_state_dict(cls, params: Mapping[str, Tensor]) -> Mapping[str, Ten
6161

6262
out = {}
6363
for name, parameter in stripped_params.items():
64+
# Input and output embeddings are tied in MPT.
65+
if "lm_head" in name:
66+
continue
67+
6468
name = name.replace("transformer", "decoder")
6569
name = name.replace("blocks", "layers")
6670

@@ -80,7 +84,6 @@ def convert_hf_state_dict(cls, params: Mapping[str, Tensor]) -> Mapping[str, Ten
8084

8185
# Embeddings
8286
name = re.sub(r"wte\.", r"embeddings.piece_embeddings.", name)
83-
name = re.sub(r"lm_head\.", r"output_embeddings.", name)
8487

8588
out[name] = parameter
8689

curated_transformers/models/mpt/causal_lm.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
from typing import Any, Mapping, Optional, Set, Type, TypeVar
1+
from typing import Any, List, Mapping, Optional, Set, Type, TypeVar
22

33
import torch
4+
import torch.nn.functional as F
45
from torch import Tensor
5-
from torch.nn import Linear
6+
from torch.nn import Embedding
67

8+
from ...layers.attention import AttentionMask
9+
from ...layers.cache import KeyValueCache
710
from ...quantization import Quantizable
811
from ..hf_hub import FromHFHub
12+
from ..output import CausalLMOutputWithCache
913
from ..transformer import TransformerCausalLM
1014
from ._hf import convert_hf_config, convert_hf_state_dict
1115
from .config import MPTConfig
@@ -38,11 +42,45 @@ def __init__(
3842
super().__init__()
3943

4044
self.decoder = MPTDecoder(config, device=device)
41-
self.output_embeddings = Linear(
42-
in_features=config.layer.feedforward.hidden_width,
43-
out_features=config.embedding.n_pieces,
44-
bias=False,
45-
device=device,
45+
46+
# Once we have proper support for tied weights, we will do something like:
47+
#
48+
# self.output_embeddings = Linear(
49+
# in_features=config.layer.feedforward.hidden_width,
50+
# out_features=config.embedding.n_pieces,
51+
# bias=False,
52+
# device=device,
53+
# )
54+
# self.output_embeddings.weights = self.decoder.embeddings.piece_embeddings.weights
55+
#
56+
# For now we'll work around this by using the piece embeddings directly.
57+
58+
def forward(
59+
self,
60+
piece_ids: Tensor,
61+
attention_mask: AttentionMask,
62+
cache: Optional[List[KeyValueCache]] = None,
63+
positions: Optional[Tensor] = None,
64+
store_cache: bool = False,
65+
) -> CausalLMOutputWithCache[KeyValueCache]:
66+
# TODO: remove this forward method once we support weight tying.
67+
68+
decoder_output = self.decoder(
69+
piece_ids,
70+
attention_mask,
71+
cache=cache,
72+
store_cache=store_cache,
73+
positions=positions,
74+
)
75+
76+
assert isinstance(self.decoder.embeddings.piece_embeddings, Embedding)
77+
output_embeddings = self.decoder.embeddings.piece_embeddings.weight
78+
79+
logits = F.linear(decoder_output.last_hidden_layer_state, output_embeddings)
80+
return CausalLMOutputWithCache(
81+
all_outputs=decoder_output.all_outputs,
82+
cache=decoder_output.cache,
83+
logits=logits,
4684
)
4785

4886
@classmethod

curated_transformers/models/transformer.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,12 @@ def forward(
9292
store_cache=store_cache,
9393
positions=positions,
9494
)
95-
if torch.jit.is_tracing():
96-
logits = self.output_embeddings(decoder_output[0][-1])
97-
return decoder_output + (logits,) # type: ignore[return-value]
98-
else:
99-
logits = self.output_embeddings(decoder_output.last_hidden_layer_state)
100-
return CausalLMOutputWithCache(
101-
all_outputs=decoder_output.all_outputs,
102-
cache=decoder_output.cache,
103-
logits=logits,
104-
)
95+
logits = self.output_embeddings(decoder_output.last_hidden_layer_state)
96+
return CausalLMOutputWithCache(
97+
all_outputs=decoder_output.all_outputs,
98+
cache=decoder_output.cache,
99+
logits=logits,
100+
)
105101

106102

107103
class TransformerEncoder(EncoderModule):

curated_transformers/tests/generation/test_auto_generator.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,33 @@
11
import pytest
22

33
from curated_transformers.generation import AutoGenerator
4-
from curated_transformers.generation.default_generator import DefaultGenerator
54
from curated_transformers.generation.dolly_v2 import DollyV2Generator
65
from curated_transformers.generation.falcon import FalconGenerator
6+
from curated_transformers.generation.llama import LlamaGenerator
7+
from curated_transformers.generation.mpt import MPTGenerator
78

89

910
@pytest.mark.slow
1011
def test_auto_generator():
1112
model_causallm_map = {
1213
"databricks/dolly-v2-3b": DollyV2Generator,
1314
"tiiuae/falcon-7b": FalconGenerator,
14-
"openlm-research/open_llama_3b": DefaultGenerator,
15+
"openlm-research/open_llama_3b": LlamaGenerator,
16+
}
17+
18+
for name, generator_cls in model_causallm_map.items():
19+
generator = AutoGenerator.from_hf_hub(name=name)
20+
assert isinstance(generator, generator_cls)
21+
22+
with pytest.raises(ValueError, match="Unsupported generator"):
23+
AutoGenerator.from_hf_hub(name="trl-internal-testing/tiny-random-GPT2Model")
24+
25+
26+
@pytest.mark.hf_head
27+
@pytest.mark.slow
28+
def test_auto_generator_hf_head():
29+
model_causallm_map = {
30+
"mosaicml/mpt-7b": MPTGenerator,
1531
}
1632

1733
for name, generator_cls in model_causallm_map.items():

docs/source/generation.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ These classes provide the interface for performing text generation using causal
5151
:special-members: __call__
5252
:show-inheritance:
5353

54+
.. autoclass:: curated_transformers.generation.MPTGenerator
55+
:members:
56+
:inherited-members:
57+
:special-members: __call__
58+
:show-inheritance:
5459

5560
Downloading
5661
-----------

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Generator wrappers:
6060
- Dolly v2
6161
- Falcon
6262
- Llama 1/2
63+
- MPT
6364

6465
All types of models can be loaded from Hugging Face Hub.
6566

0 commit comments

Comments
 (0)