Skip to content

Commit 2deefd3

Browse files
committed
clean up ruff check
Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com>
1 parent 77e9ace commit 2deefd3

File tree

15 files changed

+40
-59
lines changed

15 files changed

+40
-59
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,8 @@ def extract_validation_information(
288288

289289
if hasattr(post_iteration_hook, "extracted_logits"):
290290
validation_info = [
291-
{"tokens": t.to("cpu"), "logits": l.to("cpu")}
292-
for t, l in zip(
291+
{"tokens": t.to("cpu"), "logits": logits.to("cpu")}
292+
for t, logits in zip(
293293
torch.unbind(result), torch.unbind(post_iteration_hook.extracted_logits)
294294
)
295295
]

aiu_fms_testing_utils/utils/aiu_setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False):
4949
# ) # directory needs to exist
5050

5151
if os.getenv("FLEX_COMPUTE") == "SENTIENT":
52-
dprint(f"Sentient AIU: Enabled")
52+
dprint("Sentient AIU: Enabled")
5353
else:
54-
dprint(f"Sentient AIU: Disabled (Senulator)")
54+
dprint("Sentient AIU: Disabled (Senulator)")
5555

5656

5757
# ==============================================================
@@ -67,6 +67,6 @@ def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False
6767
os.environ["MASTER_ADDR"] = "localhost"
6868
os.environ["MASTER_PORT"] = "12355"
6969
elif rank == 0 or verbose:
70-
dprint(f"Detected running via torchrun")
70+
dprint("Detected running via torchrun")
7171

7272
aiu_setup(rank, world_size)

aiu_fms_testing_utils/utils/paged.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union
55
import torch
6-
import fms.utils.spyre.paged
6+
import torch.nn.functional as F
77

88

99
def adjust_inputs_to_batch(input_ids: torch.Tensor, **padding_kwargs):

scripts/generate_metrics.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import argparse
22
import ast
3-
import json
43
import os
5-
import random
6-
from typing import List, Optional, Tuple
74

85
import torch
96
from torch import distributed as dist
@@ -178,11 +175,11 @@ def find_eos_index(reference_tokens, eos_token_id):
178175
return result
179176

180177

181-
def filter_before_eos(l, filter_indexes):
178+
def filter_before_eos(metrics, filter_indexes):
182179
from itertools import groupby
183180

184181
filtered_results = [
185-
list(g)[: filter_indexes[k]] for k, g in groupby(l, key=lambda x: x[0])
182+
list(g)[: filter_indexes[k]] for k, g in groupby(metrics, key=lambda x: x[0])
186183
]
187184
return [item for sublist in filtered_results for item in sublist]
188185

@@ -199,10 +196,10 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
199196
return input_ids, padding_kwargs
200197

201198

202-
def write_csv(l, path, metric):
199+
def write_csv(metrics, path, metric_name):
203200
with open(path, "w") as f:
204-
f.write(f"{metric}\n")
205-
for t in l:
201+
f.write(f"{metric_name}\n")
202+
for t in metrics:
206203
f.write(f"{t[2].item()}\n")
207204
f.close()
208205

@@ -279,20 +276,20 @@ def write_csv(l, path, metric):
279276
if num_test_tokens_per_sequence is None:
280277
num_test_tokens_per_sequence = args.max_new_tokens
281278

282-
cross_entropy = lambda r, t: torch.nn.CrossEntropyLoss()(
279+
cross_entropy = lambda r, t: torch.nn.CrossEntropyLoss()( # noqa: E731
283280
r, t.softmax(dim=1).to(dtype=torch.float32)
284281
)
285-
prob_mean = lambda r, t: torch.mean(
282+
prob_mean = lambda r, t: torch.mean( # noqa: E731
286283
(
287284
r.softmax(dim=1).to(dtype=torch.float32)
288285
/ t.softmax(dim=1).to(dtype=torch.float32)
289286
)
290287
- 1.0
291288
)
292-
prob_std = lambda r, t: torch.std(
289+
prob_std = lambda r, t: torch.std( # noqa: E731
293290
r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32)
294291
)
295-
diff_mean = lambda r, t: torch.mean(
292+
diff_mean = lambda r, t: torch.mean( # noqa: E731
296293
torch.abs(
297294
r.softmax(dim=1).to(dtype=torch.float32)
298295
- t.softmax(dim=1).to(dtype=torch.float32)

scripts/inference.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pathlib import Path
88
import random
99
import time
10-
import contextlib
1110

1211
# Third Party
1312
from aiu_fms_testing_utils.utils import aiu_setup, warmup_model
@@ -236,17 +235,13 @@
236235
if args.quantization == "gptq":
237236
if "aiu" in args.device_type:
238237
try:
239-
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear
240-
241238
print("Loaded `aiu_addons` functionalities")
242-
except:
239+
except ImportError:
243240
raise ImportError("Failed to import GPTQ addons from fms-mo.")
244241
elif args.quantization == "int8":
245242
try:
246-
from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear
247-
248243
print("Loaded `aiu_addons` functionalities")
249-
except:
244+
except ImportError:
250245
raise ImportError("Failed to import INT8 addons from fms-mo.")
251246

252247
# this is a test model config
@@ -284,8 +279,6 @@
284279
device = torch.device(args.device_type, local_rank)
285280
torch.cuda.set_device(device)
286281
elif is_aiu_backend:
287-
from torch_sendnn import torch_sendnn
288-
289282
if not args.distributed:
290283
aiu_setup.aiu_setup(rank, world_size)
291284

@@ -617,7 +610,7 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
617610

618611
if args.fixed_prompt_length != 0 and args.fixed_prompt_length < max_len:
619612
dprint(
620-
f"One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set."
613+
"One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set."
621614
)
622615
exit(1)
623616
prompts = truncate_prompts_to_max_length(prompts, max_len, max_allowed_length)
@@ -734,7 +727,7 @@ def infer(use_cache, do_sample, warmup):
734727
] # True/False are identical with greedy iff `torch.use_deterministic_algorithms(True)`
735728

736729
if args.compile:
737-
dprint(f"compilation warmup")
730+
dprint("compilation warmup")
738731
pt_compile_model_time = time.time()
739732
if args.device_type == "aiu": # only run warmup for AIU, no need for senulator
740733
warmup_model(
@@ -756,7 +749,7 @@ def infer(use_cache, do_sample, warmup):
756749
pt_compile_model_time = time.time() - pt_compile_model_time
757750
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
758751

759-
dprint(f"generating output")
752+
dprint("generating output")
760753

761754
for sample, cache in itertools.product(do_sample, use_cache):
762755
for _ in range(args.iters):

scripts/roberta.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from fms.models.hf import to_hf_api
3030

3131
# Import AIU Libraries
32-
from torch_sendnn import torch_sendnn
3332

3433
# ==============================================================
3534
# Main
@@ -99,7 +98,7 @@
9998
# Create the model
10099
# -------------
101100
if 0 == world_rank:
102-
dprint(f"Creating the model...")
101+
dprint("Creating the model...")
103102
# model_name = "roberta-base"
104103
# model_name = "deepset/roberta-base-squad2-distilled"
105104
model_name = "FacebookAI/roberta-base"
@@ -128,7 +127,7 @@
128127
# Compile the model
129128
# -------------
130129
if 0 == world_rank:
131-
dprint(f"Compiling the model...")
130+
dprint("Compiling the model...")
132131
the_compiled_model = torch.compile(hf_model_fms, backend=dynamo_backend)
133132
the_compiled_model.eval() # inference only mode
134133
torch.set_grad_enabled(False)
@@ -149,15 +148,15 @@
149148

150149
# First run will create compiled artifacts
151150
if 0 == world_rank:
152-
dprint(f"Running model: First Time...")
151+
dprint("Running model: First Time...")
153152
unmasker = pipeline("fill-mask", model=the_compiled_model, tokenizer=tokenizer)
154153
the_output = unmasker(prompt)
155154
if 0 == world_rank:
156155
dprint(f"Answer: ({the_output[0]['score']:6.5f}) {the_output[0]['sequence']}")
157156

158157
# Second run will be faster as it uses the cached artifacts
159158
if 0 == world_rank:
160-
dprint(f"Running model: Second Time...")
159+
dprint("Running model: Second Time...")
161160
unmasker = pipeline("fill-mask", model=the_compiled_model, tokenizer=tokenizer)
162161
the_output = unmasker(prompt)
163162
if 0 == world_rank:
@@ -167,7 +166,7 @@
167166
# Cleanup
168167
# -------------
169168
if 0 == world_rank:
170-
dprint(f"Done")
169+
dprint("Done")
171170
if is_distributed:
172171
torch.distributed.barrier()
173172
torch.distributed.destroy_process_group()

scripts/small-toy.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from fms.utils.tp_wrapping import apply_tp
1717

1818
# Import AIU Libraries
19-
from torch_sendnn import torch_sendnn
2019

2120

2221
# ==============================================================
@@ -136,7 +135,7 @@ def forward(self, x):
136135
# Create the model
137136
# -------------
138137
if 0 == world_rank:
139-
dprint(f"Creating the model...")
138+
dprint("Creating the model...")
140139
the_model = ToyModelFM()
141140
if is_distributed:
142141
# Create a Tensor Parallel version of the model
@@ -146,7 +145,7 @@ def forward(self, x):
146145
# Compile the model
147146
# -------------
148147
if 0 == world_rank:
149-
dprint(f"Compiling the model...")
148+
dprint("Compiling the model...")
150149
the_compiled_model = torch.compile(the_model, backend=dynamo_backend)
151150
the_compiled_model.eval() # inference only mode
152151
torch.set_grad_enabled(False)
@@ -164,19 +163,19 @@ def forward(self, x):
164163

165164
# First run will create compiled artifacts
166165
if 0 == world_rank:
167-
dprint(f"Running model: First Time...")
166+
dprint("Running model: First Time...")
168167
the_outputs = the_compiled_model(the_inputs)
169168

170169
# Second run will be faster as it uses the cached artifacts
171170
if 0 == world_rank:
172-
dprint(f"Running model: Second Time...")
171+
dprint("Running model: Second Time...")
173172
the_outputs = the_compiled_model(the_inputs)
174173

175174
# -------------
176175
# Cleanup
177176
# -------------
178177
if 0 == world_rank:
179-
dprint(f"Done")
178+
dprint("Done")
180179
if is_distributed:
181180
torch.distributed.barrier()
182181
torch.distributed.destroy_process_group()

scripts/validation.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
LogitsExtractorHook,
2020
capture_level_1_metrics,
2121
extract_validation_information,
22-
StaticTokenInjectorHook,
2322
GoldenTokenHook,
2423
filter_failed_level_1_cases,
2524
validate_level_0,
@@ -264,7 +263,7 @@
264263
if args.quantization == "gptq":
265264
try:
266265
# validation script always loads AIU addon
267-
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear
266+
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear # noqa: F401
268267

269268
print("Loaded `aiu_addons` functionalities")
270269

@@ -295,7 +294,6 @@
295294
aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
296295

297296
# Always initialize AIU in this script
298-
from torch_sendnn import torch_sendnn
299297

300298
if not args.distributed:
301299
aiu_setup.aiu_setup(rank, world_size)
@@ -595,7 +593,7 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
595593

596594
if args.fixed_prompt_length != 0 and args.fixed_prompt_length < max_len:
597595
dprint(
598-
f"One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set."
596+
"One or more prompts require truncation. Truncation has been disabled as fixed_prompt_length has been set."
599597
)
600598
exit(1)
601599
prompts = truncate_prompts_to_max_length(prompts, max_len, max_allowed_length)
@@ -645,7 +643,7 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""):
645643
tokenizer,
646644
)
647645

648-
val_tokens = [torch.tensor(l) for l in validation_info.get_info("tokens")]
646+
val_tokens = [torch.tensor(_) for _ in validation_info.get_info("tokens")]
649647
max_val_len = max([prompt.size(0) for prompt in val_tokens])
650648
val_num_gen_tokens = int(args.max_new_tokens)
651649
if max_allowed_length is not None:

tests/models/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from aiu_fms_testing_utils.utils.aiu_setup import aiu_setup, rank, world_size
44
import os
5-
import pytest
65

76

87
def pytest_sessionstart(session):

tests/models/test_decoders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import os
2929

3030
try:
31-
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear
31+
from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear # noqa
3232

3333
GPTQ_ENABLED = True
3434
except ImportError:
@@ -278,11 +278,11 @@ def __find_eos_index(reference_tokens, eos_token_id, seq_length, max_new_tokens)
278278
return result
279279

280280

281-
def __filter_before_eos(l, filter_indexes):
281+
def __filter_before_eos(metrics, filter_indexes):
282282
from itertools import groupby
283283

284284
filtered_results = [
285-
list(g)[: filter_indexes[k]] for k, g in groupby(l, key=lambda x: x[0])
285+
list(g)[: filter_indexes[k]] for k, g in groupby(metrics, key=lambda x: x[0])
286286
]
287287
return [item for sublist in filtered_results for item in sublist]
288288

0 commit comments

Comments
 (0)