Skip to content

Commit d605712

Browse files
authored
Merge pull request #62 from SachinVarghese/fix-checks
Lint and type check fixes
2 parents a8faad7 + 40d9af9 commit d605712

File tree

7 files changed

+438
-1005
lines changed

7 files changed

+438
-1005
lines changed

inference_perf/datagen/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pydantic import BaseModel
1515
from inference_perf.config import APIType
1616
from abc import ABC, abstractmethod
17-
from typing import Generator, Tuple, Optional, List
17+
from typing import Generator, Optional, List
1818

1919

2020
class CompletionData(BaseModel):
@@ -38,9 +38,11 @@ class InferenceData(BaseModel):
3838

3939
class DataGenerator(ABC):
4040
"""Abstract base class for data generators."""
41+
4142
apiType: APIType
4243

4344
"""Abstract base class for data generators."""
45+
4446
def __init__(self, apiType: APIType) -> None:
4547
if apiType not in self.get_supported_apis():
4648
raise Exception(f"Unsupported API type {apiType}")

inference_perf/datagen/hf_sharegpt_datagen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_data(self) -> Generator[InferenceData, None, None]:
6262
except (KeyError, TypeError) as e:
6363
print(f"Skipping invalid completion data: {e}")
6464
continue
65-
elif self.APIType == APIType.Chat:
65+
elif self.apiType == APIType.Chat:
6666
yield InferenceData(
6767
type=APIType.Chat,
6868
chat=ChatCompletionData(
@@ -73,4 +73,4 @@ def get_data(self) -> Generator[InferenceData, None, None]:
7373
),
7474
)
7575
else:
76-
raise Exception("Unsupported API type")
76+
raise Exception("Unsupported API type")

inference_perf/datagen/mock_datagen.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Generator, List
1616
from inference_perf.config import APIType
1717

18+
1819
class MockDataGenerator(DataGenerator):
1920
def __init__(self, apiType: APIType) -> None:
2021
super().__init__(apiType)
@@ -28,9 +29,6 @@ def get_data(self) -> Generator[InferenceData, None, None]:
2829
while True:
2930
i += 1
3031
if self.apiType == APIType.Completion:
31-
yield InferenceData(
32-
data=CompletionData(prompt="text" + str(i))
33-
)
32+
yield InferenceData(data=CompletionData(prompt="text" + str(i)))
3433
else:
3534
raise Exception("Unsupported API type")
36-

inference_perf/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from inference_perf.loadgen import LoadGenerator
1515
from inference_perf.config import DataGenType
16-
from inference_perf.datagen import MockDataGenerator, HFShareGPTDataGenerator
16+
from inference_perf.datagen import DataGenerator, MockDataGenerator, HFShareGPTDataGenerator
1717
from inference_perf.client import ModelServerClient, vLLMModelServerClient
1818
from inference_perf.reportgen import ReportGenerator, MockReportGenerator
1919
from inference_perf.metrics import MockMetricsClient
@@ -48,9 +48,11 @@ def main_cli() -> None:
4848

4949
# Define DataGenerator
5050
if config.data:
51-
datagen = MockDataGenerator(config.vllm.api)
51+
datagen: DataGenerator
5252
if config.data.type == DataGenType.ShareGPT:
5353
datagen = HFShareGPTDataGenerator(config.vllm.api)
54+
else:
55+
datagen = MockDataGenerator(config.vllm.api)
5456
else:
5557
raise Exception("data config missing")
5658

inference_perf/utils/custom_tokenizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
class CustomTokenizer:
1919
def __init__(self, tokenizer_id: str, token: Optional[str], trust_remote_code: Optional[bool]):
20-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, token=token, trust_remote_code=trust_remote_code)
20+
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
21+
tokenizer_id, token=token, trust_remote_code=trust_remote_code
22+
)
2123

2224
def count_tokens(self, text: str) -> int:
23-
if not text:
25+
if text == "":
2426
return 0
2527
return len(self.tokenizer(text).input_ids)
2628

0 commit comments

Comments
 (0)