Skip to content

Commit 3ab56f5

Browse files
baudneoedurenye
authored andcommitted
Create process.py
Adds --cuda arg support
1 parent 9bb212f commit 3ab56f5

File tree

1 file changed

+175
-0
lines changed

1 file changed

+175
-0
lines changed

piper/process.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
import asyncio
4+
import json
5+
import logging
6+
import tempfile
7+
import time
8+
from dataclasses import dataclass
9+
from typing import Any, Dict, Optional
10+
11+
from .download import ensure_voice_exists, find_voice
12+
13+
_LOGGER = logging.getLogger(__name__)
14+
15+
16+
@dataclass
17+
class PiperProcess:
18+
"""Info for a running Piper process (one voice)."""
19+
20+
name: str
21+
proc: "asyncio.subprocess.Process"
22+
config: Dict[str, Any]
23+
wav_dir: tempfile.TemporaryDirectory
24+
last_used: int = 0
25+
26+
def get_speaker_id(self, speaker: str) -> Optional[int]:
27+
"""Get speaker by name or id."""
28+
return _get_speaker_id(self.config, speaker)
29+
30+
@property
31+
def is_multispeaker(self) -> bool:
32+
"""True if model has more than one speaker."""
33+
return _is_multispeaker(self.config)
34+
35+
36+
def _get_speaker_id(config: Dict[str, Any], speaker: str) -> Optional[int]:
37+
"""Get speaker by name or id."""
38+
speaker_id_map = config.get("speaker_id_map", {})
39+
speaker_id = speaker_id_map.get(speaker)
40+
if speaker_id is None:
41+
try:
42+
# Try to interpret as an id
43+
speaker_id = int(speaker)
44+
except ValueError:
45+
pass
46+
47+
return speaker_id
48+
49+
50+
def _is_multispeaker(config: Dict[str, Any]) -> bool:
51+
"""True if model has more than one speaker."""
52+
return config.get("num_speakers", 1) > 1
53+
54+
55+
# -----------------------------------------------------------------------------
56+
57+
58+
class PiperProcessManager:
59+
"""Manager of running Piper processes."""
60+
61+
def __init__(self, args: argparse.Namespace, voices_info: Dict[str, Any]):
62+
self.voices_info = voices_info
63+
self.args = args
64+
self.processes: Dict[str, PiperProcess] = {}
65+
self.processes_lock = asyncio.Lock()
66+
67+
async def get_process(self, voice_name: Optional[str] = None) -> PiperProcess:
68+
"""Get a running Piper process or start a new one if necessary."""
69+
voice_speaker: Optional[str] = None
70+
if voice_name is None:
71+
# Default voice
72+
voice_name = self.args.voice
73+
74+
if voice_name == self.args.voice:
75+
# Default speaker
76+
voice_speaker = self.args.speaker
77+
78+
assert voice_name is not None
79+
80+
# Resolve alias
81+
voice_info = self.voices_info.get(voice_name, {})
82+
voice_name = voice_info.get("key", voice_name)
83+
assert voice_name is not None
84+
85+
piper_proc = self.processes.get(voice_name)
86+
if (piper_proc is None) or (piper_proc.proc.returncode is not None):
87+
# Remove if stopped
88+
self.processes.pop(voice_name, None)
89+
90+
# Start new Piper process
91+
if self.args.max_piper_procs > 0:
92+
# Restrict number of running processes
93+
while len(self.processes) >= self.args.max_piper_procs:
94+
# Stop least recently used process
95+
lru_proc_name, lru_proc = sorted(
96+
self.processes.items(), key=lambda kv: kv[1].last_used
97+
)[0]
98+
_LOGGER.debug("Stopping process for: %s", lru_proc_name)
99+
self.processes.pop(lru_proc_name, None)
100+
if lru_proc.proc.returncode is None:
101+
try:
102+
lru_proc.proc.terminate()
103+
await lru_proc.proc.wait()
104+
except Exception:
105+
_LOGGER.exception("Unexpected error stopping piper process")
106+
107+
_LOGGER.debug(
108+
"Starting process for: %s (%s/%s)",
109+
voice_name,
110+
len(self.processes) + 1,
111+
self.args.max_piper_procs,
112+
)
113+
114+
ensure_voice_exists(
115+
voice_name,
116+
self.args.data_dir,
117+
self.args.download_dir,
118+
self.voices_info,
119+
)
120+
121+
onnx_path, config_path = find_voice(voice_name, self.args.data_dir)
122+
with open(config_path, "r", encoding="utf-8") as config_file:
123+
config = json.load(config_file)
124+
125+
wav_dir = tempfile.TemporaryDirectory()
126+
piper_args = [
127+
"--model",
128+
str(onnx_path),
129+
"--config",
130+
str(config_path),
131+
"--output_dir",
132+
str(wav_dir.name),
133+
"--json-input", # piper 1.1+
134+
]
135+
136+
if voice_speaker is not None:
137+
if _is_multispeaker(config):
138+
speaker_id = _get_speaker_id(config, voice_speaker)
139+
if speaker_id is not None:
140+
piper_args.extend(["--speaker", str(speaker_id)])
141+
142+
if self.args.noise_scale:
143+
piper_args.extend(["--noise-scale", str(self.args.noise_scale)])
144+
145+
if self.args.length_scale:
146+
piper_args.extend(["--length-scale", str(self.args.length_scale)])
147+
148+
if self.args.noise_w:
149+
piper_args.extend(["--noise-w", str(self.args.noise_w)])
150+
151+
if self.args.cuda:
152+
piper_args.extend(["--cuda"])
153+
154+
_LOGGER.debug(
155+
"Starting piper process: %s args=%s", self.args.piper, piper_args
156+
)
157+
piper_proc = PiperProcess(
158+
name=voice_name,
159+
proc=await asyncio.create_subprocess_exec(
160+
self.args.piper,
161+
*piper_args,
162+
stdin=asyncio.subprocess.PIPE,
163+
stdout=asyncio.subprocess.PIPE,
164+
stderr=asyncio.subprocess.DEVNULL,
165+
),
166+
config=config,
167+
wav_dir=wav_dir,
168+
)
169+
self.processes[voice_name] = piper_proc
170+
171+
# Update used
172+
piper_proc.last_used = time.monotonic_ns()
173+
174+
return piper_proc
175+

0 commit comments

Comments
 (0)