Skip to content

Commit 3502ddb

Browse files
authored
Add support for Ultravox (#1207)
* Add support for ultravox * WhisperFeatureExtractor: support specifying max length * Add more whisper feature extraction unit tests * Fix links from merge
1 parent 161237b commit 3502ddb

File tree

8 files changed

+296
-51
lines changed

8 files changed

+296
-51
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
414414
1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
415415
1. **[Table Transformer](https://huggingface.co/docs/transformers/model_doc/table-transformer)** (from Microsoft Research) released with the paper [PubTables-1M: Towards Comprehensive Table Extraction From Unstructured Documents](https://arxiv.org/abs/2110.00061) by Brandon Smock, Rohith Pesala, Robin Abraham.
416416
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
417+
1. **Ultravox** (from Fixie.ai) released with the repository [fixie-ai/ultravox](https://github.com/fixie-ai/ultravox) by the Fixie.ai team.
417418
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
418419
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
419420
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.

docs/snippets/6_supported-models.snippet

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
130130
1. **[Table Transformer](https://huggingface.co/docs/transformers/model_doc/table-transformer)** (from Microsoft Research) released with the paper [PubTables-1M: Towards Comprehensive Table Extraction From Unstructured Documents](https://arxiv.org/abs/2110.00061) by Brandon Smock, Rohith Pesala, Robin Abraham.
131131
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
132+
1. **Ultravox** (from Fixie.ai) released with the repository [fixie-ai/ultravox](https://github.com/fixie-ai/ultravox) by the Fixie.ai team.
132133
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
133134
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
134135
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.

src/configs.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ function getNormalizedConfig(config) {
7070
case 'florence2':
7171
case 'llava_onevision':
7272
case 'idefics3':
73+
case 'ultravox':
7374
case 'smolvlm':
7475
// @ts-expect-error TS2339
7576
init_normalized_config = getNormalizedConfig(config.text_config);

src/models.js

Lines changed: 186 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ const MODEL_TYPES = {
133133
Musicgen: 7,
134134
MultiModality: 8,
135135
Phi3V: 9,
136+
AudioTextToText: 10,
136137
}
137138
//////////////////////////////////////////////////
138139

@@ -549,7 +550,7 @@ async function encoderForward(self, model_inputs) {
549550
const dims = encoderFeeds.pixel_values.dims;
550551
encoderFeeds.pixel_mask = ones([dims[0], dims[2], dims[3]]);
551552
}
552-
553+
553554
return await sessionRun(session, encoderFeeds);
554555
}
555556

@@ -587,58 +588,98 @@ async function decoderForward(self, model_inputs, is_encoder_decoder = false) {
587588

588589

589590

590-
function default_merge_input_ids_with_image_features({
591-
image_token_id,
591+
function default_merge_input_ids_with_features({
592+
modality_token_id,
592593
inputs_embeds,
593-
image_features,
594+
modality_features,
594595
input_ids,
595596
attention_mask,
596597
}) {
597-
const image_tokens = input_ids.tolist().map(ids =>
598+
const token_positions = input_ids.tolist().map(ids =>
598599
ids.reduce((acc, x, idx) => {
599-
if (x == image_token_id) acc.push(idx);
600+
if (x == modality_token_id) acc.push(idx);
600601
return acc;
601602
}, [])
602603
);
603-
const n_image_tokens = image_tokens.reduce((acc, x) => acc + x.length, 0);
604-
const n_image_features = image_features.dims[0];
605-
if (n_image_tokens !== n_image_features) {
606-
throw new Error(`Image features and image tokens do not match: tokens: ${n_image_tokens}, features ${n_image_features}`);
604+
const n_tokens = token_positions.reduce((acc, x) => acc + x.length, 0);
605+
const n_features = modality_features.dims[0];
606+
if (n_tokens !== n_features) {
607+
throw new Error(`Number of tokens and features do not match: tokens: ${n_tokens}, features ${n_features}`);
607608
}
608609

609610
// Equivalent to performing a masked_scatter
610611
let img = 0;
611-
for (let i = 0; i < image_tokens.length; ++i) {
612-
const tokens = image_tokens[i];
612+
for (let i = 0; i < token_positions.length; ++i) {
613+
const tokens = token_positions[i];
613614
const embeds = inputs_embeds[i];
614615
for (let j = 0; j < tokens.length; ++j) {
615-
embeds[tokens[j]].data.set(image_features[img++].data)
616+
embeds[tokens[j]].data.set(modality_features[img++].data)
616617
}
617618
}
618619
return { inputs_embeds, attention_mask }
619620
}
620621

621622

622-
/**
623-
* Forward pass of an image-text-to-text model.
624-
* @param {Object} self The image-text-to-text model model.
625-
* @param {Object} model_inputs The input data to be used for the forward pass.
626-
* @param {Tensor} [model_inputs.input_ids=null]
627-
* @param {Tensor} [model_inputs.attention_mask=null]
628-
* @param {Tensor} [model_inputs.pixel_values=null]
629-
* @param {Tensor} [model_inputs.position_ids=null]
630-
* @param {Tensor} [model_inputs.inputs_embeds=null]
631-
* @param {Tensor} [model_inputs.past_key_values=null]
632-
* @param {Object} [model_inputs.generation_config=null]
633-
* @param {Object} [model_inputs.logits_processor=null]
623+
function default_merge_input_ids_with_image_features({
624+
image_token_id,
625+
inputs_embeds,
626+
image_features,
627+
input_ids,
628+
attention_mask,
629+
}) {
630+
return default_merge_input_ids_with_features({
631+
modality_token_id: image_token_id,
632+
inputs_embeds,
633+
modality_features: image_features,
634+
input_ids,
635+
attention_mask,
636+
})
637+
}
638+
639+
function default_merge_input_ids_with_audio_features({
640+
audio_token_id,
641+
inputs_embeds,
642+
audio_features,
643+
input_ids,
644+
attention_mask,
645+
}) {
646+
return default_merge_input_ids_with_features({
647+
modality_token_id: audio_token_id,
648+
inputs_embeds,
649+
modality_features: audio_features,
650+
input_ids,
651+
attention_mask,
652+
})
653+
}
654+
655+
/**
656+
* Abstract forward pass function for image-text-to-text or audio-text-to-text models.
657+
* @param {Object} self The model object.
658+
* @param {Object} params Additional parameters.
659+
* @param {Function} [params.encode_function] The function to encode the modality values.
660+
* @param {Function} [params.merge_function] The function to merge the modality features with the input embeddings.
661+
* @param {string} [params.modality_input_name] The modality input name.
662+
* @param {string} [params.modality_output_name] The modality output name.
663+
* @param {Tensor} [params.input_ids=null]
664+
* @param {Tensor} [params.attention_mask=null]
665+
* @param {Tensor} [params.position_ids=null]
666+
* @param {Tensor} [params.inputs_embeds=null]
667+
* @param {Tensor} [params.past_key_values=null]
668+
* @param {Object} [params.generation_config=null]
669+
* @param {Object} [params.logits_processor=null]
634670
* @returns {Promise<Tensor>} The model's output tensor
635671
* @private
636672
*/
637-
async function imageTextToTextForward(self, {
673+
async function genericTextToTextForward(self, {
674+
// Generic parameters:
675+
encode_function,
676+
merge_function,
677+
modality_input_name,
678+
modality_output_name,
679+
638680
// Produced by the tokenizer/processor:
639681
input_ids = null,
640682
attention_mask = null,
641-
pixel_values = null,
642683

643684
// Used during generation:
644685
position_ids = null,
@@ -649,27 +690,31 @@ async function imageTextToTextForward(self, {
649690
generation_config = null,
650691
logits_processor = null,
651692

652-
// TODO: needed?
693+
// Additional parameters
653694
...kwargs
654695
}) {
655-
696+
const modality_values = kwargs[modality_input_name];
656697
if (!inputs_embeds) {
657-
// 1. Extract the input embeddings
698+
// 1. Extract the text embeddings.
658699
inputs_embeds = await self.encode_text({ input_ids, ...kwargs });
659700

660-
// 2. Possibly, merge text and images
661-
if (pixel_values && input_ids.dims[1] !== 1) {
662-
const image_features = await self.encode_image({ pixel_values, ...kwargs });
663-
664-
({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({
665-
image_features,
701+
// 2. Possibly, merge text and modality values
702+
if (modality_values && input_ids.dims[1] !== 1) {
703+
const modality_features = await encode_function({
704+
// Pass the modality values under its expected key.
705+
// The caller knows whether this is audio or image.
706+
[modality_input_name]: modality_values,
707+
...kwargs
708+
});
709+
({ inputs_embeds, attention_mask } = merge_function({
710+
[modality_output_name]: modality_features,
666711
inputs_embeds,
667712
input_ids,
668713
attention_mask,
669714
}));
670715

671-
} else if (past_key_values && pixel_values && input_ids.dims[1] === 1) {
672-
// This is the case when we are generating with cache
716+
} else if (past_key_values && modality_values && input_ids.dims[1] === 1) {
717+
// This branch handles the cache case.
673718
const target_length = input_ids.dims[1]; // always 1
674719
const past_length = Object.values(past_key_values)[0].dims.at(-2);
675720

@@ -690,6 +735,7 @@ async function imageTextToTextForward(self, {
690735
}
691736
}
692737

738+
// 3. Call the decoder forward using the updated inputs.
693739
const outputs = await decoderForward(self, {
694740
inputs_embeds,
695741
past_key_values,
@@ -701,6 +747,40 @@ async function imageTextToTextForward(self, {
701747
return outputs;
702748
}
703749

750+
/**
751+
* Forward pass of an audio-text-to-text model.
752+
* @param {Object} self The audio-text-to-text model.
753+
* @param {Object} params The inputs for the audio-text-to-text forward pass.
754+
* @returns {Promise<Tensor>} The model's output tensor.
755+
* @private
756+
*/
757+
async function audioTextToTextForward(self, params) {
758+
return await genericTextToTextForward(self, {
759+
...params,
760+
modality_input_name: 'audio_values',
761+
modality_output_name: 'audio_features',
762+
encode_function: self.encode_audio.bind(self),
763+
merge_function: self._merge_input_ids_with_audio_features.bind(self),
764+
});
765+
}
766+
767+
/**
768+
* Forward pass of an image-text-to-text model.
769+
* @param {Object} self The image-text-to-text model.
770+
* @param {Object} params The inputs for the image-text-to-text forward pass.
771+
* @returns {Promise<Tensor>} The model's output tensor.
772+
* @private
773+
*/
774+
async function imageTextToTextForward(self, params) {
775+
return await genericTextToTextForward(self, {
776+
...params,
777+
modality_input_name: 'pixel_values',
778+
modality_output_name: 'image_features',
779+
encode_function: self.encode_image.bind(self),
780+
merge_function: self._merge_input_ids_with_image_features.bind(self),
781+
});
782+
}
783+
704784
/**
705785
* Helper function to perform the following:
706786
* ```python
@@ -814,7 +894,7 @@ function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_in
814894
};
815895
}
816896

817-
function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
897+
function multimodal_text_to_text_prepare_inputs_for_generation(self, ...args) {
818898
if (self.config.is_encoder_decoder) {
819899
return encoder_decoder_prepare_inputs_for_generation(self, ...args);
820900
} else {
@@ -918,11 +998,16 @@ export class PreTrainedModel extends Callable {
918998
case MODEL_TYPES.ImageTextToText:
919999
this.can_generate = true;
9201000
this._forward = imageTextToTextForward;
921-
this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation;
1001+
this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation;
1002+
break;
1003+
case MODEL_TYPES.AudioTextToText:
1004+
this.can_generate = true;
1005+
this._forward = audioTextToTextForward;
1006+
this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation;
9221007
break;
9231008
case MODEL_TYPES.Phi3V:
9241009
this.can_generate = true;
925-
this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation;
1010+
this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation;
9261011
break;
9271012

9281013
case MODEL_TYPES.MultiModality:
@@ -1061,6 +1146,19 @@ export class PreTrainedModel extends Callable {
10611146
}, options),
10621147
]);
10631148

1149+
} else if (modelType === MODEL_TYPES.AudioTextToText) {
1150+
const sessions = {
1151+
embed_tokens: 'embed_tokens',
1152+
audio_encoder: 'audio_encoder',
1153+
decoder_model_merged: 'decoder_model_merged',
1154+
}
1155+
info = await Promise.all([
1156+
constructSessions(pretrained_model_name_or_path, sessions, options),
1157+
getOptionalConfigs(pretrained_model_name_or_path, {
1158+
generation_config: 'generation_config.json',
1159+
}, options),
1160+
]);
1161+
10641162
} else if (modelType === MODEL_TYPES.Musicgen) {
10651163
info = await Promise.all([
10661164
constructSessions(pretrained_model_name_or_path, {
@@ -1878,6 +1976,11 @@ export class PreTrainedModel extends Callable {
18781976
// text_inputs === { input_ids, attention_mask }
18791977
return (await sessionRun(this.sessions['embed_tokens'], { input_ids })).inputs_embeds;
18801978
}
1979+
1980+
async encode_audio({ audio_values }) {
1981+
// audio_inputs === { audio_values }
1982+
return (await sessionRun(this.sessions['audio_encoder'], { audio_values })).audio_features;
1983+
}
18811984
}
18821985

18831986
//////////////////////////////////////////////////
@@ -6971,6 +7074,34 @@ export class PatchTSMixerModel extends PatchTSMixerPreTrainedModel { }
69717074
export class PatchTSMixerForPrediction extends PatchTSMixerPreTrainedModel { }
69727075
//////////////////////////////////////////////////
69737076

7077+
//////////////////////////////////////////////////
7078+
export class UltravoxPreTrainedModel extends PreTrainedModel {
7079+
forward_params = [
7080+
'input_ids',
7081+
'attention_mask',
7082+
'position_ids',
7083+
'audio_values',
7084+
'past_key_values',
7085+
];
7086+
}
7087+
7088+
export class UltravoxModel extends UltravoxPreTrainedModel {
7089+
7090+
_merge_input_ids_with_audio_features(kwargs) {
7091+
const audio_hidden_size = kwargs.audio_features.dims.at(-1);
7092+
const reshaped_audio_features = kwargs.audio_features.view(-1, audio_hidden_size);
7093+
7094+
return default_merge_input_ids_with_audio_features({
7095+
// @ts-ignore
7096+
audio_token_id: this.config.ignore_index,
7097+
...kwargs,
7098+
audio_features: reshaped_audio_features,
7099+
})
7100+
}
7101+
}
7102+
//////////////////////////////////////////////////
7103+
7104+
69747105

69757106
//////////////////////////////////////////////////
69767107
// AutoModels, used to simplify construction of PreTrainedModels
@@ -7337,6 +7468,11 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
73377468
['paligemma', ['PaliGemmaForConditionalGeneration', PaliGemmaForConditionalGeneration]],
73387469
]);
73397470

7471+
const MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
7472+
['ultravox', ['UltravoxModel', UltravoxModel]],
7473+
]);
7474+
7475+
73407476
const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
73417477
['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]],
73427478
]);
@@ -7480,6 +7616,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
74807616
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
74817617
[MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq],
74827618
[MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, MODEL_TYPES.ImageTextToText],
7619+
[MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES, MODEL_TYPES.AudioTextToText],
74837620
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
74847621
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
74857622
[MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
@@ -7771,6 +7908,14 @@ export class AutoModelForImageFeatureExtraction extends PretrainedMixin {
77717908
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES];
77727909
}
77737910

7911+
export class AutoModelForImageTextToText extends PretrainedMixin {
7912+
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES];
7913+
}
7914+
7915+
export class AutoModelForAudioTextToText extends PretrainedMixin {
7916+
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES];
7917+
}
7918+
77747919
//////////////////////////////////////////////////
77757920

77767921
//////////////////////////////////////////////////

src/models/processors.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export * from './qwen2_vl/processing_qwen2_vl.js';
1313
export * from './sam/processing_sam.js';
1414
export * from './smolvlm/processing_smolvlm.js';
1515
export * from './speecht5/processing_speecht5.js';
16+
export * from './ultravox/processing_ultravox.js';
1617
export * from './wav2vec2/processing_wav2vec2.js';
1718
export * from './wav2vec2_with_lm/processing_wav2vec2_with_lm.js';
1819
export * from './whisper/processing_whisper.js';

0 commit comments

Comments
 (0)