@@ -7,7 +7,74 @@ import Vocos
7
7
8
8
// MARK: - F5TTS
9
9
10
+ func odeint_euler( fun: ( Float , MLXArray ) -> MLXArray , y0: MLXArray , t: MLXArray ) -> MLXArray {
11
+ var ys = [ y0]
12
+ var yCurrent = y0
13
+
14
+ for i in 0 ..< ( t. shape [ 0 ] - 1 ) {
15
+ let tCurrent = t [ i] . item ( Float . self)
16
+ let dt = t [ i + 1 ] . item ( Float . self) - tCurrent
17
+
18
+ let k = fun ( tCurrent, yCurrent)
19
+ let yNext = yCurrent + dt * k
20
+
21
+ ys. append ( yNext)
22
+ yCurrent = yNext
23
+ }
24
+
25
+ return MLX . stacked ( ys, axis: 0 )
26
+ }
27
+
28
+ func odeint_midpoint( fun: ( Float , MLXArray ) -> MLXArray , y0: MLXArray , t: MLXArray ) -> MLXArray {
29
+ var ys = [ y0]
30
+ var yCurrent = y0
31
+
32
+ for i in 0 ..< ( t. shape [ 0 ] - 1 ) {
33
+ let tCurrent = t [ i] . item ( Float . self)
34
+ let dt = t [ i + 1 ] . item ( Float . self) - tCurrent
35
+
36
+ let k1 = fun ( tCurrent, yCurrent)
37
+ let mid = yCurrent + 0.5 * dt * k1
38
+
39
+ let k2 = fun ( tCurrent + 0.5 * dt, mid)
40
+ let yNext = yCurrent + dt * k2
41
+
42
+ ys. append ( yNext)
43
+ yCurrent = yNext
44
+ }
45
+
46
+ return MLX . stacked ( ys, axis: 0 )
47
+ }
48
+
49
+ func odeint_rk4( fun: ( Float , MLXArray ) -> MLXArray , y0: MLXArray , t: MLXArray ) -> MLXArray {
50
+ var ys = [ y0]
51
+ var yCurrent = y0
52
+
53
+ for i in 0 ..< ( t. shape [ 0 ] - 1 ) {
54
+ let tCurrent = t [ i] . item ( Float . self)
55
+ let dt = t [ i + 1 ] . item ( Float . self) - tCurrent
56
+
57
+ let k1 = fun ( tCurrent, yCurrent)
58
+ let k2 = fun ( tCurrent + 0.5 * dt, yCurrent + 0.5 * dt * k1)
59
+ let k3 = fun ( tCurrent + 0.5 * dt, yCurrent + 0.5 * dt * k2)
60
+ let k4 = fun ( tCurrent + dt, yCurrent + dt * k3)
61
+
62
+ let yNext = yCurrent + ( dt / 6 ) * ( k1 + 2 * k2 + 2 * k3 + k4)
63
+
64
+ ys. append ( yNext)
65
+ yCurrent = yNext
66
+ }
67
+
68
+ return MLX . stacked ( ys)
69
+ }
70
+
10
71
public class F5TTS : Module {
72
+ public enum ODEMethod : String {
73
+ case euler
74
+ case midpoint
75
+ case rk4
76
+ }
77
+
11
78
enum F5TTSError : Error {
12
79
case unableToLoadModel
13
80
case unableToLoadReferenceAudio
@@ -38,40 +105,18 @@ public class F5TTS: Module {
38
105
super. init ( )
39
106
}
40
107
41
- private func odeint( fun: ( Float , MLXArray ) -> MLXArray , y0: MLXArray , t: MLXArray ) -> MLXArray {
42
- var ys = [ y0]
43
- var yCurrent = y0
44
-
45
- for i in 0 ..< ( t. shape [ 0 ] - 1 ) {
46
- let tCurrent = t [ i] . item ( Float . self)
47
- let dt = t [ i + 1 ] . item ( Float . self) - tCurrent
48
-
49
- let k1 = fun ( tCurrent, yCurrent)
50
- let mid = yCurrent + 0.5 * dt * k1
51
-
52
- let k2 = fun ( tCurrent + 0.5 * dt, mid)
53
- let yNext = yCurrent + dt * k2
54
-
55
- ys. append ( yNext)
56
- yCurrent = yNext
57
- }
58
-
59
- return MLX . stacked ( ys, axis: 0 )
60
- }
61
-
62
108
private func sample(
63
109
cond: MLXArray ,
64
110
text: [ String ] ,
65
111
duration: Int ? = nil ,
66
112
lens: MLXArray ? = nil ,
67
- steps: Int = 32 ,
113
+ steps: Int = 8 ,
114
+ method: ODEMethod = . rk4,
68
115
cfgStrength: Double = 2.0 ,
69
116
swayCoef: Double ? = - 1.0 ,
70
117
seed: Int ? = nil ,
71
118
maxDuration: Int = 4096 ,
72
119
vocoder: ( ( MLXArray ) -> MLXArray ) ? = nil ,
73
- noRefAudio: Bool = false ,
74
- editMask: MLXArray ? = nil ,
75
120
progressHandler: ( ( Double ) -> Void ) ? = nil
76
121
) throws -> ( MLXArray , MLXArray ) {
77
122
MLX . eval ( self . parameters ( ) )
@@ -96,9 +141,6 @@ public class F5TTS: Module {
96
141
lens = MLX . maximum ( textLens, lens)
97
142
98
143
var condMask = lensToMask ( t: lens)
99
- if let editMask = editMask {
100
- condMask = condMask & editMask
101
- }
102
144
103
145
// duration
104
146
var resolvedDuration : MLXArray ? = ( duration != nil ) ? MLXArray ( duration!) : nil
@@ -125,10 +167,6 @@ public class F5TTS: Module {
125
167
126
168
let mask : MLXArray ? = ( batch > 1 ) ? lensToMask ( t: duration) : nil
127
169
128
- if noRefAudio {
129
- cond = MLX . zeros ( like: cond)
130
- }
131
-
132
170
// neural ode
133
171
134
172
let fn : ( Float , MLXArray ) -> MLXArray = { t, x in
@@ -169,7 +207,7 @@ public class F5TTS: Module {
169
207
170
208
var y0 : [ MLXArray ] = [ ]
171
209
for dur in duration {
172
- if let seed = seed {
210
+ if let seed {
173
211
MLXRandom . seed ( UInt64 ( seed) )
174
212
}
175
213
let noise = MLXRandom . normal ( [ dur. item ( Int . self) , self . numChannels] )
@@ -183,11 +221,17 @@ public class F5TTS: Module {
183
221
t = t + coef * ( MLX . cos ( MLXArray ( . pi) / 2 * t) - 1 + t)
184
222
}
185
223
186
- let trajectory = self . odeint ( fun: fn, y0: y0Padded, t: t)
224
+ let odeintFn = switch method {
225
+ case . euler: odeint_euler
226
+ case . midpoint: odeint_midpoint
227
+ case . rk4: odeint_rk4
228
+ }
229
+
230
+ let trajectory = odeintFn ( fn, y0Padded, t)
187
231
let sampled = trajectory [ - 1 ]
188
232
var out = MLX . where ( condMask, cond, sampled)
189
233
190
- if let vocoder = vocoder {
234
+ if let vocoder {
191
235
out = vocoder ( out)
192
236
}
193
237
out. eval ( )
@@ -200,6 +244,8 @@ public class F5TTS: Module {
200
244
referenceAudioURL: URL ? = nil ,
201
245
referenceAudioText: String ? = nil ,
202
246
duration: TimeInterval ? = nil ,
247
+ steps: Int = 8 ,
248
+ method: ODEMethod = . rk4,
203
249
cfg: Double = 2.0 ,
204
250
sway: Double = - 1.0 ,
205
251
speed: Double = 1.0 ,
@@ -234,7 +280,8 @@ public class F5TTS: Module {
234
280
cond: normalizedAudio. expandedDimensions ( axis: 0 ) ,
235
281
text: [ processedText] ,
236
282
duration: nil ,
237
- steps: 32 ,
283
+ steps: steps,
284
+ method: method,
238
285
cfgStrength: cfg,
239
286
swayCoef: sway,
240
287
seed: seed,
@@ -339,7 +386,7 @@ public extension F5TTS {
339
386
static var framesPerSecond : Double = . init( sampleRate) / Double( hopLength)
340
387
341
388
static func loadAudioArray( url: URL ) throws -> MLXArray {
342
- return try AudioUtilities . loadAudioFile ( url: url)
389
+ try AudioUtilities . loadAudioFile ( url: url)
343
390
}
344
391
345
392
static func referenceAudio( ) throws -> ( MLXArray , String ) {
0 commit comments