Skip to content

Commit 26d2535

Browse files
committed
Add generation tool.
1 parent 847fb3c commit 26d2535

File tree

5 files changed

+138
-6
lines changed

5 files changed

+138
-6
lines changed

Package.resolved

+9
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@
4444
"revision" : "4d25d20e49d2269aec1556231f8e278db7b2a4f0",
4545
"version" : "0.1.13"
4646
}
47+
},
48+
{
49+
"identity" : "vocos-swift",
50+
"kind" : "remoteSourceControl",
51+
"location" : "https://github.com/lucasnewman/vocos-swift.git",
52+
"state" : {
53+
"revision" : "021e7af9d0c0aff9f7b62bf9839c37554287f3af",
54+
"version" : "0.0.1"
55+
}
4756
}
4857
],
4958
"version" : 2

Package.swift

+15-2
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ let package = Package(
1010
.library(
1111
name: "F5TTS",
1212
targets: ["F5TTS"]
13-
),
13+
)
1414
],
1515
dependencies: [
1616
.package(url: "https://github.com/ml-explore/mlx-swift", from: "0.18.0"),
1717
.package(url: "https://github.com/huggingface/swift-transformers", from: "0.1.13"),
18+
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.3.0"),
19+
.package(url: "https://github.com/lucasnewman/vocos-swift.git", from: "0.0.1")
1820
],
1921
targets: [
2022
.target(
@@ -26,12 +28,23 @@ let package = Package(
2628
.product(name: "MLXFFT", package: "mlx-swift"),
2729
.product(name: "MLXLinalg", package: "mlx-swift"),
2830
.product(name: "MLXRandom", package: "mlx-swift"),
29-
.product(name: "Transformers", package: "swift-transformers"),
31+
.product(name: "Transformers", package: "swift-transformers")
3032
],
3133
path: "Sources/F5TTS",
3234
resources: [
3335
.copy("mel_filters.npy"),
36+
.copy("test_en_1_ref_short.wav")
3437
]
3538
),
39+
.executableTarget(
40+
name: "f5-tts-generate",
41+
dependencies: [
42+
"F5TTS",
43+
.product(name: "Vocos", package: "vocos-swift"),
44+
.product(name: "ArgumentParser", package: "swift-argument-parser"),
45+
.product(name: "MLX", package: "mlx-swift"),
46+
],
47+
path: "Sources/f5-tts-generate"
48+
)
3649
]
3750
)

Sources/F5TTS/CFM.swift

