Skip to content

Commit 08627f1

Browse files
CopilotPatrick-DE
andcommitted
Include changes from 5 community pull requests (jamiepine#78, jamiepine#88, jamiepine#91, jamiepine#95, jamiepine#97)
Co-authored-by: Patrick-DE <14962702+Patrick-DE@users.noreply.github.com>
1 parent 98f83b1 commit 08627f1

9 files changed

Lines changed: 214 additions & 19 deletions

File tree

app/src/App.tsx

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ function App() {
8282
console.log('Dev mode: Skipping auto-start of server (run it separately)');
8383
setServerReady(true); // Mark as ready so UI doesn't show loading screen
8484
// Mark that server was not started by app (so we don't try to stop it on close)
85-
// @ts-expect-error - adding property to window
8685
window.__voiceboxServerStartedByApp = false;
8786
return;
8887
}
@@ -103,13 +102,11 @@ function App() {
103102
useServerStore.getState().setServerUrl(serverUrl);
104103
setServerReady(true);
105104
// Mark that we started the server (so we know to stop it on close)
106-
// @ts-expect-error - adding property to window
107105
window.__voiceboxServerStartedByApp = true;
108106
})
109107
.catch((error) => {
110108
console.error('Failed to auto-start server:', error);
111109
serverStartingRef.current = false;
112-
// @ts-expect-error - adding property to window
113110
window.__voiceboxServerStartedByApp = false;
114111
});
115112

app/src/components/Generation/FloatingGenerateBox.tsx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import {
1313
} from '@/components/ui/select';
1414
import { Textarea } from '@/components/ui/textarea';
1515
import { useToast } from '@/components/ui/use-toast';
16-
import { LANGUAGE_OPTIONS } from '@/lib/constants/languages';
16+
import { LANGUAGE_OPTIONS, type LanguageCode } from '@/lib/constants/languages';
1717
import { useGenerationForm } from '@/lib/hooks/useGenerationForm';
1818
import { useProfile, useProfiles } from '@/lib/hooks/useProfiles';
1919
import { useAddStoryItem, useStory } from '@/lib/hooks/useStories';
@@ -112,6 +112,13 @@ export function FloatingGenerateBox({
112112
}
113113
}, [selectedProfileId, profiles, setSelectedProfileId]);
114114

