Skip to content

huggingface asr integration #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions examples/huggingface/reverb_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Following https://huggingface.co/docs/transformers/en/custom_models
import math
from typing import Dict, List, Optional
from transformers import PretrainedConfig
import numpy as np
import yaml
from pyctcdecode import build_ctcdecoder


def cmvn(means: List[float], variance: List[float], count: int):
""" Calculate cmvn from stats

Returns:
a numpy array of [means, vars]
"""
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
return [means, variance]


class ReverbConfig(PretrainedConfig):
model_type = "reverb_asr"

def __init__(self, **kwargs):
super().__init__(**kwargs)

# Set default special tokens if not provided
if not hasattr(self, 'special_tokens'):
self.special_tokens = {
"<blank>": 0,
"<sos>": 2,
"<eos>": 2,
"<unk>": 1,
}

# Calculate CMVN if the required stats are provided
if hasattr(self, 'cmvn_mean_stat') and hasattr(self, 'cmvn_var_stat') and hasattr(self, 'cmvn_frame_num'):
self.cmvn_mean, self.cmvn_istd = cmvn(
self.cmvn_mean_stat,
self.cmvn_var_stat,
self.cmvn_frame_num
)

# Set default ratio if not provided
if not hasattr(self, 'inputs_to_logits_ratio'):
self.inputs_to_logits_ratio = 1

# Tokenizer configuration
if not hasattr(self, 'tokenizer_path'):
self.tokenizer_path = "path/to/tokenizer.model"
if not hasattr(self, 'units_path'):
self.units_path = "path/to/units.txt"
if not hasattr(self, 'decoder_beam_width'):
self.decoder_beam_width = 8
if not hasattr(self, 'decoder_token_min_logp'):
self.decoder_token_min_logp = -10

# Load units and build decoder
self._load_units_and_build_decoder()

def _load_units_and_build_decoder(self):
"""Load units from file and build the CTC decoder."""
decoder_ids = []
with open(self.units_path, 'r') as units_file:
for line in units_file:
token = line.split()[0]
if len(token) == 0:
continue
if token == '<blank>':
token = ''
decoder_ids.append(token)
self.decoder = build_ctcdecoder(decoder_ids)

