Skip to content

Commit 3689910

Browse files
committed
Add rk4 sampling and use it by default.
1 parent a772687 commit 3689910

File tree

3 files changed

+97
-42
lines changed

3 files changed

+97
-42
lines changed

Package.resolved

+6-6
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
"kind" : "remoteSourceControl",
66
"location" : "https://github.com/maiqingqiang/Jinja",
77
"state" : {
8-
"revision" : "b435eb62b0d3d5f34167ec70a128355486981712",
9-
"version" : "1.0.5"
8+
"revision" : "6dbe4c449469fb586d0f7339f900f0dd4d78b167",
9+
"version" : "1.0.6"
1010
}
1111
},
1212
{
1313
"identity" : "mlx-swift",
1414
"kind" : "remoteSourceControl",
1515
"location" : "https://github.com/ml-explore/mlx-swift",
1616
"state" : {
17-
"revision" : "78a7cfe6701d6e9c88e9d4a0d1f7990af84b2146",
18-
"version" : "0.18.0"
17+
"revision" : "70dbb62128a5a1471a5ab80363430adb33470cab",
18+
"version" : "0.21.2"
1919
}
2020
},
2121
{
@@ -41,8 +41,8 @@
4141
"kind" : "remoteSourceControl",
4242
"location" : "https://github.com/huggingface/swift-transformers",
4343
"state" : {
44-
"revision" : "4d25d20e49d2269aec1556231f8e278db7b2a4f0",
45-
"version" : "0.1.13"
44+
"revision" : "d42fdae473c49ea216671da8caae58e102d28709",
45+
"version" : "0.1.14"
4646
}
4747
},
4848
{

Sources/F5TTS/F5TTS.swift

+83-36
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,74 @@ import Vocos
77

88
// MARK: - F5TTS
99

10+
func odeint_euler(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray {
11+
var ys = [y0]
12+
var yCurrent = y0
13+
14+
for i in 0..<(t.shape[0] - 1) {
15+
let tCurrent = t[i].item(Float.self)
16+
let dt = t[i + 1].item(Float.self) - tCurrent
17+
18+
let k = fun(tCurrent, yCurrent)
19+
let yNext = yCurrent + dt * k
20+
21+
ys.append(yNext)
22+
yCurrent = yNext
23+
}
24+
25+
return MLX.stacked(ys, axis: 0)
26+
}
27+
28+
func odeint_midpoint(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray {
29+
var ys = [y0]
30+
var yCurrent = y0
31+
32+
for i in 0..<(t.shape[0] - 1) {
33+
let tCurrent = t[i].item(Float.self)
34+
let dt = t[i + 1].item(Float.self) - tCurrent
35+
36+
let k1 = fun(tCurrent, yCurrent)
37+
let mid = yCurrent + 0.5 * dt * k1
38+
39+
let k2 = fun(tCurrent + 0.5 * dt, mid)
40+
let yNext = yCurrent + dt * k2
41+
42+
ys.append(yNext)
43+
yCurrent = yNext
44+
}
45+
46+
return MLX.stacked(ys, axis: 0)
47+
}
48+
49+
func odeint_rk4(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray {
50+
var ys = [y0]
51+
var yCurrent = y0
52+
53+
for i in 0..<(t.shape[0] - 1) {
54+
let tCurrent = t[i].item(Float.self)
55+
let dt = t[i + 1].item(Float.self) - tCurrent
56+
57+
let k1 = fun(tCurrent, yCurrent)
58+
let k2 = fun(tCurrent + 0.5 * dt, yCurrent + 0.5 * dt * k1)
59+
let k3 = fun(tCurrent + 0.5 * dt, yCurrent + 0.5 * dt * k2)
60+
let k4 = fun(tCurrent + dt, yCurrent + dt * k3)
61+
62+
let yNext = yCurrent + (dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4)
63+
64+
ys.append(yNext)
65+
yCurrent = yNext
66+
}
67+
68+
return MLX.stacked(ys)
69+
}
70+
1071
public class F5TTS: Module {
72+
public enum ODEMethod: String {
73+
case euler
74+
case midpoint
75+
case rk4
76+
}
77+
1178
enum F5TTSError: Error {
1279
case unableToLoadModel
1380
case unableToLoadReferenceAudio
@@ -38,40 +105,18 @@ public class F5TTS: Module {
38105
super.init()
39106
}
40107

41-
private func odeint(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray {
42-
var ys = [y0]
43-
var yCurrent = y0
44-
45-
for i in 0..<(t.shape[0] - 1) {
46-
let tCurrent = t[i].item(Float.self)
47-
let dt = t[i + 1].item(Float.self) - tCurrent
48-
49-
let k1 = fun(tCurrent, yCurrent)
50-
let mid = yCurrent + 0.5 * dt * k1
51-
52-
let k2 = fun(tCurrent + 0.5 * dt, mid)
53-
let yNext = yCurrent + dt * k2
54-
55-
ys.append(yNext)
56-
yCurrent = yNext
57-
}
58-
59-
return MLX.stacked(ys, axis: 0)
60-
}
61-
62108
private func sample(
63109
cond: MLXArray,
64110
text: [String],
65111
duration: Int? = nil,
66112
lens: MLXArray? = nil,
67-
steps: Int = 32,
113+
steps: Int = 8,
114+
method: ODEMethod = .rk4,
68115
cfgStrength: Double = 2.0,
69116
swayCoef: Double? = -1.0,
70117
seed: Int? = nil,
71118
maxDuration: Int = 4096,
72119
vocoder: ((MLXArray) -> MLXArray)? = nil,
73-
noRefAudio: Bool = false,
74-
editMask: MLXArray? = nil,
75120
progressHandler: ((Double) -> Void)? = nil
76121
) throws -> (MLXArray, MLXArray) {
77122
MLX.eval(self.parameters())
@@ -96,9 +141,6 @@ public class F5TTS: Module {
96141
lens = MLX.maximum(textLens, lens)
97142

98143
var condMask = lensToMask(t: lens)
99-
if let editMask = editMask {
100-
condMask = condMask & editMask
101-
}
102144

103145
// duration
104146
var resolvedDuration: MLXArray? = (duration != nil) ? MLXArray(duration!) : nil
@@ -125,10 +167,6 @@ public class F5TTS: Module {
125167

126168
let mask: MLXArray? = (batch > 1) ? lensToMask(t: duration) : nil
127169

128-
if noRefAudio {
129-
cond = MLX.zeros(like: cond)
130-
}
131-
132170
// neural ode
133171

134172
let fn: (Float, MLXArray) -> MLXArray = { t, x in
@@ -169,7 +207,7 @@ public class F5TTS: Module {
169207

170208
var y0: [MLXArray] = []
171209
for dur in duration {
172-
if let seed = seed {
210+
if let seed {
173211
MLXRandom.seed(UInt64(seed))
174212
}
175213
let noise = MLXRandom.normal([dur.item(Int.self), self.numChannels])
@@ -183,11 +221,17 @@ public class F5TTS: Module {
183221
t = t + coef * (MLX.cos(MLXArray(.pi) / 2 * t) - 1 + t)
184222
}
185223

186-
let trajectory = self.odeint(fun: fn, y0: y0Padded, t: t)
224+
let odeintFn = switch method {
225+
case .euler: odeint_euler
226+
case .midpoint: odeint_midpoint
227+
case .rk4: odeint_rk4
228+
}
229+
230+
let trajectory = odeintFn(fn, y0Padded, t)
187231
let sampled = trajectory[-1]
188232
var out = MLX.where(condMask, cond, sampled)
189233

190-
if let vocoder = vocoder {
234+
if let vocoder {
191235
out = vocoder(out)
192236
}
193237
out.eval()
@@ -200,6 +244,8 @@ public class F5TTS: Module {
200244
referenceAudioURL: URL? = nil,
201245
referenceAudioText: String? = nil,
202246
duration: TimeInterval? = nil,
247+
steps: Int = 8,
248+
method: ODEMethod = .rk4,
203249
cfg: Double = 2.0,
204250
sway: Double = -1.0,
205251
speed: Double = 1.0,
@@ -234,7 +280,8 @@ public class F5TTS: Module {
234280
cond: normalizedAudio.expandedDimensions(axis: 0),
235281
text: [processedText],
236282
duration: nil,
237-
steps: 32,
283+
steps: steps,
284+
method: method,
238285
cfgStrength: cfg,
239286
swayCoef: sway,
240287
seed: seed,
@@ -339,7 +386,7 @@ public extension F5TTS {
339386
static var framesPerSecond: Double = .init(sampleRate) / Double(hopLength)
340387

341388
static func loadAudioArray(url: URL) throws -> MLXArray {
342-
return try AudioUtilities.loadAudioFile(url: url)
389+
try AudioUtilities.loadAudioFile(url: url)
343390
}
344391

345392
static func referenceAudio() throws -> (MLXArray, String) {

Sources/f5-tts-generate/GenerateCommand.swift

+8
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ struct GenerateAudio: AsyncParsableCommand {
2424
@Option(name: .long, help: "Output path for the generated audio")
2525
var outputPath: String = "output.wav"
2626

27+
@Option(name: .long, help: "The number of steps to use for ODE sampling")
28+
var steps: Int = 8
29+
30+
@Option(name: .long, help: "Method to use for ODE sampling. Options are 'euler', 'midpoint', and 'rk4'.")
31+
var method: String = "rk4"
32+
2733
@Option(name: .long, help: "Strength of classifier free guidance")
2834
var cfg: Double = 2.0
2935

@@ -49,6 +55,8 @@ struct GenerateAudio: AsyncParsableCommand {
4955
referenceAudioURL: refAudioPath != nil ? URL(filePath: refAudioPath!) : nil,
5056
referenceAudioText: refAudioText,
5157
duration: duration,
58+
steps: steps,
59+
method: F5TTS.ODEMethod(rawValue: method)!,
5260
cfg: cfg,
5361
sway: sway,
5462
speed: speed,

0 commit comments

Comments
 (0)