-
Notifications
You must be signed in to change notification settings - Fork 16
Refactor inference.py for LLM and RoBERTa support #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
6e5c0d2
5c0d9ec
d7730c8
819b147
13c0917
5895831
8b1d37e
238b05d
38fef7a
3b94a36
effb27b
600ba67
031abde
8657a11
0d042c6
3f13729
4e731d8
fd70377
7a5c9df
3ad7050
92d05ef
011ec33
5292128
17be9a7
6780770
0437c50
dfd6758
f7c458e
3434641
a543198
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
import argparse | ||
import os | ||
import torch | ||
|
||
# ============================================================== | ||
# Common utilities | ||
|
@@ -57,6 +59,8 @@ def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False): | |
def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False): | ||
if local_rank < 0: | ||
local_rank = rank | ||
|
||
# FIXME: local_size not in use ? | ||
if local_size < 0: | ||
local_size = world_size | ||
|
||
|
@@ -67,3 +71,60 @@ def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False | |
dprint(f"Detected running via torchrun") | ||
|
||
aiu_setup(rank, world_size) | ||
|
||
|
||
# ============================================================== | ||
# Environment variables utilities | ||
# ============================================================== | ||
def set_aiu_env_vars(args: argparse.Namespace) -> None: | ||
"""Set necessary environment variables for AIU""" | ||
|
||
_target_cache_size = max( | ||
int(args.max_new_tokens * 2), | ||
int(args.min_pad_length * 2.5), | ||
int(args.fixed_prompt_length * 2.5), | ||
) | ||
_prompt_size = max(int(args.min_pad_length), int(args.fixed_prompt_length)) | ||
if hasattr(torch._dynamo.config, "accumulated_cache_size_limit"): | ||
if _target_cache_size > torch._dynamo.config.accumulated_cache_size_limit: | ||
_prev = torch._dynamo.config.accumulated_cache_size_limit | ||
torch._dynamo.config.accumulated_cache_size_limit = _target_cache_size | ||
dprint( | ||
"NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit " | ||
f"from {_prev} to {torch._dynamo.config.accumulated_cache_size_limit} " | ||
f"to accomodate prompt size of {_prompt_size} and decode tokens of " | ||
f"{args.max_new_tokens}" | ||
) | ||
|
||
if _target_cache_size > torch._dynamo.config.cache_size_limit: | ||
_prev = torch._dynamo.config.cache_size_limit | ||
torch._dynamo.config.cache_size_limit = _target_cache_size | ||
dprint( | ||
f"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from {_prev} to " | ||
f"{torch._dynamo.config.cache_size_limit} to accomodate prompt size of " | ||
f"{_prompt_size} and decode tokens of {args.max_new_tokens}" | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is only needed if compile_dynamic is disabled, can we gate it? |
||
|
||
if not args.compile_dynamic: | ||
torch._dynamo.config.assume_static_by_default = True | ||
torch._dynamo.config.dynamic_shapes = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is now deprecated in pytorch, only assume and automatic are needed |
||
torch._dynamo.config.automatic_dynamic_shapes = False | ||
|
||
# This should be set outside!!! | ||
os.environ.setdefault("SENCORES", "32") | ||
os.environ.setdefault("SENCORELETS", "2") | ||
os.environ.setdefault("DATA_PREC", "fp16") | ||
os.environ.setdefault("FLEX_OVERWRITE_NMB_FRAME", "1") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think some of these are already set by default on the e2e_stable image, can we check and remove the ones we don't need anymore? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. confirmed these are no longer needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is |
||
os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "true") | ||
|
||
os.environ.setdefault("COMPILATION_MODE", "offline_decoder") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this one is only needed for decoder models, for roberta it will probably make it not work There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. confirmed it will not work for roberta, we need to set it depending on the kind of model |
||
|
||
if args.device_type == "aiu-senulator": | ||
os.environ["FLEX_COMPUTE"] = "SENULATOR" | ||
os.environ["FLEX_DEVICE"] = "MOCK" | ||
else: | ||
if "AIU_WORLD_RANK_0" not in os.environ: | ||
print("must set AIU_WORLD_RANK_0") | ||
exit() | ||
os.environ.setdefault("FLEX_COMPUTE", "SENTIENT") | ||
os.environ.setdefault("FLEX_DEVICE", "VFIO") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think VFIO is now PF or VF, depending on the AIU setup, we can set PF by default as most cards are running in PF mode There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. confirmed it is now PF for all clusters we have access to and will eventually be VF There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I set it to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe it's no longer needed? if you can't find any reference to it feel free to delete