115+
// Sync generation form language with selected profile's language
116+
useEffect(() => {
117+
if (selectedProfile?.language) {
118+
form.setValue('language', selectedProfile.language as LanguageCode);
119+
}
120+
}, [selectedProfile, form]);
121+
115122
// Auto-resize textarea based on content (only when expanded)
116123
useEffect(() => {
117124
if (!isExpanded) {

app/src/components/StoriesTab/StoryTrackEditor.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ export function StoryTrackEditor({ storyId, items }: StoryTrackEditorProps) {
313313
}
314314
}, [isResizing, handleResizeMove, handleResizeEnd]);
315315

316-
const handleTimelineClick = (e: React.MouseEvent<HTMLDivElement>) => {
316+
const handleTimelineClick = (e: React.MouseEvent<HTMLElement>) => {
317317
if (!tracksRef.current || draggingItem || trimmingItem) return;
318318
const rect = tracksRef.current.getBoundingClientRect();
319319
const x = e.clientX - rect.left + tracksRef.current.scrollLeft;

app/src/components/VoiceProfiles/AudioSampleRecording.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ export function AudioSampleRecording({
5858
// Request microphone access when component mounts
5959
useEffect(() => {
6060
if (!showWaveform) return;
61+
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) return;
6162

6263
let stream: MediaStream | null = null;
6364

app/src/hooks/useAutoUpdater.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ export function useAutoUpdater(options: boolean | UseAutoUpdaterOptions = false)
7373
}
7474
// Empty dependency array - only run once on mount
7575
// eslint-disable-next-line react-hooks/exhaustive-deps
76-
}, [platform.metadata.isTauricheckOnMountcheckForUpdates]);
76+
}, [platform.metadata.isTauri, checkOnMount, checkForUpdates]);
7777

7878
// Show toast when update is available
7979
useEffect(() => {

backend/backends/mlx_backend.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
1515
from ..utils.tasks import get_task_manager
1616

17+
LANGUAGE_CODE_TO_NAME = {
18+
"zh": "chinese", "en": "english", "ja": "japanese", "ko": "korean",
19+
"de": "german", "fr": "french", "ru": "russian", "pt": "portuguese",
20+
"es": "spanish", "it": "italian",
21+
}
22+
1723

1824
class MLXTTSBackend:
1925
"""MLX-based TTS backend using mlx-audio."""
@@ -316,7 +322,8 @@ def _generate_sync():
316322
# MLX generate() returns a generator yielding GenerationResult objects
317323
audio_chunks = []
318324
sample_rate = 24000
319-
325+
lang = LANGUAGE_CODE_TO_NAME.get(language, "auto")
326+
320327
# Set seed if provided (MLX uses numpy random)
321328
if seed is not None:
322329
import mlx.core as mx
@@ -344,23 +351,23 @@ def _generate_sync():
344351
sig = inspect.signature(self.model.generate)
345352
if "ref_audio" in sig.parameters:
346353
# Generate with voice cloning
347-
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text):
354+
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang):
348355
audio_chunks.append(np.array(result.audio))
349356
sample_rate = result.sample_rate
350357
else:
351358
# Fallback: generate without voice cloning
352-
for result in self.model.generate(text):
359+
for result in self.model.generate(text, lang_code=lang):
353360
audio_chunks.append(np.array(result.audio))
354361
sample_rate = result.sample_rate
355362
else:
356363
# No voice prompt, generate normally
357-
for result in self.model.generate(text):
364+
for result in self.model.generate(text, lang_code=lang):
358365
audio_chunks.append(np.array(result.audio))
359366
sample_rate = result.sample_rate
360367
except Exception as e:
361368
# If voice cloning fails, try without it
362369
print(f"Warning: Voice cloning failed, generating without voice prompt: {e}")
363-
for result in self.model.generate(text):
370+
for result in self.model.generate(text, lang_code=lang):
364371
audio_chunks.append(np.array(result.audio))
365372
sample_rate = result.sample_rate
366373

backend/backends/pytorch_backend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
1616
from ..utils.tasks import get_task_manager
1717

18+
LANGUAGE_CODE_TO_NAME = {
19+
"zh": "chinese", "en": "english", "ja": "japanese", "ko": "korean",
20+
"de": "german", "fr": "french", "ru": "russian", "pt": "portuguese",
21+
"es": "spanish", "it": "italian",
22+
}
23+
1824

1925
class PyTorchTTSBackend:
2026
"""PyTorch-based TTS backend using Qwen3-TTS."""
@@ -335,6 +341,7 @@ def _generate_sync():
335341
wavs, sample_rate = self.model.generate_voice_clone(
336342
text=text,
337343
voice_clone_prompt=voice_prompt,
344+
language=LANGUAGE_CODE_TO_NAME.get(language, "auto"),
338345
instruct=instruct,
339346
)
340347
return wavs[0], sample_rate

backend/main.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,23 @@
3636
version=__version__,
3737
)
3838