+3-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ func padToLength(_ t: MLXArray, length: Int, value: Float? = nil) -> MLXArray {
4141
func padSequence(_ t: [MLXArray], paddingValue: Float = 0) -> MLXArray {
4242
let maxLen = t.map { $0.shape.last ?? 0 }.max() ?? 0
4343
let t = MLX.stacked(t, axis: 0)
44-
let paddedArrays = padToLength(t, length: maxLen, value: paddingValue)
4544
return padToLength(t, length: maxLen, value: paddingValue)
4645
}
4746

@@ -108,7 +107,7 @@ public class F5TTS: Module {
108107
lens: MLXArray? = nil,
109108
steps: Int = 32,
110109
cfgStrength: Float = 2.0,
111-
swaySamplingCoef: Float? = -1.0,
110+
swayCoef: Float? = -1.0,
112111
seed: Int? = nil,
113112
maxDuration: Int = 4096,
114113
vocoder: ((MLXArray) -> MLXArray)? = nil,
@@ -203,12 +202,12 @@ public class F5TTS: Module {
203202

204203
var t = MLXArray.linspace(Float32(0.0), Float32(1.0), count: steps)
205204

206-
if let coef = swaySamplingCoef {
205+
if let coef = swayCoef {
207206
t = t + coef * (MLX.cos(MLXArray(.pi) / 2 * t) - 1 + t)
208207
}
209208

210209
let trajectory = self.odeint(fun: fn, y0: y0Padded, t: t)
211-
var sampled = trajectory[-1]
210+
let sampled = trajectory[-1]
212211
var out = MLX.where(condMask, cond, sampled)
213212

214213
if let vocoder = vocoder {

Sources/F5TTS/test_en_1_ref_short.wav

250 KB
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import ArgumentParser
2+
import MLX
3+
import F5TTS
4+
import Foundation
5+
import Vocos
6+
7+
@main
8+
struct GenerateAudio: AsyncParsableCommand {
9+
@Argument(help: "Text to generate speech from")
10+
var text: String
11+
12+
@Option(name: .long, help: "Duration of the generated audio in seconds")
13+
var duration: Double?
14+
15+
@Option(name: .long, help: "Path to the reference audio file")
16+
var refAudioPath: String?
17+
18+
@Option(name: .long, help: "Text spoken in the reference audio")
19+
var refAudioText: String?
20+
21+
@Option(name: .long, help: "Model name to use")
22+
var model: String = "lucasnewman/f5-tts-mlx"
23+
24+
@Option(name: .long, help: "Output path for the generated audio")
25+
var outputPath: String = "output.wav"
26+
27+
@Option(name: .long, help: "Strength of classifier free guidance")
28+
var cfg: Float = 2.0
29+
30+
@Option(name: .long, help: "Coefficient for sway sampling")
31+
var sway: Float = -1.0
32+
33+
@Option(name: .long, help: "Speed factor for the duration heuristic")
34+
var speed: Float = 1.0
35+
36+
@Option(name: .long, help: "Seed for noise generation")
37+
var seed: Int?
38+
39+
func run() async throws {
40+
let sampleRate = 24_000
41+
let hopLength = 256
42+
let framesPerSec = Double(sampleRate) / Double(hopLength)
43+
let targetRMS: Float = 0.1
44+
45+
let f5tts = try await F5TTS.fromPretrained(repoId: model)
46+
let vocos = try await Vocos.fromPretrained(repoId: "lucasnewman/vocos-mel-24khz-mlx")
47+
48+
var audio: MLXArray
49+
let referenceText: String
50+
51+
if let refPath = refAudioPath {
52+
audio = try AudioUtilities.loadAudioFile(url: URL(filePath: refPath))
53+
referenceText = refAudioText ?? "Some call me nature, others call me mother nature."
54+
} else if let refURL = Bundle.main.url(forResource: "test_en_1_ref_short", withExtension: "wav") {
55+
audio = try AudioUtilities.loadAudioFile(url: refURL)
56+
referenceText = "Some call me nature, others call me mother nature."
57+
} else {
58+
fatalError("No reference audio file specified.")
59+
}
60+
61+
let rms = audio.square().mean().sqrt().item(Float.self)
62+
if rms < targetRMS {
63+
audio = audio * targetRMS / rms
64+
}
65+
66+
// use a heuristic to determine the duration if not provided
67+
let refAudioDuration = Double(audio.shape[0]) / framesPerSec
68+
var generatedDuration = duration
69+
70+
if generatedDuration == nil {
71+
let refAudioLength = audio.shape[0] / hopLength
72+
let pausePunctuation = "。,、;:?!"
73+
let refTextLength = referenceText.utf8.count + 3 * pausePunctuation.utf8.count
74+
let genTextLength = text.utf8.count + 3 * pausePunctuation.utf8.count
75+
76+
let durationInFrames = refAudioLength + Int((Double(refAudioLength) / Double(refTextLength)) * (Double(genTextLength) / Double(speed)))
77+
let estimatedDuration = Double(durationInFrames - refAudioLength) / framesPerSec
78+
79+
print("Using duration of \(estimatedDuration) seconds for generated speech.")
80+
generatedDuration = estimatedDuration
81+
}
82+
83+
guard let generatedDuration else {
84+
fatalError("Unable to determine duration.")
85+
}
86+
87+
let processedText = referenceText + " " + text
88+
let frameDuration = Int((refAudioDuration + generatedDuration) * framesPerSec)
89+
print("Generating \(frameDuration) frames of audio...")
90+
91+
let startTime = Date()
92+
93+
let (outputAudio, _) = f5tts.sample(
94+
cond: audio.expandedDimensions(axis: 0),
95+
text: [processedText],
96+
duration: frameDuration,
97+
steps: 32,
98+
cfgStrength: cfg,
99+
swayCoef: sway,
100+
seed: seed,
101+
vocoder: vocos.decode
102+
)
103+
104+
let generatedAudio = outputAudio[audio.shape[0]...]
105+
106+
let elapsedTime = Date().timeIntervalSince(startTime)
107+
print("Generated \(Double(generatedAudio.count) / Double(sampleRate)) seconds of audio in \(elapsedTime) seconds.")
108+
109+
try AudioUtilities.saveAudioFile(url: URL(filePath: outputPath), samples: generatedAudio)
110+
}
111+
}

0 commit comments

Comments
 (0)