File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 6969LLaDA
7070ep_execution_plan.md
7171parallel_state_redesign.md
72+ core. *
Original file line number Diff line number Diff line change @@ -54,16 +54,16 @@ def forward(
5454 mask : list [torch .Tensor ] | None = None ,
5555 ) -> torch .Tensor :
5656 if q .dim () == 2 :
57- q = rearrange (q , "s (nh hd) -> s nh hd" , ** self .q_shape ). contiguous ()
57+ q = rearrange (q , "s (nh hd) -> s nh hd" , ** self .q_shape )
5858 elif q .dim () == 3 :
59- q = q . contiguous ()
59+ q = q
6060 else :
6161 raise ValueError (f"Unsupported q ndim for Attention: { q .dim ()} " )
6262
6363 if k .dim () == 2 :
64- k = rearrange (k , "s (nkvh hd) -> s nkvh hd" , ** self .kv_shape ). contiguous ()
64+ k = rearrange (k , "s (nkvh hd) -> s nkvh hd" , ** self .kv_shape )
6565 elif k .dim () == 3 :
66- k = k . contiguous ()
66+ k = k
6767 else :
6868 raise ValueError (f"Unsupported k ndim for Attention: { k .dim ()} " )
6969
You can’t perform that action at this time.
0 commit comments