Skip to content

Commit 3553c91

Browse files
committed
Add simpler generation interface.
1 parent 26d2535 commit 3553c91

File tree

6 files changed

+215
-134
lines changed

6 files changed

+215
-134
lines changed

Package.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ let package = Package(
3232
],
3333
path: "Sources/F5TTS",
3434
resources: [
35-
.copy("mel_filters.npy"),
36-
.copy("test_en_1_ref_short.wav")
35+
.copy("Resources/test_en_1_ref_short.wav"),
36+
.copy("Resources/mel_filters.npy")
3737
]
3838
),
3939
.executableTarget(

README.md

+20-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11

2-
# F5 TTS for Swift (WIP)
2+
# F5 TTS for Swift
33

44
Implementation of [F5-TTS](https://arxiv.org/abs/2410.06885) in Swift, using the [MLX Swift](https://github.com/ml-explore/mlx-swift) framework.
55

66
You can listen to a [sample here](https://s3.amazonaws.com/lucasnewman.datasets/f5tts/sample.wav) that was generated in ~11 seconds on an M3 Max MacBook Pro.
77

88
See the [Python repository](https://github.com/lucasnewman/f5-tts-mlx) for additional details on the model architecture.
9+
910
This repository is based on the original Pytorch implementation available [here](https://github.com/SWivid/F5-TTS).
1011

1112

@@ -19,21 +20,29 @@ A pretrained model is available [on Huggingface](https://hf.co/lucasnewman/f5-tt
1920
## Usage
2021

2122
```swift
22-
import Vocos
2323
import F5TTS
2424

2525
let f5tts = try await F5TTS.fromPretrained(repoId: "lucasnewman/f5-tts-mlx")
26-
let vocos = try await Vocos.fromPretrained(repoId: "lucasnewman/vocos-mel-24khz-mlx") // if decoding to audio output
2726

28-
let inputAudio = MLXArray(...)
27+
let generatedAudio = try await f5tts.generate(text: "The quick brown fox jumped over the lazy dog.")
28+
```
29+
30+
The result is an MLXArray with 24kHz audio samples.
31+
32+
If you want to use your own reference audio sample, make sure it's a mono, 24kHz wav file of around 5-10 seconds:
33+
34+
```swift
35+
let generatedAudio = try await f5tts.generate(
36+
text: "The quick brown fox jumped over the lazy dog.",
37+
referenceAudioURL: ...,
38+
referenceAudioText: "This is the caption for the reference audio."
39+
)
40+
```
41+
42+
You can convert an audio file to the correct format with ffmpeg like this:
2943

30-
let (outputAudio, _) = f5tts.sample(
31-
cond: inputAudio,
32-
text: ["This is the caption for the reference audio and generation text."],
33-
duration: ...,
34-
vocoder: vocos.decode) { progress in
35-
print("Progress: \(Int(progress * 100))%")
36-
}
44+
```bash
45+
ffmpeg -i /path/to/audio.wav -ac 1 -ar 24000 -sample_fmt s16 -t 10 /path/to/output_audio.wav
3746
```
3847

3948
## Appreciation

Sources/F5TTS/CFM.swift renamed to Sources/F5TTS/F5TTS.swift

+175-56
Original file line numberDiff line numberDiff line change
@@ -3,59 +3,15 @@ import Hub
33
import MLX
44
import MLXNN
55
import MLXRandom
6+
import Vocos
67

7-
// utilities
8-
9-
func lensToMask(t: MLXArray, length: Int? = nil) -> MLXArray {
10-
let maxLength = length ?? t.max(keepDims: false).item(Int.self)
11-
let seq = MLXArray(0..<maxLength)
12-
let expandedSeq = seq.expandedDimensions(axis: 0)
13-
let expandedT = t.expandedDimensions(axis: 1)
14-
return MLX.less(expandedSeq, expandedT)
15-
}
16-
17-
func padToLength(_ t: MLXArray, length: Int, value: Float? = nil) -> MLXArray {
18-
let ndim = t.ndim
19-
20-
guard let seqLen = t.shape.last, length > seqLen else {
21-
return t[0..., .ellipsis]
22-
}
23-
24-
let paddingValue = MLXArray(value ?? 0.0)
25-
26-
let padded: MLXArray
27-
switch ndim {
28-
case 1:
29-
padded = MLX.padded(t, widths: [.init((0, length - seqLen))], value: paddingValue)
30-
case 2:
31-
padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen))], value: paddingValue)
32-
case 3:
33-
padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen)), .init((0, 0))], value: paddingValue)
34-
default:
35-
fatalError("Unsupported padding dims: \(ndim)")
36-
}
37-
38-
return padded[0..., .ellipsis]
39-
}
40-
41-
func padSequence(_ t: [MLXArray], paddingValue: Float = 0) -> MLXArray {
42-
let maxLen = t.map { $0.shape.last ?? 0 }.max() ?? 0
43-
let t = MLX.stacked(t, axis: 0)
44-
return padToLength(t, length: maxLen, value: paddingValue)
45-
}
46-
47-
func listStrToIdx(_ text: [String], vocabCharMap: [String: Int], paddingValue: Int = -1) -> MLXArray {
48-
let listIdxTensors = text.map { str in str.map { char in vocabCharMap[String(char), default: 0] }}
49-
let mlxArrays = listIdxTensors.map { MLXArray($0) }
50-
let paddedText = padSequence(mlxArrays, paddingValue: Float(paddingValue))
51-
return paddedText.asType(.int32)
52-
}
53-
54-
// MARK: -
8+
// MARK: - F5TTS
559