39-
# CORS middleware
39+
# CORS middleware - restrict to known local origins by default.
40+
# Set VOICEBOX_CORS_ORIGINS env var to a comma-separated list of origins
41+
# to allow additional origins (e.g. for remote server mode).
42+
_default_origins = [
43+
"http://localhost:5173", # Vite dev server
44+
"http://127.0.0.1:5173",
45+
"http://localhost:17493",
46+
"http://127.0.0.1:17493",
47+
"tauri://localhost", # Tauri webview (macOS)
48+
"https://tauri.localhost", # Tauri webview (Windows/Linux)
49+
]
50+
_env_origins = os.environ.get("VOICEBOX_CORS_ORIGINS", "")
51+
_cors_origins = _default_origins + [o.strip() for o in _env_origins.split(",") if o.strip()]
52+
4053
app.add_middleware(
4154
CORSMiddleware,
42-
allow_origins=["*"], # Configure appropriately for production
55+
allow_origins=_cors_origins,
4356
allow_credentials=True,
4457
allow_methods=["*"],
4558
allow_headers=["*"],
@@ -542,12 +555,6 @@ async def generate_speech(
542555
if not profile:
543556
raise HTTPException(status_code=404, detail="Profile not found")
544557

545-
# Create voice prompt from profile
546-
voice_prompt = await profiles.create_voice_prompt_for_profile(
547-
data.profile_id,
548-
db,
549-
)
550-
551558
# Generate audio
552559
tts_model = tts.get_tts_model()
553560
# Load the requested model size if different from current (async to not block)
@@ -582,7 +589,15 @@ async def download_model_background():
582589
}
583590
)
584591

