Skip to content

Commit e05b9c5

Browse files
committed
clean up ruff check
Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com>
1 parent 77e9ace commit e05b9c5

File tree

15 files changed

+47
-57
lines changed

15 files changed

+47
-57
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,8 @@ def extract_validation_information(
288288

289289
if hasattr(post_iteration_hook, "extracted_logits"):
290290
validation_info = [
291-
{"tokens": t.to("cpu"), "logits": l.to("cpu")}
292-
for t, l in zip(
291+
{"tokens": t.to("cpu"), "logits": logits.to("cpu")}
292+
for t, logits in zip(
293293
torch.unbind(result), torch.unbind(post_iteration_hook.extracted_logits)
294294
)
295295
]

aiu_fms_testing_utils/utils/aiu_setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False):
4949
# ) # directory needs to exist
5050

5151
if os.getenv("FLEX_COMPUTE") == "SENTIENT":
52-
dprint(f"Sentient AIU: Enabled")
52+
dprint("Sentient AIU: Enabled")
5353
else:
54-
dprint(f"Sentient AIU: Disabled (Senulator)")
54+
dprint("Sentient AIU: Disabled (Senulator)")
5555

5656

5757
# ==============================================================
@@ -67,6 +67,6 @@ def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False
6767
os.environ["MASTER_ADDR"] = "localhost"
6868
os.environ["MASTER_PORT"] = "12355"
6969
elif rank == 0 or verbose:
70-
dprint(f"Detected running via torchrun")
70+
dprint("Detected running via torchrun")
7171

7272
aiu_setup(rank, world_size)

aiu_fms_testing_utils/utils/paged.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union
55
import torch
6-
import fms.utils.spyre.paged
6+
import fms.utils.spyre.paged # noqa
77

88

99
def adjust_inputs_to_batch(input_ids: torch.Tensor, **padding_kwargs):
@@ -293,7 +293,7 @@ def generate(
293293
v, _ = torch.topk(logits, top_k)
294294
logits[logits < v[:, [-1]]] = -float("inf")
295295

296-
probs = F.softmax(logits, dim=-1)
296+
probs = F.softmax(logits, dim=-1) # noqa: F821
297297
next_val = torch.multinomial(probs, num_samples=1)
298298
else:
299299
next_val = torch.argmax(logits, dim=-1).unsqueeze(0).t()

scripts/generate_metrics.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import argparse
22
import ast
3-
import json
43
import os
5-
import random
6-
from typing import List, Optional, Tuple
74

85
import torch
96
from torch import distributed as dist
@@ -178,11 +175,11 @@ def find_eos_index(reference_tokens, eos_token_id):
178175
return result
179176

180177

181-
def filter_before_eos(l, filter_indexes):
178+
def filter_before_eos(metrics, filter_indexes):
182179
from itertools import groupby
183180

184181
filtered_results = [
185-
list(g)[: filter_indexes[k]] for k, g in groupby(l, key=lambda x: x[0])
182+
list(g)[: filter_indexes[k]] for k, g in groupby(metrics, key=lambda x: x[0])
186183
]
187184
return [item for sublist in filtered_results for item in sublist]
188185

@@ -199,10 +196,10 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
199196
return input_ids, padding_kwargs
200197

201198

202-
def write_csv(l, path, metric):
199+
def write_csv(metrics, path, metric_name):
203200
with open(path, "w") as f:
204-
f.write(f"{metric}\n")
205-
for t in l:
201+
f.write(f"{metric_name}\n")
202+
for t in metrics:
206203
f.write(f"{t[2].item()}\n")
207204
f.close()
208205

@@ -279,20 +276,20 @@ def write_csv(l, path, metric):
279276
if num_test_tokens_per_sequence is None:
280277
num_test_tokens_per_sequence = args.max_new_tokens
281278

282-
cross_entropy = lambda r, t: torch.nn.CrossEntropyLoss()(
279+
cross_entropy = lambda r, t: torch.nn.CrossEntropyLoss()( # noqa: E731
283280
r, t.softmax(dim=1).to(dtype=torch.float32)
284281
)
285-
prob_mean = lambda r, t: torch.mean(
282+
prob_mean = lambda r, t: torch.mean( # noqa: E731
286283
(
287284
r.softmax(dim=1).to(dtype=torch.float32)
288285
/ t.softmax(dim=1).to(dtype=torch.float32)
289286
)
290287
- 1.0
291288
)
292-
prob_std = lambda r, t: torch.std(
289+
prob_std = lambda r, t: torch.std( # noqa: E731
293290
r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32)
294291
)
295-
diff_mean = lambda r, t: torch.mean(
292+
diff_mean = lambda r, t: torch.mean( # noqa: E731
296293
torch.abs(
297294
r.softmax(dim=1).to(dtype=torch.float32)
298295
- t.softmax(dim=1).to(dtype=torch.float32)

scripts/inference.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pathlib import Path
88
import random
99
import time
10-
import contextlib
1110

1211
# Third Party
1312
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model
@@ -236,17 +235,17 @@
236235
if args.quantization == "gptq":
237236
if "aiu" in args.device_type:
238237
try:
239-
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear
238+
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear # noqa
240239

241240
print("Loaded `aiu_addons` functionalities")
242-
except:
241+
except ImportError:
243242
raise ImportError("Failed to import GPTQ addons from fms-mo.")
244243
elif args.quantization == "int8":
245244
try:
246-
from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear
245+
from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear # noqa
247246

248247
print("Loaded `aiu_addons` functionalities")
249-
except:
248+
except ImportError:
250249
raise ImportError("Failed to import INT8 addons from fms-mo.")
251250

252251
# this is a test model config
@@ -284,7 +283,7 @@
284283
device = torch.device(args.device_type, local_rank)
285284
torch.cuda.set_device(device)
286285
elif is_aiu_backend:
287-
from torch_sendnn import torch_sendnn
286+
from torch_sendnn import torch_sendnn # noqa
288287

289288
if not args.distributed:
290289
aiu_setup.aiu_setup(rank, world_size)
@@ -617,7 +616,7 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
617616

618617
if args.fixed_prompt_length != 0 and args.fixed_prompt_length < max_len:
619618
dprint(
620-
f"One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set."
619+
"One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set."
621620
)
622621
exit(1)
623622
prompts = truncate_prompts_to_max_length(prompts, max_len, max_allowed_length)
@@ -734,7 +733,7 @@ def infer(use_cache, do_sample, warmup):
734733
] # True/False are identical with greedy iff `torch.use_deterministic_algorithms(True)`
735734

736735
if args.compile:
737-
dprint(f"compilation warmup")
736+
dprint("compilation warmup")
738737
pt_compile_model_time = time.time()
739738
if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
740739
warmup_model(
@@ -756,7 +755,7 @@ def infer(use_cache, do_sample, warmup):
756755
pt_compile_model_time = time.time() - pt_compile_model_time
757756
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
758757

759-
dprint(f"generating output")
758+
dprint("generating output")
760759

761760
for sample, cache in itertools.product(do_sample, use_cache):
762761
for _ in range(args.iters):

scripts/roberta.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from fms.models.hf import to_hf_api
3030

3131
# Import AIU Libraries
32-
from torch_sendnn import torch_sendnn
32+
from torch_sendnn import torch_sendnn # noqa
3333

3434
# ==============================================================
3535
# Main
@@ -99,7 +99,7 @@
9999
# Create the model
100100
# -------------
101101
if 0 == world_rank:
102-
dprint(f"Creating the model...")
102+
dprint("Creating the model...")
103103
# model_name = "roberta-base"
104104
# model_name = "deepset/roberta-base-squad2-distilled"
105105
model_name = "FacebookAI/roberta-base"
@@ -128,7 +128,7 @@
128128
# Compile the model
129129
# -------------
130130
if 0 == world_rank:
131-
dprint(f"Compiling the model...")
131+
dprint("Compiling the model...")
132132
the_compiled_model = torch.compile(hf_model_fms, backend=dynamo_backend)
133133
the_compiled_model.eval() # inference only mode
134134
torch.set_grad_enabled(False)
@@ -149,15 +149,15 @@
149149

150150
# First run will create compiled artifacts
151151
if 0 == world_rank:
152-
dprint(f"Running model: First Time...")
152+
dprint("Running model: First Time...")
153153
unmasker = pipeline("fill-mask", model=the_compiled_model, tokenizer=tokenizer)
154154
the_output = unmasker(prompt)
155155
if 0 == world_rank:
156156
dprint(f"Answer: ({the_output[0]['score']:6.5f}) {the_output[0]['sequence']}")
157157

158158
# Second run will be faster as it uses the cached artifacts
159159
if 0 == world_rank:
160-
dprint(f"Running model: Second Time...")
160+
dprint("Running model: Second Time...")
161161
unmasker = pipeline("fill-mask", model=the_compiled_model, tokenizer=tokenizer)
162162
the_output = unmasker(prompt)
163163
if 0 == world_rank:
@@ -167,7 +167,7 @@
167167
# Cleanup
168168
# -------------
169169
if 0 == world_rank:
170-
dprint(f"Done")
170+
dprint("Done")
171171
if is_distributed:
172172
torch.distributed.barrier()
173173
torch.distributed.destroy_process_group()

scripts/small-toy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from fms.utils.tp_wrapping import apply_tp
1717

1818
# Import AIU Libraries
19-
from torch_sendnn import torch_sendnn
19+
from torch_sendnn import torch_sendnn # noqa
2020

2121

2222
# ==============================================================
@@ -136,7 +136,7 @@ def forward(self, x):
136136
# Create the model
137137
# -------------
138138
if 0 == world_rank:
139-
dprint(f"Creating the model...")
139+
dprint("Creating the model...")
140140
the_model = ToyModelFM()
141141
if is_distributed:
142142
# Create a Tensor Parallel version of the model
@@ -146,7 +146,7 @@ def forward(self, x):
146146
# Compile the model
147147
# -------------
148148
if 0 == world_rank:
149-
dprint(f"Compiling the model...")
149+
dprint("Compiling the model...")
150150
the_compiled_model = torch.compile(the_model, backend=dynamo_backend)
151151
the_compiled_model.eval() # inference only mode
152152
torch.set_grad_enabled(False)
@@ -164,19 +164,19 @@ def forward(self, x):
164164

165165
# First run will create compiled artifacts
166166
if 0 == world_rank:
167-
dprint(f"Running model: First Time...")
167+
dprint("Running model: First Time...")
168168
the_outputs = the_compiled_model(the_inputs)
169169

170170
# Second run will be faster as it uses the cached artifacts
171171
if 0 == world_rank:
172-
dprint(f"Running model: Second Time...")
172+
dprint("Running model: Second Time...")
173173
the_outputs = the_compiled_model(the_inputs)
174174

175175
# -------------
176176
# Cleanup
177177
# -------------
178178
if 0 == world_rank:
179-
dprint(f"Done")
179+
dprint("Done")
180180
if is_distributed:
181181
torch.distributed.barrier()
182182
torch.distributed.destroy_process_group()

scripts/validation.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
LogitsExtractorHook,
2020
capture_level_1_metrics,
2121
extract_validation_information,
22-
StaticTokenInjectorHook,
2322
GoldenTokenHook,
2423
filter_failed_level_1_cases,
2524
validate_level_0,
@@ -264,7 +263,7 @@
264263
if args.quantization == "gptq":
265264
try:
266265
# validation script always loads AIU addon
267-
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear
266+
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear # noqa: F401
268267

269268
print("Loaded `aiu_addons` functionalities")
270269

@@ -295,7 +294,7 @@
295294
aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
296295

297296
# Always initialize AIU in this script
298-
from torch_sendnn import torch_sendnn
297+
from torch_sendnn import torch_sendnn # noqa
299298

300299
if not args.distributed:
301300
aiu_setup.aiu_setup(rank, world_size)
@@ -595,7 +594,7 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
595594

596595
if args.fixed_prompt_length != 0 and args.fixed_prompt_length < max_len:
597596
dprint(
598-
f"One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set."
597+
"One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set."
599598
)
600599
exit(1)
601600
prompts = truncate_prompts_to_max_length(prompts, max_len, max_allowed_length)
@@ -645,7 +644,7 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""):
645644
tokenizer,
646645
)
647646

648-
val_tokens = [torch.tensor(l) for l in validation_info.get_info("tokens")]
647+
val_tokens = [torch.tensor(_) for _ in validation_info.get_info("tokens")]
649648
max_val_len = max([prompt.size(0) for prompt in val_tokens])
650649
val_num_gen_tokens = int(args.max_new_tokens)
651650
if max_allowed_length is not None:

tests/models/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from aiu_fms_testing_utils.utils.aiu_setup import aiu_setup, rank, world_size
44
import os
5-
import pytest
65

76

87
def pytest_sessionstart(session):

tests/models/test_decoders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import os
2929

3030
try:
31-
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear
31+
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear # noqa
3232

3333
GPTQ_ENABLED = True
3434
except ImportError:
@@ -278,11 +278,11 @@ def __find_eos_index(reference_tokens, eos_token_id, seq_length, max_new_tokens)
278278
return result
279279

280280

281-
def __filter_before_eos(l, filter_indexes):
281+
def __filter_before_eos(metrics, filter_indexes):
282282
from itertools import groupby
283283

284284
filtered_results = [
285-
list(g)[: filter_indexes[k]] for k, g in groupby(l, key=lambda x: x[0])
285+
list(g)[: filter_indexes[k]] for k, g in groupby(metrics, key=lambda x: x[0])
286286
]
287287
return [item for sublist in filtered_results for item in sublist]
288288

0 commit comments

Comments
 (0)