from transformers import AutoModel
import torch, torchaudio
from datasets import load_dataset, Audio
from dataclasses import dataclass, field
import tomllib
import argparse
import os
from itertools import islice
from abc import abstractmethod
# from torchmetrics.text import WordErrorRate
@dataclass
class ModelConfig:
model_id: str = "ai4bharat/indic-conformer-600m-multilingual"
expected_sampling_rate: int = 16000
@dataclass
class DataloaderConfig:
dataset_name: str = "ai4bharat/IndicVoices"
subset: str = "hindi"
split: str = "train"
streaming: bool = True
@dataclass
class EvaluationConfig:
num_samples: int = 100
@dataclass
class STTConfig:
model: ModelConfig = field(default_factory=ModelConfig)
dataloader: DataloaderConfig = field(default_factory=DataloaderConfig)
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
def set_key_val_by_path(self, key: str, value_str: str):
key_parts = key.split(".")
node = self
for field_name in key_parts[:-1]:
node = getattr(node, field_name)
field_name = key_parts[-1]
old_value = getattr(node, field_name)
if old_value is None:
raise ValueError("We don't yet support setting keys with old_value=None")
def typecast_val(field: str, val: str, _type: type):
# handle boolean separately since bool("false") is True.
if _type is bool:
assert val.lower() in ("true", "false"), f"Cannot assign {val} to {field}(bool)"
return val.lower() == "true"
return _type(val)
field = type(node).__name__ + "." + field_name
new_value = typecast_val(field, value_str, type(old_value))
# set the new value.
setattr(node, field_name, new_value)
return field, old_value, new_value
@classmethod
def from_toml(cls, path: str):
with open(path, "rb") as f:
toml_cfg = tomllib.load(f)
def from_dict(dc_cls, field_to_value):
if field_to_value is None:
return dc_cls()
cls_field_names = {f.name for f in dc_cls.__dataclass_fields__.values()}
for key in field_to_value.keys():
if key not in cls_field_names:
raise ValueError(f"Unknown field: {dc_cls.__name__}.{key}")
return dc_cls(**field_to_value)
model = from_dict(ModelConfig, toml_cfg.get("model"))
dataloader = from_dict(DataloaderConfig, toml_cfg.get("dataloader"))
evaluation = from_dict(EvaluationConfig, toml_cfg.get("evaluation"))
config = cls(
model=model,
dataloader=dataloader,
evaluation=evaluation,
)
return config
class BaseDataset:
def __init__(self, dataloader_config: DataloaderConfig, expected_sampling_rate: int):
self.dataloader_config = dataloader_config
self.expected_sampling_rate = expected_sampling_rate
@abstractmethod
def __iter__(self):
pass
@abstractmethod
def __len__(self):
pass
class IndicVoicesDataset(BaseDataset):
def __init__(self, dataloader_config: DataloaderConfig, expected_sampling_rate: int):
super().__init__(dataloader_config, expected_sampling_rate)
self.dataset = load_dataset(dataloader_config.dataset_name, dataloader_config.subset, split=dataloader_config.split, streaming=dataloader_config.streaming)
def __iter__(self):
for sample in self.dataset:
clip = sample['audio_filepath']
wav = torch.tensor(clip["array"], dtype=torch.float32).unsqueeze(0)
if(clip["sampling_rate"] != self.expected_sampling_rate):
wav = torchaudio.transforms.Resample(clip["sampling_rate"], self.expected_sampling_rate)(wav)
yield wav, sample['text'] #, sample['verbatim'], sample['normalized']
def __len__(self):
return len(self.dataset)
class DatasetHandler:
def __init__(self, dataloader_config: DataloaderConfig, expected_sampling_rate: int):
self.dataset: BaseDataset = None
if dataloader_config.dataset_name == "ai4bharat/IndicVoices":
self.dataset = IndicVoicesDataset(dataloader_config, expected_sampling_rate)
else:
raise ValueError(f"Unknown dataset: {dataloader_config.dataset_name}")
def __iter__(self):
return self.dataset.__iter__()
def __len__(self):
return self.dataset.__len__()
def evaluate(config: STTConfig):
model = AutoModel.from_pretrained(config.model.model_id, trust_remote_code=True)
dataset = DatasetHandler(config.dataloader, config.model.expected_sampling_rate)
# wer_metric = WordErrorRate()
for wav, ground_truth in dataset:
transcription_ctc = model(wav, "hi", "ctc") # using only ctc for now.
# calculate WER
# wer_score = wer_metric([ground_truth], [transcription_ctc]).item()
# print(f"WER: {wer_score}")
break
print('done')
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
stt_config = STTConfig.from_toml(args.config)
evaluate(stt_config)
if __name__ == "__main__":
main()
🐛 Bug
Error thrown:
To Reproduce
Code
Environment
absl-py==2.3.1
aiohappyeyeballs==2.6.1
aiohttp==3.12.15
aiosignal==1.4.0
asttokens==3.0.0
attrs==25.3.0
audioread==3.0.1
certifi==2025.7.14
cffi==1.17.1
charset-normalizer==3.4.2
comm==0.2.3
contourpy==1.3.3
cycler==0.12.1
datasets==4.0.0
debugpy==1.8.16
decorator==5.2.1
dill==0.3.8
executing==2.2.0
filelock==3.18.0
fonttools==4.59.1
frozenlist==1.7.0
fsspec==2025.3.0
grpcio==1.73.1
hf-transfer==0.1.9
hf-xet==1.1.5
huggingface-hub==0.33.4
idna==3.10
iniconfig==2.1.0
ipykernel==6.30.1
ipython==9.4.0
ipython-pygments-lexers==1.1.1
ipywidgets==8.1.7
jedi==0.19.2
jinja2==3.1.6
joblib==1.5.1
jupyter-client==8.6.3
jupyter-core==5.8.1
jupyterlab-widgets==3.0.15
kiwisolver==1.4.9
lazy-loader==0.4
librosa==0.11.0
llvmlite==0.44.0
markdown==3.8.2
markupsafe==3.0.2
matplotlib==3.10.5
matplotlib-inline==0.1.7
mpmath==1.3.0
msgpack==1.1.1
multidict==6.6.4
multiprocess==0.70.16
nest-asyncio==1.6.0
networkx==3.5
numba==0.61.2
numpy==2.2.6
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
onnxruntime==1.23.1
orjson==3.11.2
packaging==25.0
pandas==2.3.1
parso==0.8.4
pexpect==4.9.0
pillow==11.3.0
pip==24.0
platformdirs==4.3.8
pluggy==1.6.0
pooch==1.8.2
prompt-toolkit==3.0.51
propcache==0.3.2
protobuf==6.31.1
psutil==7.0.0
ptyprocess==0.7.0
pure-eval==0.2.3
pyarrow==21.0.0
pycparser==2.22
pygments==2.19.2
pyjwt==2.10.1
pyparsing==3.2.3
pytest==8.4.1
python-dateutil==2.9.0.post0
pytz==2025.2
pyyaml==6.0.2
pyzmq==27.0.1
regex==2024.11.6
requests==2.32.4
safetensors==0.5.3
scikit-learn==1.7.1
scipy==1.16.1
setuptools==80.9.0
six==1.17.0
soundfile==0.13.1
soxr==0.5.0.post1
stack-data==0.6.3
sympy==1.14.0
tensorboard==2.19.0
tensorboard-data-server==0.7.2
threadpoolctl==3.6.0
tokenizers==0.21.2
torch==2.7.0
torchaudio==2.7.0
torchcodec==0.3.0
tornado==6.5.2
tqdm==4.67.1
traitlets==5.14.3
transformers==4.52.4
triton==3.3.0
typing-extensions==4.14.1
tzdata==2025.2
urllib3==2.5.0
uv==0.7.9
wcwidth==0.2.13
werkzeug==3.1.3
widgetsnbextension==4.0.14
xxhash==3.5.0
yarl==1.20.1
The above code sample runs correctly, but on installing
lightning-utilities==0.15.2andtorchmetrics==1.8.2. It starts to crash.