Skip to content

Commit dd4a323

Browse files
borzunovmryab
andauthored
Add Falcon support (#499)
This PR adds: - Support for models based on `transformers.FalconModel` (the in-library format for Falcon). Tested on Falcon-40B. - CI tests for Falcon-RW-1B. - `--throughput dry_run` option to evaluate throughput and exit right away (implemented by @mryab). Limitations: - Backward pass support is broken for now, will be fixed in #500. Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
1 parent b4d822a commit dd4a323

File tree

9 files changed

+356
-15
lines changed

9 files changed

+356
-15
lines changed

.github/workflows/run-tests.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ jobs:
1414
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
1515
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
1616
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
17+
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' }
18+
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' }
1719
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
1820
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
1921
fail-fast: false
2022
runs-on: ${{ matrix.os }}-latest
21-
timeout-minutes: 15
23+
timeout-minutes: 20
2224
steps:
2325
- name: Increase swap space
2426
if: ${{ matrix.os == 'ubuntu' }}
@@ -93,6 +95,9 @@ jobs:
9395
9496
# [Step 2] Run PyTest
9597
98+
# Share disk cache between Petals servers, clients, and HF Transformers
99+
export TRANSFORMERS_CACHE=~/.cache/petals
100+
96101
# Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
97102
export no_proxy=*
98103
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES

src/petals/cli/run_server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,13 @@ def main():
106106
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
107107

108108
parser.add_argument('--throughput',
109-
type=lambda value: value if value in ['auto', 'eval'] else float(value),
109+
type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value),
110110
default='auto',
111111
help='Expected server throughput (a float measured in RPS). '
112112
'If set to "auto" (default), the script evaluates network and compute throughput '
113113
'on the first run and uses these estimates for future runs. '
114-
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
114+
'If set to "eval", the script re-evaluates the throughput and overrides the cache. '
115+
'If set to "dry_run", the script re-evaluates the throughput and exits.')
115116
parser.add_argument('--update_period', type=float, required=False, default=120,
116117
help='Server will report blocks to DHT once in this many seconds')
117118
parser.add_argument('--expiration', type=float, required=False, default=None,

src/petals/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from petals.models.bloom import *
2+
from petals.models.falcon import *
23
from petals.models.llama import *

src/petals/models/falcon/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from petals.models.falcon.block import WrappedFalconBlock
2+
from petals.models.falcon.config import DistributedFalconConfig
3+
from petals.models.falcon.model import (
4+
DistributedFalconForCausalLM,
5+
DistributedFalconForSequenceClassification,
6+
DistributedFalconModel,
7+
)
8+
from petals.utils.auto_config import register_model_classes
9+
10+
register_model_classes(
11+
config=DistributedFalconConfig,
12+
model=DistributedFalconModel,
13+
model_for_causal_lm=DistributedFalconForCausalLM,
14+
model_for_sequence_classification=DistributedFalconForSequenceClassification,
15+
)

src/petals/models/falcon/block.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""
2+
Falcon intermediate layer
3+
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
4+
See commit history for authorship.
5+
"""
6+
from typing import Optional, Tuple
7+
8+
import torch
9+
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
10+
11+
KVCache = Tuple[torch.Tensor, torch.Tensor]
12+
13+
14+
class WrappedFalconBlock(FalconDecoderLayer):
15+
def forward(
16+
self,
17+
hidden_states: torch.Tensor,
18+
*args,
19+
attention_mask: Optional[torch.Tensor] = None,
20+
alibi: Optional[torch.Tensor] = None,
21+
layer_past: Optional[KVCache] = None,
22+
use_cache: bool = False,
23+
**kwargs
24+
):
25+
batch_size, seq_length = hidden_states.shape[:2]
26+
27+
if layer_past is not None:
28+
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
29+
past_length = 0 if layer_past is None else layer_past[0].shape[1]
30+
seq_length_with_past = seq_length + past_length
31+
32+
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
33+
if alibi is None and self.config.alibi:
34+
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
35+
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
36+
37+
outputs = super().forward(
38+
hidden_states,
39+
*args,
40+
attention_mask=attention_mask,
41+
alibi=alibi,
42+
layer_past=layer_past,
43+
use_cache=use_cache,
44+
**kwargs
45+
)
46+
47+
if use_cache:
48+
present_key_value = outputs[-1]
49+
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
50+
outputs = outputs[:-1] + (present_key_value,)
51+
52+
return outputs
53+
54+
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
55+
key_states, value_states = key_value
56+
57+
key_states = key_states.permute(0, 2, 1)
58+
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
59+
60+
if self.config.new_decoder_architecture:
61+
key_states = self._expand_states(key_states)
62+
value_states = self._expand_states(value_states)
63+
64+
return (key_states, value_states)
65+
66+
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
67+
key_states, value_states = key_value
68+
69+
if self.config.new_decoder_architecture:
70+
key_states = self._collapse_states(key_states)
71+
value_states = self._collapse_states(value_states)
72+
73+
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
74+
key_states = key_states.permute(0, 2, 1)
75+
76+
return (key_states, value_states)
77+
78+
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
79+
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
80+
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
81+
82+
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
83+
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
84+
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
85+
return state
86+
87+
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
88+
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
89+
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
90+
91+
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
92+
state = state[:, :, 0]
93+
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
94+
return state

src/petals/models/falcon/config.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
from typing import Optional, Union
3+
4+
from hivemind import get_logger
5+
from transformers.models.falcon import FalconConfig
6+
from transformers.models.falcon.modeling_falcon import FalconAttention
7+
8+
from petals.client.config import ClientConfig
9+
from petals.client.lm_head import LMHeadConfig
10+
from petals.client.ptune import PTuneConfig
11+
from petals.models.falcon.block import WrappedFalconBlock
12+
from petals.utils.auto_config import DefaultRevisionMixin
13+
14+
logger = get_logger(__name__)
15+
16+
17+
class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
18+
block_class = WrappedFalconBlock
19+
attn_class = FalconAttention
20+
block_prefix = "transformer.h"
21+
22+
@property
23+
def num_key_value_groups(self) -> int:
24+
if self.new_decoder_architecture:
25+
return self.num_attention_heads // self.num_kv_heads
26+
if self.multi_query:
27+
return self.num_attention_heads
28+
return 1
29+
30+
@classmethod
31+
def from_pretrained(
32+
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
33+
):
34+
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
35+
if loading_from_repo and dht_prefix is None:
36+
dht_prefix = str(model_name_or_path)
37+
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
38+
dht_prefix = dht_prefix.replace(".", "-")
39+
logger.info(f"Using DHT prefix: {dht_prefix}")
40+
41+
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
42+
config = result[0] if isinstance(result, tuple) else result
43+
if config.pad_token_id is None:
44+
config.pad_token_id = 0
45+
return result

src/petals/models/falcon/model.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from typing import Optional
2+
3+
import hivemind
4+
import torch
5+
import torch.nn as nn
6+
from hivemind.utils.logging import get_logger
7+
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
8+
from transformers.models.falcon import (
9+
FalconForCausalLM,
10+
FalconForSequenceClassification,
11+
FalconModel,
12+
FalconPreTrainedModel,
13+
)
14+
15+
from petals.client.from_pretrained import FromPretrainedMixin
16+
from petals.client.lm_head import LMHead
17+
from petals.client.ptune import PTuneMixin
18+
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
19+
from petals.client.remote_sequential import RemoteSequential
20+
from petals.models.falcon.config import DistributedFalconConfig
21+
from petals.utils.auto_config import DefaultRevisionMixin
22+
23+
logger = get_logger(__name__)
24+
25+
26+
class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):
27+
"""FalconModel, but all transformer layers are hosted by the swarm"""
28+
29+
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
30+
_keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."]
31+
32+
config_class = DistributedFalconConfig
33+
34+
def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None):
35+
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
36+
super().__init__(config)
37+
assert len(self.h) == 0
38+
config.num_hidden_layers = n_layer
39+
40+
self.h = RemoteSequential(config, dht=dht)
41+
42+
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
43+
self.init_prompts(config)
44+
45+
def forward(
46+
self,
47+
input_ids: Optional[torch.LongTensor] = None,
48+
past_key_values: Optional[RemotePastKeyValues] = None,
49+
attention_mask: Optional[torch.Tensor] = None,
50+
head_mask: Optional[torch.LongTensor] = None,
51+
inputs_embeds: Optional[torch.LongTensor] = None,
52+
use_cache: Optional[bool] = None,
53+
output_attentions: Optional[bool] = None,
54+
output_hidden_states: Optional[bool] = None,
55+
return_dict: Optional[bool] = None,
56+
):
57+
if input_ids is not None and inputs_embeds is not None:
58+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
59+
elif input_ids is not None:
60+
input_shape = input_ids.size()
61+
input_ids = input_ids.view(-1, input_shape[-1])
62+
elif inputs_embeds is not None:
63+
input_shape = inputs_embeds.size()[:-1]
64+
else:
65+
raise ValueError("You have to specify either input_ids or inputs_embeds")
66+
67+
# The causal mask will be added on the server-side
68+
assert (
69+
attention_mask is None or (attention_mask == 1).all()
70+
), f"Custom attention masks are not supported, {attention_mask=}"
71+
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
72+
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
73+
assert not output_attentions, f"{output_attentions=} is not supported"
74+
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
75+
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
76+
77+
if inputs_embeds is None:
78+
inputs_embeds = self.word_embeddings(input_ids)
79+
80+
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
81+
batch_size = inputs_embeds.shape[0]
82+
prompts, intermediate_prompts = self.get_prompt(batch_size)
83+
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
84+
else:
85+
prompts = intermediate_prompts = None
86+
87+
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
88+
output_shape = input_shape + (hidden_states.size(-1),)
89+
90+
hidden_states = self.h(
91+
hidden_states,
92+
prompts=intermediate_prompts,
93+
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
94+
)
95+
96+
# Remove prefix
97+
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
98+
hidden_states = hidden_states[:, self.pre_seq_len :]
99+
100+
# Add last hidden state
101+
hidden_states = self.ln_f(hidden_states)
102+
hidden_states = hidden_states.view(output_shape)
103+
return BaseModelOutputWithPastAndCrossAttentions(
104+
last_hidden_state=hidden_states,
105+
past_key_values=RemotePastKeyValues(),
106+
hidden_states=None,
107+
attentions=None,
108+
)
109+
110+
@property
111+
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
112+
return nn.Identity()
113+
114+
115+
class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
116+
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
117+
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
118+
119+
config_class = DistributedFalconConfig
120+
121+
def __init__(self, config: DistributedFalconConfig):
122+
FalconPreTrainedModel.__init__(self, config)
123+
self.transformer = DistributedFalconModel(config)
124+
self.lm_head = LMHead(config)
125+
126+
# Initialize weights and apply final processing
127+
self.post_init()
128+
129+
def get_output_embeddings(self):
130+
return self.lm_head
131+
132+
133+
class DistributedFalconForSequenceClassification(
134+
DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification
135+
):
136+
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
137+
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
138+
139+
config_class = DistributedFalconConfig
140+
141+
def __init__(self, config: DistributedFalconConfig):
142+
FalconPreTrainedModel.__init__(self, config)
143+
self.num_labels = config.num_labels
144+
145+
self.transformer = DistributedFalconModel(config)
146+
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
147+
148+
# Initialize weights and apply final processing
149+
self.post_init()

src/petals/server/server.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import multiprocessing as mp
66
import os
77
import random
8+
import sys
89
import threading
910
import time
1011
from typing import Dict, List, Optional, Sequence, Union
@@ -186,10 +187,7 @@ def __init__(
186187
check_device_balance(self.tensor_parallel_devices)
187188

188189
if quant_type is None:
189-
if device.type == "cuda":
190-
quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
191-
else:
192-
quant_type = QuantType.NONE
190+
quant_type = QuantType.NF4 if device.type == "cuda" else QuantType.NONE
193191
self.quant_type = quant_type
194192
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
195193

@@ -234,8 +232,9 @@ def __init__(
234232
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
235233
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
236234

237-
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
238-
if throughput in ["auto", "eval"]:
235+
assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
236+
if throughput in ["auto", "eval", "dry_run"]:
237+
force_eval = throughput in ["eval", "dry_run"]
239238
throughput_info = get_server_throughput(
240239
converted_model_name_or_path,
241240
self.block_config,
@@ -245,9 +244,12 @@ def __init__(
245244
quant_type=quant_type,
246245
tensor_parallel_devices=self.tensor_parallel_devices,
247246
reachable_via_relay=reachable_via_relay,
248-
force_eval=(throughput == "eval"),
247+
force_eval=force_eval,
249248
cache_dir=cache_dir,
250249
)
250+
if throughput == "dry_run":
251+
logger.info("Finished estimating throughput, exiting")
252+
sys.exit(0)
251253
else:
252254
throughput_info = {"throughput": throughput}
253255
self.server_info = ServerInfo(

0 commit comments

Comments
 (0)