Skip to content

Installing torchmetrics causes GIL error even without usage #3296

@sarang-c

Description

@sarang-c

🐛 Bug

Error thrown:

Fatal Python error: PyGILState_Release: thread state 0x7a0b1400fe30 must be current when releasing
Python runtime state: finalizing (tstate=0x0000000000ba6ac8)

Thread 0x00007a0ecbe5c080 (most recent call first):
  <no Python frame>

To Reproduce

Code
from transformers import AutoModel
import torch, torchaudio
from datasets import load_dataset, Audio
from dataclasses import dataclass, field
import tomllib
import argparse
import os
from itertools import islice
from abc import abstractmethod
# from torchmetrics.text import WordErrorRate

@dataclass
class ModelConfig:
    model_id: str = "ai4bharat/indic-conformer-600m-multilingual"
    expected_sampling_rate: int = 16000

@dataclass
class DataloaderConfig:
    dataset_name: str = "ai4bharat/IndicVoices"
    subset: str = "hindi"
    split: str = "train"
    streaming: bool = True

@dataclass
class EvaluationConfig:
    num_samples: int = 100

@dataclass
class STTConfig:
    model: ModelConfig = field(default_factory=ModelConfig)
    dataloader: DataloaderConfig = field(default_factory=DataloaderConfig)
    evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)

    def set_key_val_by_path(self, key: str, value_str: str):
        key_parts = key.split(".")
        node = self
        for field_name in key_parts[:-1]:
            node = getattr(node, field_name)
        field_name = key_parts[-1]

        old_value = getattr(node, field_name)
        if old_value is None:
            raise ValueError("We don't yet support setting keys with old_value=None")

        def typecast_val(field: str, val: str, _type: type):
            # handle boolean separately since bool("false") is True.
            if _type is bool:
                assert val.lower() in ("true", "false"), f"Cannot assign {val} to {field}(bool)"
                return val.lower() == "true"
            return _type(val)

        field = type(node).__name__ + "." + field_name
        new_value = typecast_val(field, value_str, type(old_value))

        # set the new value.
        setattr(node, field_name, new_value)
        return field, old_value, new_value

    @classmethod
    def from_toml(cls, path: str):
        with open(path, "rb") as f:
            toml_cfg = tomllib.load(f)

        def from_dict(dc_cls, field_to_value):
            if field_to_value is None:
                return dc_cls()

            cls_field_names = {f.name for f in dc_cls.__dataclass_fields__.values()}
            for key in field_to_value.keys():
                if key not in cls_field_names:
                    raise ValueError(f"Unknown field: {dc_cls.__name__}.{key}")

            return dc_cls(**field_to_value)

        model = from_dict(ModelConfig, toml_cfg.get("model"))
        dataloader = from_dict(DataloaderConfig, toml_cfg.get("dataloader"))
        evaluation = from_dict(EvaluationConfig, toml_cfg.get("evaluation"))

        config = cls(
            model=model,
            dataloader=dataloader,
            evaluation=evaluation,
        )

        return config

class BaseDataset:
    def __init__(self, dataloader_config: DataloaderConfig, expected_sampling_rate: int):
        self.dataloader_config = dataloader_config
        self.expected_sampling_rate = expected_sampling_rate
    
    @abstractmethod
    def __iter__(self):
        pass
    
    @abstractmethod
    def __len__(self):
        pass

class IndicVoicesDataset(BaseDataset):
    def __init__(self, dataloader_config: DataloaderConfig, expected_sampling_rate: int):
        super().__init__(dataloader_config, expected_sampling_rate)
        self.dataset = load_dataset(dataloader_config.dataset_name, dataloader_config.subset, split=dataloader_config.split, streaming=dataloader_config.streaming)
    
    def __iter__(self):
        for sample in self.dataset:
            clip = sample['audio_filepath']
            wav = torch.tensor(clip["array"], dtype=torch.float32).unsqueeze(0)
            if(clip["sampling_rate"] != self.expected_sampling_rate):
                wav = torchaudio.transforms.Resample(clip["sampling_rate"], self.expected_sampling_rate)(wav)
            yield wav, sample['text'] #, sample['verbatim'], sample['normalized']
    
    def __len__(self):
        return len(self.dataset)

