@@ -3,59 +3,15 @@ import Hub
3
3
import MLX
4
4
import MLXNN
5
5
import MLXRandom
6
+ import Vocos
6
7
7
- // utilities
8
-
9
- func lensToMask( t: MLXArray , length: Int ? = nil ) -> MLXArray {
10
- let maxLength = length ?? t. max ( keepDims: false ) . item ( Int . self)
11
- let seq = MLXArray ( 0 ..< maxLength)
12
- let expandedSeq = seq. expandedDimensions ( axis: 0 )
13
- let expandedT = t. expandedDimensions ( axis: 1 )
14
- return MLX . less ( expandedSeq, expandedT)
15
- }
16
-
17
- func padToLength( _ t: MLXArray , length: Int , value: Float ? = nil ) -> MLXArray {
18
- let ndim = t. ndim
19
-
20
- guard let seqLen = t. shape. last, length > seqLen else {
21
- return t [ 0 ... , . ellipsis]
22
- }
23
-
24
- let paddingValue = MLXArray ( value ?? 0.0 )
25
-
26
- let padded : MLXArray
27
- switch ndim {
28
- case 1 :
29
- padded = MLX . padded ( t, widths: [ . init( ( 0 , length - seqLen) ) ] , value: paddingValue)
30
- case 2 :
31
- padded = MLX . padded ( t, widths: [ . init( ( 0 , 0 ) ) , . init( ( 0 , length - seqLen) ) ] , value: paddingValue)
32
- case 3 :
33
- padded = MLX . padded ( t, widths: [ . init( ( 0 , 0 ) ) , . init( ( 0 , length - seqLen) ) , . init( ( 0 , 0 ) ) ] , value: paddingValue)
34
- default :
35
- fatalError ( " Unsupported padding dims: \( ndim) " )
36
- }
37
-
38
- return padded [ 0 ... , . ellipsis]
39
- }
40
-
41
- func padSequence( _ t: [ MLXArray ] , paddingValue: Float = 0 ) -> MLXArray {
42
- let maxLen = t. map { $0. shape. last ?? 0 } . max ( ) ?? 0
43
- let t = MLX . stacked ( t, axis: 0 )
44
- return padToLength ( t, length: maxLen, value: paddingValue)
45
- }
46
-
47
- func listStrToIdx( _ text: [ String ] , vocabCharMap: [ String : Int ] , paddingValue: Int = - 1 ) -> MLXArray {
48
- let listIdxTensors = text. map { str in str. map { char in vocabCharMap [ String ( char) , default: 0 ] } }
49
- let mlxArrays = listIdxTensors. map { MLXArray ( $0) }
50
- let paddedText = padSequence ( mlxArrays, paddingValue: Float ( paddingValue) )
51
- return paddedText. asType ( . int32)
52
- }
53
-
54
- // MARK: -
8
+ // MARK: - F5TTS
55
9
56
10
public class F5TTS : Module {
57
11
enum F5TTSError : Error {
58
12
case unableToLoadModel
13
+ case unableToLoadReferenceAudio
14
+ case unableToDetermineDuration
59
15
}
60
16
61
17
public let melSpec : MelSpec
@@ -100,20 +56,20 @@ public class F5TTS: Module {
100
56
return MLX . stacked ( ys, axis: 0 )
101
57
}
102
58
103
- public func sample(
59
+ private func sample(
104
60
cond: MLXArray ,
105
61
text: [ String ] ,
106
62
duration: Any ,
107
63
lens: MLXArray ? = nil ,
108
64
steps: Int = 32 ,
109
- cfgStrength: Float = 2.0 ,
110
- swayCoef: Float ? = - 1.0 ,
65
+ cfgStrength: Double = 2.0 ,
66
+ swayCoef: Double ? = - 1.0 ,
111
67
seed: Int ? = nil ,
112
68
maxDuration: Int = 4096 ,
113
69
vocoder: ( ( MLXArray ) -> MLXArray ) ? = nil ,
114
70
noRefAudio: Bool = false ,
115
71
editMask: MLXArray ? = nil ,
116
- progressHandler: ( ( Float ) -> Void ) ? = nil
72
+ progressHandler: ( ( Double ) -> Void ) ? = nil
117
73
) -> ( MLXArray , MLXArray ) {
118
74
MLX . eval ( self . parameters ( ) )
119
75
@@ -183,7 +139,7 @@ public class F5TTS: Module {
183
139
mask: mask
184
140
)
185
141
186
- progressHandler ? ( t )
142
+ progressHandler ? ( Double ( t ) )
187
143
188
144
return pred + ( pred - nullPred) * cfgStrength
189
145
}
@@ -218,13 +174,82 @@ public class F5TTS: Module {
218
174
219
175
return ( out, trajectory)
220
176
}
177
+
178
+ public func generate(
179
+ text: String ,
180
+ referenceAudioURL: URL ? = nil ,
181
+ referenceAudioText: String ? = nil ,
182
+ duration: TimeInterval ? = nil ,
183
+ cfg: Double = 2.0 ,
184
+ sway: Double = - 1.0 ,
185
+ speed: Double = 1.0 ,
186
+ seed: Int ? = nil ,
187
+ progressHandler: ( ( Double ) -> Void ) ? = nil
188
+ ) async throws -> MLXArray {
189
+ print ( " Loading Vocos model... " )
190
+ let vocos = try await Vocos . fromPretrained ( repoId: " lucasnewman/vocos-mel-24khz-mlx " )
191
+
192
+ // load the reference audio + text
193
+
194
+ var audio : MLXArray
195
+ let referenceText : String
196
+
197
+ if let referenceAudioURL {
198
+ audio = try F5TTS . loadAudioArray ( url: referenceAudioURL)
199
+ referenceText = referenceAudioText ?? " "
200
+ } else {
201
+ let refAudioAndCaption = try F5TTS . referenceAudio ( )
202
+ ( audio, referenceText) = refAudioAndCaption
203
+ }
204
+
205
+ let refAudioDuration = Double ( audio. shape [ 0 ] ) / Double( F5TTS . sampleRate)
206
+ print ( " Using reference audio with duration: \( refAudioDuration) " )
207
+
208
+ // use a heuristic to determine the duration if not provided
209
+
210
+ var generatedDuration = duration
211
+ if generatedDuration == nil {
212
+ generatedDuration = F5TTS . estimatedDuration ( refAudio: audio, refText: referenceText, text: text)
213
+ }
214
+
215
+ guard let generatedDuration else {
216
+ throw F5TTSError . unableToDetermineDuration
217
+ }
218
+ print ( " Using generated duration: \( generatedDuration) " )
219
+
220
+ // generate the audio
221
+
222
+ let normalizedAudio = F5TTS . normalizeAudio ( audio: audio)
223
+
224
+ let processedText = referenceText + " " + text
225
+ let frameDuration = Int ( ( refAudioDuration + generatedDuration) * F5TTS. framesPerSecond)
226
+ print ( " Generating \( generatedDuration) seconds ( \( frameDuration) total frames) of audio... " )
227
+
228
+ let ( outputAudio, _) = self . sample (
229
+ cond: normalizedAudio. expandedDimensions ( axis: 0 ) ,
230
+ text: [ processedText] ,
231
+ duration: frameDuration,
232
+ steps: 32 ,
233
+ cfgStrength: cfg,
234
+ swayCoef: sway,
235
+ seed: seed,
236
+ vocoder: vocos. decode
237
+ ) { progress in
238
+ print ( " Generation progress: \( progress) " )
239
+ }
240
+
241
+ let generatedAudio = outputAudio [ audio. shape [ 0 ] ... ]
242
+ return generatedAudio
243
+ }
221
244
}
222
245
223
- // MARK: -
246
+ // MARK: - Pretrained Models
224
247
225
248
public extension F5TTS {
226
- static func fromPretrained( repoId: String ) async throws -> F5TTS {
227
- let modelDirectoryURL = try await Hub . snapshot ( from: repoId, matching: [ " *.safetensors " , " *.txt " ] )
249
+ static func fromPretrained( repoId: String , downloadProgress: ( ( Progress ) -> Void ) ? = nil ) async throws -> F5TTS {
250
+ let modelDirectoryURL = try await Hub . snapshot ( from: repoId, matching: [ " *.safetensors " , " *.txt " ] ) { progress in
251
+ downloadProgress ? ( progress)
252
+ }
228
253
return try self . fromPretrained ( modelDirectoryURL: modelDirectoryURL)
229
254
}
230
255
@@ -273,3 +298,97 @@ public extension F5TTS {
273
298
return f5tts
274
299
}
275
300
}
301
+
302
+ // MARK: - Utilities
303
+
304
+ public extension F5TTS {
305
+ static var sampleRate : Int = 24000
306
+ static var hopLength : Int = 256
307
+ static var framesPerSecond : Double = . init( sampleRate) / Double( hopLength)
308
+
309
+ static func loadAudioArray( url: URL ) throws -> MLXArray {
310
+ return try AudioUtilities . loadAudioFile ( url: url)
311
+ }
312
+
313
+ static func referenceAudio( ) throws -> ( MLXArray , String ) {
314
+ guard let url = Bundle . module. url ( forResource: " test_en_1_ref_short " , withExtension: " wav " ) else {
315
+ throw F5TTSError . unableToLoadReferenceAudio
316
+ }
317
+
318
+ return try (
319
+ self . loadAudioArray ( url: url) ,
320
+ " Some call me nature, others call me mother nature. "
321
+ )
322
+ }
323
+
324
+ static func normalizeAudio( audio: MLXArray , targetRMS: Double = 0.1 ) -> MLXArray {
325
+ let rms = Double ( audio. square ( ) . mean ( ) . sqrt ( ) . item ( Float . self) )
326
+ if rms < targetRMS {
327
+ return audio * targetRMS / rms
328
+ }
329
+ return audio
330
+ }
331
+
332
+ static func estimatedDuration( refAudio: MLXArray , refText: String , text: String , speed: Double = 1.0 ) -> TimeInterval {
333
+ let refDurationInFrames = refAudio. shape [ 0 ] / self . hopLength
334
+ let pausePunctuation = " 。,、;:?! "
335
+ let refTextLength = refText. utf8. count + 3 * pausePunctuation. utf8. count
336
+ let genTextLength = text. utf8. count + 3 * pausePunctuation. utf8. count
337
+
338
+ let refAudioToTextRatio = Double ( refDurationInFrames) / Double( refTextLength)
339
+ let textLength = Double ( genTextLength) / speed
340
+ let estimatedDurationInFrames = Int ( refAudioToTextRatio * textLength)
341
+
342
+ let estimatedDuration = TimeInterval ( estimatedDurationInFrames) / Self. framesPerSecond
343
+ print ( " Using duration of \( estimatedDuration) seconds ( \( estimatedDurationInFrames) frames) for generated speech. " )
344
+
345
+ return estimatedDuration
346
+ }
347
+ }
348
+
349
+ // MLX utilities
350
+
351
+ func lensToMask( t: MLXArray , length: Int ? = nil ) -> MLXArray {
352
+ let maxLength = length ?? t. max ( keepDims: false ) . item ( Int . self)
353
+ let seq = MLXArray ( 0 ..< maxLength)
354
+ let expandedSeq = seq. expandedDimensions ( axis: 0 )
355
+ let expandedT = t. expandedDimensions ( axis: 1 )
356
+ return MLX . less ( expandedSeq, expandedT)
357
+ }
358
+
359
+ func padToLength( _ t: MLXArray , length: Int , value: Float ? = nil ) -> MLXArray {
360
+ let ndim = t. ndim
361
+
362
+ guard let seqLen = t. shape. last, length > seqLen else {
363
+ return t [ 0 ... , . ellipsis]
364
+ }
365
+
366
+ let paddingValue = MLXArray ( value ?? 0.0 )
367
+
368
+ let padded : MLXArray
369
+ switch ndim {
370
+ case 1 :
371
+ padded = MLX . padded ( t, widths: [ . init( ( 0 , length - seqLen) ) ] , value: paddingValue)
372
+ case 2 :
373
+ padded = MLX . padded ( t, widths: [ . init( ( 0 , 0 ) ) , . init( ( 0 , length - seqLen) ) ] , value: paddingValue)
374
+ case 3 :
375
+ padded = MLX . padded ( t, widths: [ . init( ( 0 , 0 ) ) , . init( ( 0 , length - seqLen) ) , . init( ( 0 , 0 ) ) ] , value: paddingValue)
376
+ default :
377
+ fatalError ( " Unsupported padding dims: \( ndim) " )
378
+ }
379
+
380
+ return padded [ 0 ... , . ellipsis]
381
+ }
382
+
383
+ func padSequence( _ t: [ MLXArray ] , paddingValue: Float = 0 ) -> MLXArray {
384
+ let maxLen = t. map { $0. shape. last ?? 0 } . max ( ) ?? 0
385
+ let t = MLX . stacked ( t, axis: 0 )
386
+ return padToLength ( t, length: maxLen, value: paddingValue)
387
+ }
388
+
389
+ func listStrToIdx( _ text: [ String ] , vocabCharMap: [ String : Int ] , paddingValue: Int = - 1 ) -> MLXArray {
390
+ let listIdxTensors = text. map { str in str. map { char in vocabCharMap [ String ( char) , default: 0 ] } }
391
+ let mlxArrays = listIdxTensors. map { MLXArray ( $0) }
392
+ let paddedText = padSequence ( mlxArrays, paddingValue: Float ( paddingValue) )
393
+ return paddedText. asType ( . int32)
394
+ }
0 commit comments