@classmethod
def from_yaml_file(cls, yaml_file: str) -> "ReverbConfig":
"""Load a ReverbConfig from a YAML file.

Args:
yaml_file: Path to the YAML file containing the configuration

Returns:
A ReverbConfig instance loaded from the file
"""
with open(yaml_file, 'r') as f:
config_dict = yaml.safe_load(f)
return cls(**config_dict)
48 changes: 48 additions & 0 deletions examples/huggingface/reverb_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
model_type: reverb_asr
input_dim: 80
output_dim: 10001
cmvn_mean_stat: [33596438528.0, 35418329088.0, 39182106624.0, 41983324160.0, 44419112960.0, 46015381504.0, 46934564864.0, 47058870272.0, 47288012800.0, 47522979840.0, 48491438080.0, 49308729344.0, 50230493184.0, 50796900352.0, 51020386304.0, 51297456128.0, 51333586944.0, 51126181888.0, 51455569920.0, 50636410880.0, 49947033600.0, 50365546496.0, 49383075840.0, 49540546560.0, 49066065920.0, 49236889600.0, 48820707328.0, 49071112192.0, 48968024064.0, 49024458752.0, 49202397184.0, 49374433280.0, 49620660224.0, 49947111424.0, 50326310912.0, 50717818880.0, 51046891520.0, 51345678336.0, 51655733248.0, 51505459200.0, 51813666816.0, 51577262080.0, 51776524288.0, 51754237952.0, 51918598144.0, 52158758912.0, 52405276672.0, 52596776960.0, 52639731712.0, 52631220224.0, 52443103232.0, 52315619328.0, 52219695104.0, 52178399232.0, 52083040256.0, 52064792576.0, 51980918784.0, 51824164864.0, 51550973952.0, 51002216448.0, 50422747136.0, 49847754752.0, 49474338816.0, 48997863424.0, 48617009152.0, 48309174272.0, 48084140032.0, 48095608832.0, 47965765632.0, 47909335040.0, 47780065280.0, 47762370560.0, 47757099008.0, 47731314688.0, 47574110208.0, 47336361984.0, 47009054720.0, 46283513856.0, 44821860352.0, 42771775488.0]
cmvn_var_stat: [360475131904.0, 401487724544.0, 484368646144.0, 548414357504.0, 608912080896.0, 651613241344.0, 678013698048.0, 683624693760.0, 689524047872.0, 695375822848.0, 722376851456.0, 746773872640.0, 774244204544.0, 791678353408.0, 798920015872.0, 807307444224.0, 808713453568.0, 802957754368.0, 812319899648.0, 788076953600.0, 767619497984.0, 777970712576.0, 748566544384.0, 751065628672.0, 736340869120.0, 739872473088.0, 727466704896.0, 734006083584.0, 731017904128.0, 732582576128.0, 737590444032.0, 742469861376.0, 749455671296.0, 758746972160.0, 769666121728.0, 781107331072.0, 790730506240.0, 799342002176.0, 808164917248.0, 803454713856.0, 812040585216.0, 804632395776.0, 809866821632.0, 808861499392.0, 813548044288.0, 820701954048.0, 828343779328.0, 834335604736.0, 835754590208.0, 835251011584.0, 829192929280.0, 824705744896.0, 821224734720.0, 819399753728.0, 816182853632.0, 815243788288.0, 812578177024.0, 807846281216.0, 799796035584.0, 784661544960.0, 770915631104.0, 756696285184.0, 746462183424.0, 734193254400.0, 724980072448.0, 717529612288.0, 711156563968.0, 710358204416.0, 706386919424.0, 704228884480.0, 700537110528.0, 699519008768.0, 699025129472.0, 698035535872.0, 693109391360.0, 686047887360.0, 676213948416.0, 655917645824.0, 616676458496.0, 563932168192.0]
cmvn_frame_num: 3519342927
encoder: conformer
encoder_activation_type: swish
encoder_attention_dropout_rate: 0.1
encoder_attention_heads: 8
encoder_causal: true
encoder_cnn_module_kernel: 31
encoder_cnn_module_norm: layer_norm
encoder_dropout_rate: 0.1
encoder_input_layer: conv2d
encoder_linear_units: 2048
encoder_normalize_before: true
encoder_num_blocks: 18
encoder_num_langs: 2
encoder_output_size: 640
encoder_pos_enc_layer_type: rel_pos
encoder_positional_dropout_rate: 0.1
encoder_selfattention_layer_type: rel_selfattn
encoder_use_cnn_module: true
encoder_use_dynamic_chunk: true
decoder: lslbitransformer
decoder_attention_heads: 8
decoder_dropout_rate: 0.1
decoder_linear_units: 2048
decoder_num_blocks: 6
decoder_num_langs: 2
decoder_positional_dropout_rate: 0.1
decoder_r_num_blocks: 6
decoder_self_attention_dropout_rate: 0.1
decoder_src_attention_dropout_rate: 0.1
ctc_blank_id: 0
ctc_weight: 0.3
lsm_weight: 0.1
reverse_weight: 0.3
special_tokens:
"<blank>": 0
"<sos>": 2
"<eos>": 2
"<unk>": 1
tokenizer_path: "hf-reverb/tk.model"
units_path: "hf-reverb/tk.units.txt"
decoder_beam_width: 8
decoder_token_min_logp: -10
98 changes: 98 additions & 0 deletions examples/huggingface/reverb_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Following https://huggingface.co/docs/transformers/en/custom_models

from typing import List, Optional, Tuple, Union
import torch
from transformers import PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.cmvn import GlobalCMVN
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import LanguageSpecificBiTransformerDecoder
from wenet.transformer.encoder import ConformerEncoder
from reverb_config import ReverbConfig

class ReverbModel(PreTrainedModel):
config_class = ReverbConfig
main_input_name = "input_features"

def __init__(self, config):
super().__init__(config)
self.config = config
global_cmvn = GlobalCMVN(
torch.Tensor(config.cmvn_mean),
torch.Tensor(config.cmvn_istd),
)
encoder = ConformerEncoder(
config.input_dim,
global_cmvn=global_cmvn,
activation_type=config.encoder_activation_type,
attention_dropout_rate=config.encoder_attention_dropout_rate,
attention_heads=config.encoder_attention_heads,
causal=config.encoder_causal,
cnn_module_kernel=config.encoder_cnn_module_kernel,
cnn_module_norm=config.encoder_cnn_module_norm,
dropout_rate=config.encoder_dropout_rate,
input_layer=config.encoder_input_layer,
linear_units=config.encoder_linear_units,
normalize_before=config.encoder_normalize_before,
num_blocks=config.encoder_num_blocks,
num_langs=config.encoder_num_langs,
output_size=config.encoder_output_size,
pos_enc_layer_type=config.encoder_pos_enc_layer_type,
positional_dropout_rate=config.encoder_positional_dropout_rate,
selfattention_layer_type=config.encoder_selfattention_layer_type,
use_cnn_module=config.encoder_use_cnn_module,
use_dynamic_chunk=config.encoder_use_dynamic_chunk,
)

