-
Notifications
You must be signed in to change notification settings - Fork 16
Options for Stagger model loading for low memory systems #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Options for Stagger model loading for low memory systems #47
Conversation
Keeping this in Draft for now while a few teams get back to me on testing. |
1af2202
to
c5218b1
Compare
I checked out this PR and am testing it with latest FMS levels. Reasons for trying this:
|
For testing purposes: What are differences between option |
We can improve the help text, but this is what it reports now:
So two different options to stagger processing of two different sections of the compile phase in case you have a need to set different values for each part. The default value is Setting this value to Setting the value to If you are only worried about functional testing in a memory constrained environment and not on compile time efficiency then setting this value to |
@jjhursey , thanks! This is very helpful. |
c5218b1
to
2fd81b0
Compare
Both the x86 and Power test teams have confirmed that this helps mitigate their testing needs for large models on low memory systems. |
@JRosenkranz @ani300 this is ready for review |
extra_kwargs = {**padding_kwargs, "only_last_token": True} | ||
max_new_tokens_warmup = max_new_tokens | ||
if compile_dynamic_sendnn: | ||
max_new_tokens_warmup = 2 | ||
|
||
if stagger_update_lazyhandle > 0 and stagger_update_lazyhandle != world_size: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like logic is called multiple times, do you think it would make sense to put it in it's own utility function, this way in can be re-used in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I can do that. Have an "enter" and "exit" version to place in the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just pushed a commit for this change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has this been tested with inference.py as well as test_decoders (multi-aiu / single aiu). In theory those should not change. Also, it might make sense to add an option to this in test_decoders in low memory systems.
bot:test |
1 similar comment
bot:test |
5aaf975
to
14dfb4a
Compare
I pushed a commit to consolidate the staggered enter/exit. I also rebased on main |
bot:test |
2 similar comments
bot:test |
bot:test |
14dfb4a
to
130a407
Compare
I pushed an update that adds the docstrings and fixes a DCO check. |
torch.distributed.barrier() | ||
dprint(f"Stagger: Enter (Set: {_set+1} of {math.ceil(world_size / float(limit))})") | ||
|
||
def stagger_leave(limit: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense for this to be a stagger_context
(given each stagger_enter needs to be paired with a stagger_leave):
with stagger_context(limit):
model = get_model(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that would be a good idea. I've not built a context like that in Python before, but I'm always willing to learn. I'll take a pass at it in the next couple of days.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a commit to convert this to a contextlib function. Take a look and let me know what you think.
* `--stagger_load` : (default: `0` off) Stagger model loading to avoid OOM issues on the host * `--stagger_update_lazyhandle` : (default: `0` off) Stagger update_lazyhandle to avoid OOM issues on the host * `--dist_timeout` : (default: either `10` for NCCL or `30` for others set by PyTorch) torch distributed timeout in minutes Signed-off-by: Joshua Hursey <jhursey@us.ibm.com>
Signed-off-by: Joshua Hursey <jhursey@us.ibm.com>
f60d30f
to
b7c22e0
Compare
I recently updated my foundation-model-stack repo and the rest of the stack, and now I'm seeing a hang during the I'm not sure what part of the stack is causing that to break. It's not caused by this PR, but a new synchronization in the stack it is enclosing in the warmup. |
--stagger_load
: (default:0
off) Stagger model loading to avoid OOM issues on the host--stagger_update_lazyhandle
: (default:0
off) Stagger update_lazyhandle to avoid OOM issues on the host--dist_timeout
: (default: either10
for NCCL or30
for others set by PyTorch) torch distributed timeout in minutes