class DatasetHandler:
    def __init__(self, dataloader_config: DataloaderConfig, expected_sampling_rate: int):
        self.dataset: BaseDataset = None
        if dataloader_config.dataset_name == "ai4bharat/IndicVoices":
            self.dataset = IndicVoicesDataset(dataloader_config, expected_sampling_rate)
        else:
            raise ValueError(f"Unknown dataset: {dataloader_config.dataset_name}")
    
    def __iter__(self):
        return self.dataset.__iter__()
    
    def __len__(self):
        return self.dataset.__len__()

def evaluate(config: STTConfig):
    model = AutoModel.from_pretrained(config.model.model_id, trust_remote_code=True)
    dataset = DatasetHandler(config.dataloader, config.model.expected_sampling_rate)
    # wer_metric = WordErrorRate()
    for wav, ground_truth in dataset:
        transcription_ctc = model(wav, "hi", "ctc") # using only ctc for now.
        # calculate WER
        # wer_score = wer_metric([ground_truth], [transcription_ctc]).item()
        # print(f"WER: {wer_score}")
        break
    print('done')

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    args = parser.parse_args()
    stt_config = STTConfig.from_toml(args.config)
    evaluate(stt_config)

if __name__ == "__main__":
    main()
Environment

absl-py==2.3.1
aiohappyeyeballs==2.6.1
aiohttp==3.12.15
aiosignal==1.4.0
asttokens==3.0.0
attrs==25.3.0
audioread==3.0.1
certifi==2025.7.14
cffi==1.17.1
charset-normalizer==3.4.2
comm==0.2.3
contourpy==1.3.3
cycler==0.12.1
datasets==4.0.0
debugpy==1.8.16
decorator==5.2.1
dill==0.3.8
executing==2.2.0
filelock==3.18.0
fonttools==4.59.1
frozenlist==1.7.0
fsspec==2025.3.0
grpcio==1.73.1
hf-transfer==0.1.9
hf-xet==1.1.5
huggingface-hub==0.33.4
idna==3.10
iniconfig==2.1.0
ipykernel==6.30.1
ipython==9.4.0
ipython-pygments-lexers==1.1.1
ipywidgets==8.1.7
jedi==0.19.2
jinja2==3.1.6
joblib==1.5.1
jupyter-client==8.6.3
jupyter-core==5.8.1
jupyterlab-widgets==3.0.15
kiwisolver==1.4.9
lazy-loader==0.4
librosa==0.11.0
llvmlite==0.44.0
markdown==3.8.2
markupsafe==3.0.2
matplotlib==3.10.5
matplotlib-inline==0.1.7
mpmath==1.3.0
msgpack==1.1.1
multidict==6.6.4
multiprocess==0.70.16
nest-asyncio==1.6.0
networkx==3.5
numba==0.61.2
numpy==2.2.6
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
onnxruntime==1.23.1
orjson==3.11.2
packaging==25.0
pandas==2.3.1
parso==0.8.4
pexpect==4.9.0
pillow==11.3.0
pip==24.0
platformdirs==4.3.8
pluggy==1.6.0
pooch==1.8.2
prompt-toolkit==3.0.51
propcache==0.3.2
protobuf==6.31.1
psutil==7.0.0
ptyprocess==0.7.0
pure-eval==0.2.3
pyarrow==21.0.0
pycparser==2.22
pygments==2.19.2
pyjwt==2.10.1
pyparsing==3.2.3
pytest==8.4.1
python-dateutil==2.9.0.post0
pytz==2025.2
pyyaml==6.0.2
pyzmq==27.0.1
regex==2024.11.6
requests==2.32.4
safetensors==0.5.3
scikit-learn==1.7.1
scipy==1.16.1
setuptools==80.9.0
six==1.17.0
soundfile==0.13.1
soxr==0.5.0.post1
stack-data==0.6.3
sympy==1.14.0
tensorboard==2.19.0
tensorboard-data-server==0.7.2
threadpoolctl==3.6.0
tokenizers==0.21.2
torch==2.7.0
torchaudio==2.7.0
torchcodec==0.3.0
tornado==6.5.2
tqdm==4.67.1
traitlets==5.14.3
transformers==4.52.4
triton==3.3.0
typing-extensions==4.14.1
tzdata==2025.2
urllib3==2.5.0
uv==0.7.9
wcwidth==0.2.13
werkzeug==3.1.3
widgetsnbextension==4.0.14
xxhash==3.5.0
yarl==1.20.1

The above code sample runs correctly, but on installing lightning-utilities==0.15.2 and torchmetrics==1.8.2. It starts to crash.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug / fixSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions