|
22 | 22 |
|
23 | 23 | import torch
|
24 | 24 | import torch.nn.functional as F
|
25 |
| -from torch import nn, Tensor, tensor, is_tensor, stack |
| 25 | +from torch import nn, Tensor, tensor, is_tensor, cat, stack |
26 | 26 | from torch.nn import Module, ModuleList, Linear
|
27 | 27 |
|
28 | 28 | from torch.utils.data import Dataset, DataLoader
|
@@ -808,8 +808,8 @@ def forward(
|
808 | 808 |
|
809 | 809 | if exists(cache):
|
810 | 810 | cached_k, cached_v = cache
|
811 |
| - k = torch.cat((cached_k, k), dim = -2) |
812 |
| - v = torch.cat((cached_v, v), dim = -2) |
| 811 | + k = cat((cached_k, k), dim = -2) |
| 812 | + v = cat((cached_v, v), dim = -2) |
813 | 813 |
|
814 | 814 | # maybe kv cache
|
815 | 815 |
|
@@ -1030,7 +1030,7 @@ def forward(
|
1030 | 1030 | skip = skips.pop()
|
1031 | 1031 |
|
1032 | 1032 | residual = x
|
1033 |
| - x = torch.cat((x, skip), dim = -1) |
| 1033 | + x = cat((x, skip), dim = -1) |
1034 | 1034 | x = skip_proj(x) + residual
|
1035 | 1035 |
|
1036 | 1036 | # attention and feedforward
|
@@ -1695,7 +1695,7 @@ def generate_text_only(
|
1695 | 1695 |
|
1696 | 1696 | sample = gumbel_sample(logits, temperature = temperature, dim = -1)
|
1697 | 1697 |
|
1698 |
| - out = torch.cat((out, sample), dim = -1) |
| 1698 | + out = cat((out, sample), dim = -1) |
1699 | 1699 |
|
1700 | 1700 | return out[..., prompt_seq_len:]
|
1701 | 1701 |
|
@@ -2147,7 +2147,7 @@ def forward(
|
2147 | 2147 | precede_modality_tokens = len(modality_meta_info) + 2
|
2148 | 2148 | succeed_modality_tokens = 1
|
2149 | 2149 |
|
2150 |
| - text_tensor = torch.cat(( |
| 2150 | + text_tensor = cat(( |
2151 | 2151 | tensor_([self.meta_id]),
|
2152 | 2152 | modality_meta_info,
|
2153 | 2153 | tensor_([som_id]),
|
@@ -2200,12 +2200,12 @@ def inner(embed: Float['b n d'], need_splice = True) -> Float['...']:
|
2200 | 2200 |
|
2201 | 2201 | batch_modality_pos_emb.append(pos_emb)
|
2202 | 2202 |
|
2203 |
| - text.append(torch.cat(batch_text)) |
| 2203 | + text.append(cat(batch_text)) |
2204 | 2204 |
|
2205 | 2205 | if need_axial_pos_emb:
|
2206 |
| - modality_pos_emb.append(torch.cat(batch_modality_pos_emb, dim = -2)) |
| 2206 | + modality_pos_emb.append(cat(batch_modality_pos_emb, dim = -2)) |
2207 | 2207 |
|
2208 |
| - modality_tokens.append(torch.cat(batch_modality_tokens)) |
| 2208 | + modality_tokens.append(cat(batch_modality_tokens)) |
2209 | 2209 | modality_positions.append(batch_modality_positions)
|
2210 | 2210 |
|
2211 | 2211 | modality_index += 1
|
|
0 commit comments