File tree Expand file tree Collapse file tree 1 file changed +11
-8
lines changed Expand file tree Collapse file tree 1 file changed +11
-8
lines changed Original file line number Diff line number Diff line change @@ -619,27 +619,30 @@ def weak_ref_tensors(
619
619
raise ValueError ("Invalid type for tensors" )
620
620
621
621
622
- def npu_stream_switch (target_stream : torch .npu .Stream , * , enabled : bool = True ):
622
+ def npu_stream_switch (target_stream : torch .npu .Stream ,
623
+ * ,
624
+ enabled : bool = True ):
623
625
"""
624
626
Switch to the target stream if enabled is True.
625
627
Otherwise, do nothing.
626
628
"""
627
629
if not enabled :
628
630
return nullcontext ()
631
+ assert target_stream is not None
629
632
return torch .npu .stream (target_stream )
630
633
631
634
632
- def npu_wait_stream (
633
- current_stream : torch .npu .Stream ,
634
- target_stream : torch .npu .Stream ,
635
- * ,
636
- enabled : bool = True
637
- ):
635
+ def npu_wait_stream (current_stream : torch .npu .Stream ,
636
+ target_stream : torch .npu .Stream ,
637
+ * ,
638
+ enabled : bool = True ):
638
639
"""
639
640
Make current stream wait for the target stream if enabled is True.
640
641
This operation will launch a record event on the target stream,
641
642
and launch a wait event on current stream, waitint for the record event.
642
643
Otherwise, do nothing.
643
644
"""
644
645
if enabled :
645
- current_stream .wait_stream (target_stream )
646
+ assert current_stream is not None
647
+ assert target_stream is not None
648
+ current_stream .wait_stream (target_stream )
You can’t perform that action at this time.
0 commit comments