Skip to content

Commit da541e6

Browse files
committed
clean up ruff check
Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com>
1 parent ff97f53 commit da541e6

File tree

7 files changed

+16
-8
lines changed

7 files changed

+16
-8
lines changed

aiu_fms_testing_utils/utils/paged.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union
55
import torch
6+
import fms.utils.spyre.paged # noqa
67

78

89
def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs):

scripts/generate_metrics.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,11 @@ def find_eos_index(reference_tokens, eos_token_id):
175175
return result
176176

177177

178-
def filter_before_eos(level_metric, filter_indexes):
178+
def filter_before_eos(metrics, filter_indexes):
179179
from itertools import groupby
180180

181181
filtered_results = [
182-
list(g)[: filter_indexes[k]]
183-
for k, g in groupby(level_metric, key=lambda x: x[0])
182+
list(g)[: filter_indexes[k]] for k, g in groupby(metrics, key=lambda x: x[0])
184183
]
185184
return [item for sublist in filtered_results for item in sublist]
186185

@@ -197,10 +196,10 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
197196
return input_ids, padding_kwargs
198197

199198

200-
def write_csv(metric, path, metric_name):
199+
def write_csv(metrics, path, metric_name):
201200
with open(path, "w") as f:
202201
f.write(f"{metric_name}\n")
203-
for t in metric:
202+
for t in metrics:
204203
f.write(f"{t[2].item()}\n")
205204
f.close()
206205

scripts/inference.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,15 @@
257257
if args.quantization == "gptq":
258258
if "aiu" in args.device_type:
259259
try:
260+
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear # noqa
261+
260262
print("Loaded `aiu_addons` functionalities")
261263
except ImportError:
262264
raise ImportError("Failed to import GPTQ addons from fms-mo.")
263265
elif args.quantization == "int8":
264266
try:
267+
from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear # noqa
268+
265269
print("Loaded `aiu_addons` functionalities")
266270
except ImportError:
267271
raise ImportError("Failed to import INT8 addons from fms-mo.")
@@ -301,6 +305,8 @@
301305
device = torch.device(args.device_type, local_rank)
302306
torch.cuda.set_device(device)
303307
elif is_aiu_backend:
308+
from torch_sendnn import torch_sendnn # noqa
309+
304310
if not args.distributed:
305311
aiu_setup.aiu_setup(rank, world_size)
306312

scripts/roberta.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from fms.models.hf import to_hf_api
3030

3131
# Import AIU Libraries
32+
from torch_sendnn import torch_sendnn # noqa
3233

3334
# ==============================================================
3435
# Main

scripts/small-toy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from fms.utils.tp_wrapping import apply_tp
1717

1818
# Import AIU Libraries
19+
from torch_sendnn import torch_sendnn # noqa
1920

2021

2122
# ==============================================================

scripts/validation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@
294294
aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
295295

296296
# Always initialize AIU in this script
297+
from torch_sendnn import torch_sendnn # noqa
297298

298299
if not args.distributed:
299300
aiu_setup.aiu_setup(rank, world_size)

tests/models/test_decoders.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,11 @@ def __find_eos_index(reference_tokens, eos_token_id, seq_length, max_new_tokens)
286286
return result
287287

288288

289-
def __filter_before_eos(level_metric, filter_indexes):
289+
def __filter_before_eos(metrics, filter_indexes):
290290
from itertools import groupby
291291

292292
filtered_results = [
293-
list(g)[: filter_indexes[k]]
294-
for k, g in groupby(level_metric, key=lambda x: x[0])
293+
list(g)[: filter_indexes[k]] for k, g in groupby(metrics, key=lambda x: x[0])
295294
]
296295
return [item for sublist in filtered_results for item in sublist]
297296

0 commit comments

Comments
 (0)