Skip to content

Commit 47d50e1

Browse files
authored
Improve default arguments for clients and servers (#530)
This PR updates multiple default arguments in clients and servers: 1. **The client defaults to `torch_dtype=torch.float32` instead of `torch_dtype="auto"`.** The old default was to load weights in the dtype they are saved in (usually bfloat16/float16), which caused issues when the client was run on CPU (the default unless you call `.cuda()`). Specifically, bfloat16 is slow on most CPUs (unless a CPU supports AVX512) and float16 can't be run natively and leads to an exception. This default was a legacy of the earliest Petals versions designed to run BLOOM - its embeddings were so big that they didn't fit into RAM in float32 (e.g., in Colab). The newer models don't have this issue. In contrast, the new default leads to good speed on all CPUs and is consistent with PyTorch and HF Transformers. Also, the client now shows "bfloat16 on non-AVX512 CPU" in all cases (previously this warning was shown only if the machine has enough RAM to fit float32 weights, which could hide the crucial reason of inference being slow). **Note:** This change is backward-incompatible, so we have to increase at least the minor package version (2.2.0 -> 2.3.0.dev0). 2. **The server uses 2x smaller `--attn_cache_tokens`.** The old default led to loading 39 (out of 80) or 78 (out of 80) blocks for popular models on some GPU types, which visibly slowed down inference due to an excess network hop. It was also leaving too much cache, so that inference slowed down much before the cache is used. The new default leads to more efficient block layouts and makes the inference routing algorithm choose alternative paths through other servers when a particular server already has enough active inference sessions (= its cache is full). 3. **The client's max number of retries can be limited by the `PETALS_MAX_RETRIES` env var.** This is to limit `ClientConfig.max_retries` in tests, so we see tracebacks instead of retrying indefinitely in case of errors.
1 parent ae19b65 commit 47d50e1

File tree

7 files changed

+19
-22
lines changed

7 files changed

+19
-22
lines changed

.github/workflows/run-tests.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ jobs:
102102
export no_proxy=*
103103
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
104104
105+
# Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely
106+
export PETALS_MAX_RETRIES=10
107+
105108
pytest tests --durations=0 --durations-min=1.0 -v
106109
107110
# [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)

src/petals/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from petals.utils import *
1818
from petals.utils.logging import initialize_logs as _initialize_logs
1919

20-
__version__ = "2.2.0"
20+
__version__ = "2.3.0.dev0"
2121

2222

2323
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

src/petals/cli/run_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,17 @@ def main():
7070

7171
parser.add_argument('--inference_max_length', type=int, default=None,
7272
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
73-
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
73+
'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
7474
parser.add_argument('--min_batch_size', type=int, default=1,
7575
help='Minimum required batch size for all operations (in total tokens)')
7676
parser.add_argument('--max_batch_size', type=int, default=None,
7777
help='The total number of tokens in the same batch will not exceed this value. '
78-
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
78+
'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
7979
parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,
8080
help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')
8181
parser.add_argument('--attn_cache_tokens', type=int, default=None,
8282
help='The number of past attention key/value pairs that will be stored between inference steps. '
83-
'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)')
83+
'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others')
8484

8585
parser.add_argument('--cache_dir', type=str, default=None,
8686
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')

src/petals/client/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import dataclasses
2+
import os
23
from typing import Optional, Sequence, Union
34

45
from hivemind import PeerID
56

67
from petals.constants import PUBLIC_INITIAL_PEERS
78

9+
_max_retries = os.getenv("PETALS_MAX_RETRIES")
10+
DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None
11+
812

913
@dataclasses.dataclass
1014
class ClientConfig:
@@ -21,7 +25,7 @@ class ClientConfig:
2125
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
2226
update_period: float = 60 # refresh DHT information once in this many seconds
2327

24-
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
28+
max_retries: Optional[int] = DEFAULT_MAX_RETRIES # max number of retries before an exception (default: inf)
2529
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
2630
max_backoff: float = 60 # limit maximal sleep time between retries to this value
2731
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds

src/petals/client/from_pretrained.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from contextvars import ContextVar
77
from typing import List, Optional, Tuple, Union
88

9-
import torch
109
from hivemind.utils.logging import get_logger
1110
from transformers import BloomPreTrainedModel, modeling_utils
1211

@@ -22,21 +21,14 @@ def from_pretrained(
2221
model_name_or_path: Union[str, os.PathLike, None],
2322
*args,
2423
low_cpu_mem_usage: Optional[bool] = None,
25-
torch_dtype: Optional[Union[str, torch.dtype]] = None,
2624
**kwargs,
2725
):
2826
model_name_or_path = get_compatible_model_repo(model_name_or_path)
2927
if low_cpu_mem_usage is None:
3028
low_cpu_mem_usage = True
31-
if torch_dtype is None:
32-
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
33-
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
34-
torch_dtype = "auto"
3529

3630
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
37-
return super().from_pretrained(
38-
model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
39-
)
31+
return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
4032

4133
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
4234
"low_cpu_mem_usage(`bool`, *optional*)",

src/petals/client/lm_head.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import dataclasses
22
import platform
3-
from typing import Optional, Union
3+
from typing import Union
44

5-
import psutil
65
import torch
76
import torch.nn.functional as F
87
import torch.utils.checkpoint
@@ -68,11 +67,10 @@ def chunked_forward(self, hidden_states):
6867
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
6968

7069
if not self._bf16_warning_shown:
71-
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
72-
logger.warning(
73-
"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
74-
"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
75-
)
70+
logger.warning(
71+
"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
72+
"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
73+
)
7674
self._bf16_warning_shown = True
7775

7876
hidden_states = hidden_states.float()

src/petals/server/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def __init__(
203203

204204
# For attention cache in GPU or RAM
205205
if attn_cache_tokens is None:
206-
attn_cache_tokens = 32768 if is_multiquery_attn else 8192
206+
attn_cache_tokens = 16384 if is_multiquery_attn else 4096
207207
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
208208
cache_values_per_block //= self.block_config.num_key_value_groups
209209
self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)

0 commit comments

Comments
 (0)