5610
public class F5TTS: Module {
5711
enum F5TTSError: Error {
5812
case unableToLoadModel
13+
case unableToLoadReferenceAudio
14+
case unableToDetermineDuration
5915
}
6016

6117
public let melSpec: MelSpec
@@ -100,20 +56,20 @@ public class F5TTS: Module {
10056
return MLX.stacked(ys, axis: 0)
10157
}
10258

103-
public func sample(
59+
private func sample(
10460
cond: MLXArray,
10561
text: [String],
10662
duration: Any,
10763
lens: MLXArray? = nil,
10864
steps: Int = 32,
109-
cfgStrength: Float = 2.0,
110-
swayCoef: Float? = -1.0,
65+
cfgStrength: Double = 2.0,
66+
swayCoef: Double? = -1.0,
11167
seed: Int? = nil,
11268
maxDuration: Int = 4096,
11369
vocoder: ((MLXArray) -> MLXArray)? = nil,
11470
noRefAudio: Bool = false,
11571
editMask: MLXArray? = nil,
116-
progressHandler: ((Float) -> Void)? = nil
72+
progressHandler: ((Double) -> Void)? = nil
11773
) -> (MLXArray, MLXArray) {
11874
MLX.eval(self.parameters())
11975

@@ -183,7 +139,7 @@ public class F5TTS: Module {
183139
mask: mask
184140
)
185141

186-
progressHandler?(t)
142+
progressHandler?(Double(t))
187143

188144
return pred + (pred - nullPred) * cfgStrength
189145
}
@@ -218,13 +174,82 @@ public class F5TTS: Module {
218174

219175
return (out, trajectory)
220176
}
177+
178+
public func generate(
179+
text: String,
180+
referenceAudioURL: URL? = nil,
181+
referenceAudioText: String? = nil,
182+
duration: TimeInterval? = nil,
183+
cfg: Double = 2.0,
184+
sway: Double = -1.0,
185+
speed: Double = 1.0,
186+
seed: Int? = nil,
187+
progressHandler: ((Double) -> Void)? = nil
188+
) async throws -> MLXArray {
189+
print("Loading Vocos model...")
190+
let vocos = try await Vocos.fromPretrained(repoId: "lucasnewman/vocos-mel-24khz-mlx")
191+
192+
// load the reference audio + text
193+
194+
var audio: MLXArray
195+
let referenceText: String
196+
197+
if let referenceAudioURL {
198+
audio = try F5TTS.loadAudioArray(url: referenceAudioURL)
199+
referenceText = referenceAudioText ?? ""
200+
} else {
201+
let refAudioAndCaption = try F5TTS.referenceAudio()
202+
(audio, referenceText) = refAudioAndCaption
203+
}
204+
205+
let refAudioDuration = Double(audio.shape[0]) / Double(F5TTS.sampleRate)
206+
print("Using reference audio with duration: \(refAudioDuration)")
207+
208+
// use a heuristic to determine the duration if not provided
209+
210+
var generatedDuration = duration
211+
if generatedDuration == nil {
212+
generatedDuration = F5TTS.estimatedDuration(refAudio: audio, refText: referenceText, text: text)
213+
}
214+
215+
guard let generatedDuration else {
216+
throw F5TTSError.unableToDetermineDuration
217+
}
218+
print("Using generated duration: \(generatedDuration)")
219+
220+
// generate the audio
221+
222+
let normalizedAudio = F5TTS.normalizeAudio(audio: audio)
223+
224+
let processedText = referenceText + " " + text
225+
let frameDuration = Int((refAudioDuration + generatedDuration) * F5TTS.framesPerSecond)
226+
print("Generating \(generatedDuration) seconds (\(frameDuration) total frames) of audio...")
227+
228+
let (outputAudio, _) = self.sample(
229+
cond: normalizedAudio.expandedDimensions(axis: 0),
230+
text: [processedText],
231+
duration: frameDuration,
232+
steps: 32,
233+
cfgStrength: cfg,
234+
swayCoef: sway,
235+
seed: seed,
236+
vocoder: vocos.decode
237+
) { progress in
238+
print("Generation progress: \(progress)")
239+
}
240+
241+
let generatedAudio = outputAudio[audio.shape[0]...]
242+
return generatedAudio
243+
}
221244
}
222245

223-
// MARK: -
246+
// MARK: - Pretrained Models
224247

225248
public extension F5TTS {
226-
static func fromPretrained(repoId: String) async throws -> F5TTS {
227-
let modelDirectoryURL = try await Hub.snapshot(from: repoId, matching: ["*.safetensors", "*.txt"])
249+
static func fromPretrained(repoId: String, downloadProgress: ((Progress) -> Void)? = nil) async throws -> F5TTS {
250+
let modelDirectoryURL = try await Hub.snapshot(from: repoId, matching: ["*.safetensors", "*.txt"]) { progress in
251+
downloadProgress?(progress)
252+
}
228253
return try self.fromPretrained(modelDirectoryURL: modelDirectoryURL)
229254
}
230255

@@ -273,3 +298,97 @@ public extension F5TTS {
273298
return f5tts
274299
}
275300
}
301+
302+
// MARK: - Utilities
303+
304+
public extension F5TTS {
305+
static var sampleRate: Int = 24000
306+
static var hopLength: Int = 256
307+
static var framesPerSecond: Double = .init(sampleRate) / Double(hopLength)
308+
309+
static func loadAudioArray(url: URL) throws -> MLXArray {
310+
return try AudioUtilities.loadAudioFile(url: url)
311+
}
312+
313+
static func referenceAudio() throws -> (MLXArray, String) {
314+
guard let url = Bundle.module.url(forResource: "test_en_1_ref_short", withExtension: "wav") else {
315+
throw F5TTSError.unableToLoadReferenceAudio
316+
}
317+
318+
return try (
319+
self.loadAudioArray(url: url),
320+
"Some call me nature, others call me mother nature."
321+
)
322+
}
323+
324+
static func normalizeAudio(audio: MLXArray, targetRMS: Double = 0.1) -> MLXArray {
325+
let rms = Double(audio.square().mean().sqrt().item(Float.self))
326+
if rms < targetRMS {
327+
return audio * targetRMS / rms
328+
}
329+
return audio
330+
}
331+
332+
static func estimatedDuration(refAudio: MLXArray, refText: String, text: String, speed: Double = 1.0) -> TimeInterval {
333+
let refDurationInFrames = refAudio.shape[0] / self.hopLength
334+
let pausePunctuation = "。,、;:?!"
335+
let refTextLength = refText.utf8.count + 3 * pausePunctuation.utf8.count
336+
let genTextLength = text.utf8.count + 3 * pausePunctuation.utf8.count
337+
338+
let refAudioToTextRatio = Double(refDurationInFrames) / Double(refTextLength)
339+
let textLength = Double(genTextLength) / speed
340+
let estimatedDurationInFrames = Int(refAudioToTextRatio * textLength)
341+
342+
let estimatedDuration = TimeInterval(estimatedDurationInFrames) / Self.framesPerSecond
343+
print("Using duration of \(estimatedDuration) seconds (\(estimatedDurationInFrames) frames) for generated speech.")
344+
345+
return estimatedDuration
346+
}
347+
}
348+
349+
// MLX utilities
350+
351+
func lensToMask(t: MLXArray, length: Int? = nil) -> MLXArray {
352+
let maxLength = length ?? t.max(keepDims: false).item(Int.self)
353+
let seq = MLXArray(0..<maxLength)
354+
let expandedSeq = seq.expandedDimensions(axis: 0)
355+
let expandedT = t.expandedDimensions(axis: 1)
356+
return MLX.less(expandedSeq, expandedT)
357+
}
358+
359+
func padToLength(_ t: MLXArray, length: Int, value: Float? = nil) -> MLXArray {
360+
let ndim = t.ndim
361+
362+
guard let seqLen = t.shape.last, length > seqLen else {
363+
return t[0..., .ellipsis]
364+
}
365+
366+
let paddingValue = MLXArray(value ?? 0.0)
367+
368+
let padded: MLXArray
369+
switch ndim {
370+
case 1:
371+
padded = MLX.padded(t, widths: [.init((0, length - seqLen))], value: paddingValue)
372+
case 2:
373+
padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen))], value: paddingValue)
374+
case 3:
375+
padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen)), .init((0, 0))], value: paddingValue)
376+
default:
377+
fatalError("Unsupported padding dims: \(ndim)")
378+
}
379+
380+
return padded[0..., .ellipsis]
381+
}
382+
383+
func padSequence(_ t: [MLXArray], paddingValue: Float = 0) -> MLXArray {
384+
let maxLen = t.map { $0.shape.last ?? 0 }.max() ?? 0
385+
let t = MLX.stacked(t, axis: 0)
386+
return padToLength(t, length: maxLen, value: paddingValue)
387+
}
388+
389+
func listStrToIdx(_ text: [String], vocabCharMap: [String: Int], paddingValue: Int = -1) -> MLXArray {
390+
let listIdxTensors = text.map { str in str.map { char in vocabCharMap[String(char), default: 0] }}
391+
let mlxArrays = listIdxTensors.map { MLXArray($0) }
392+
let paddedText = padSequence(mlxArrays, paddingValue: Float(paddingValue))
393+
return paddedText.asType(.int32)
394+
}

0 commit comments

Comments
 (0)