592+
# Load the requested model BEFORE creating voice prompt,
593+
# so create_voice_prompt uses the correct model size
585594
await tts_model.load_model_async(model_size)
595+
596+
# Create voice prompt from profile
597+
voice_prompt = await profiles.create_voice_prompt_for_profile(
598+
data.profile_id,
599+
db,
600+
)
586601
audio, sample_rate = await tts_model.generate(
587602
data.text,
588603
voice_prompt,

backend/tests/test_cors.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""
2+
Tests for CORS origin restrictions.
3+
4+
Validates that the CORS middleware only allows known local origins
5+
and respects the VOICEBOX_CORS_ORIGINS environment variable.
6+
7+
Uses a minimal FastAPI app that mirrors the exact CORS configuration
8+
from backend/main.py, so tests run without heavy ML dependencies.
9+
10+
Usage:
11+
pip install httpx pytest fastapi starlette
12+
python -m pytest backend/tests/test_cors.py -v
13+
"""
14+
15+
import os
16+
import pytest
17+
from fastapi import FastAPI
18+
from fastapi.middleware.cors import CORSMiddleware
19+
from starlette.testclient import TestClient
20+
21+
22+
def _build_app(env_origins: str = "") -> FastAPI:
23+
"""
24+
Build a minimal FastAPI app with the same CORS logic as backend/main.py.
25+
26+
This mirrors the exact code in main.py so the test validates the real
27+
configuration without needing torch/numpy/transformers installed.
28+
"""
29+
app = FastAPI()
30+
31+
_default_origins = [
32+
"http://localhost:5173",
33+
"http://127.0.0.1:5173",
34+
"http://localhost:17493",
35+
"http://127.0.0.1:17493",
36+
"tauri://localhost",
37+
"https://tauri.localhost",
38+
]
39+
_cors_origins = _default_origins + [o.strip() for o in env_origins.split(",") if o.strip()]
40+
41+
app.add_middleware(
42+
CORSMiddleware,
43+
allow_origins=_cors_origins,
44+
allow_credentials=True,
45+
allow_methods=["*"],
46+
allow_headers=["*"],
47+
)
48+
49+
@app.get("/health")
50+
async def health():
51+
return {"status": "ok"}
52+
53+
return app
54+
55+
56+
@pytest.fixture()
57+
def client():
58+
return TestClient(_build_app())
59+
60+
61+
@pytest.fixture()
62+
def client_with_custom_origins():
63+
return TestClient(_build_app("https://custom.example.com,https://other.example.com"))
64+
65+
66+
def _get_with_origin(client: TestClient, origin: str) -> dict:
67+
"""Send a GET with Origin header, return response headers."""
68+
response = client.get("/health", headers={"Origin": origin})
69+
return dict(response.headers)
70+
71+
72+
def _preflight(client: TestClient, origin: str) -> dict:
73+
"""Send CORS preflight OPTIONS request, return response headers."""
74+
response = client.options(
75+
"/health",
76+
headers={
77+
"Origin": origin,
78+
"Access-Control-Request-Method": "GET",
79+
},
80+
)
81+
return dict(response.headers)
82+
83+
84+
class TestCORSDefaultOrigins:
85+
"""CORS should allow known local origins and block everything else."""
86+
87+
@pytest.mark.parametrize("origin", [
88+
"http://localhost:5173",
89+
"http://127.0.0.1:5173",
90+
"http://localhost:17493",
91+
"http://127.0.0.1:17493",
92+
"tauri://localhost",
93+
"https://tauri.localhost",
94+
])
95+
def test_allowed_origins(self, client, origin):
96+
headers = _get_with_origin(client, origin)
97+
assert headers.get("access-control-allow-origin") == origin
98+
99+
@pytest.mark.parametrize("origin", [
100+
"http://evil.com",
101+
"http://localhost:9999",
102+
"https://attacker.example.com",
103+
"null",
104+
])
105+
def test_blocked_origins(self, client, origin):
106+
headers = _get_with_origin(client, origin)
107+
assert "access-control-allow-origin" not in headers
108+
109+
def test_preflight_allowed(self, client):
110+
headers = _preflight(client, "http://localhost:5173")
111+
assert headers.get("access-control-allow-origin") == "http://localhost:5173"
112+
113+
def test_preflight_blocked(self, client):
114+
headers = _preflight(client, "http://evil.com")
115+
assert "access-control-allow-origin" not in headers
116+
117+
def test_credentials_header_present(self, client):
118+
headers = _get_with_origin(client, "http://localhost:5173")
119+
assert headers.get("access-control-allow-credentials") == "true"
120+
121+
122+
class TestCORSCustomOrigins:
123+
"""VOICEBOX_CORS_ORIGINS env var should extend the allowlist."""
124+
125+
def test_custom_origin_allowed(self, client_with_custom_origins):
126+
headers = _get_with_origin(client_with_custom_origins, "https://custom.example.com")
127+
assert headers.get("access-control-allow-origin") == "https://custom.example.com"
128+
129+
def test_other_custom_origin_allowed(self, client_with_custom_origins):
130+
headers = _get_with_origin(client_with_custom_origins, "https://other.example.com")
131+
assert headers.get("access-control-allow-origin") == "https://other.example.com"
132+
133+
def test_default_origins_still_work(self, client_with_custom_origins):
134+
headers = _get_with_origin(client_with_custom_origins, "http://localhost:5173")
135+
assert headers.get("access-control-allow-origin") == "http://localhost:5173"
136+
137+
def test_unlisted_origin_still_blocked(self, client_with_custom_origins):
138+
headers = _get_with_origin(client_with_custom_origins, "http://evil.com")
139+
assert "access-control-allow-origin" not in headers
140+
141+
142+
class TestCORSEnvVarParsing:
143+
"""Edge cases for VOICEBOX_CORS_ORIGINS parsing."""
144+
145+
def test_empty_env_var(self):
146+
app = _build_app("")
147+
client = TestClient(app)
148+
headers = _get_with_origin(client, "http://evil.com")
149+
assert "access-control-allow-origin" not in headers
150+
151+
def test_whitespace_trimmed(self):
152+
app = _build_app(" https://spaced.example.com ")
153+
client = TestClient(app)
154+
headers = _get_with_origin(client, "https://spaced.example.com")
155+
assert headers.get("access-control-allow-origin") == "https://spaced.example.com"
156+
157+
def test_trailing_comma_ignored(self):
158+
app = _build_app("https://one.example.com,")
159+
client = TestClient(app)
160+
headers = _get_with_origin(client, "https://one.example.com")
161+
assert headers.get("access-control-allow-origin") == "https://one.example.com"

0 commit comments

Comments
 (0)