decoder = LanguageSpecificBiTransformerDecoder(
config.output_dim,
config.encoder_output_size,
attention_heads=config.decoder_attention_heads,
dropout_rate=config.decoder_dropout_rate,
linear_units=config.decoder_linear_units,
num_blocks=config.decoder_num_blocks,
num_langs=config.decoder_num_langs,
positional_dropout_rate=config.decoder_positional_dropout_rate,
r_num_blocks=config.decoder_r_num_blocks,
self_attention_dropout_rate=config.decoder_self_attention_dropout_rate,
src_attention_dropout_rate=config.decoder_src_attention_dropout_rate,
)

ctc = CTC(
config.output_dim,
config.encoder_output_size,
config.ctc_blank_id,
)

self.model = ASRModel(
vocab_size=config.output_dim,
encoder=encoder,
decoder=decoder,
ctc=ctc,
special_tokens=config.special_tokens,
ctc_weight=config.ctc_weight,
lsm_weight=config.lsm_weight,
reverse_weight=config.reverse_weight,
)
self.model.lsl_enc = True
self.model.lsl_dec = True

def forward(
self,
input_features=None,
feats_lengths=None,
labels=None,
labels_lengths=None,
**kwargs,
):
output = self.model.hf_forward(
input_features,
feats_lengths=feats_lengths,
labels=labels,
labels_lengths=labels_lengths,
)
return Seq2SeqLMOutput(
logits=output['ctc_probs'],
loss=output['loss'],
)
132 changes: 132 additions & 0 deletions examples/huggingface/reverb_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import json
from typing import List, Optional, Union
import numpy as np
import sentencepiece as spm
import torch
import torchaudio
from torchaudio.compliance import kaldi
from tqdm import tqdm
from transformers import BatchFeature, PreTrainedTokenizer, ProcessorMixin, SequenceFeatureExtractor
from transformers.utils import logging


logger = logging.get_logger(__name__)


class ReverbFeatureExtractor(SequenceFeatureExtractor):
model_input_names = ["input_features"]
def __init__(
self,
feature_size=80,
sampling_rate=16000,
frame_length=25,
frame_shift=10,
chunk_length=15,
padding_value=0.0,
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
return_attention_mask=False,
**kwargs,
)
self.frame_length = frame_length
self.frame_shift = frame_shift
self.chunk_length = chunk_length
self.max_chunk_size = 2051
self._processor_class = "CTCWithLM"

def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
device: Optional[str] = "cpu",
sampling_rate: Optional[int] = None,
**kwargs,
) -> BatchFeature:
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
ValueError(
f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
f" was sampled with {self.sampling_rate} and not {sampling_rate}."
" Attempting a conversion."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)

is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)

if is_batched:
raw_speech = [np.asarray([speech], dtype=np.float32) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
raw_speech = raw_speech.astype(np.float32)

if not is_batched:
raw_speech = [np.asarray([raw_speech])]

fbank_speech, feats_lengths = [], []
for waveform in raw_speech:
fbank_speech.append(
kaldi.fbank(
torch.tensor(waveform),
num_mel_bins=self.feature_size,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=0.0,
energy_floor=0.0,
sample_frequency=self.sampling_rate,
)
)
feats_lengths.append(fbank_speech[-1].shape[0])
fbank_speech = BatchFeature({
"input_features": fbank_speech,
"feats_lengths": feats_lengths,
})
padded = self.pad(
fbank_speech,
padding="max_length",
max_length=self.max_chunk_size,
)
return padded


class ReverbTokenizer(PreTrainedTokenizer):
def __init__(
self,
model: str,
#units: str,
**kwargs,
):
self.tokenizer = spm.SentencePieceProcessor(model)
"""self.units = dict()
with open(units, 'r') as units_file:
for line in tqdm(units_file.readlines()):
token, id = line.split()
self.units[int(id)] = token.replace('▁', ' ')"""


def encode(
self,
text,
**kwargs
):
return self.tokenizer.encode(text)

def decode(
self,
token_ids,
**kwargs,
):
return self.tokenizer.decode(token_ids[token_ids.nonzero()[0]].tolist())
Loading