fix(ocr): fail closed when TRAINING_TOKEN is not configured

_check_training_token previously skipped auth when TRAINING_TOKEN was
empty, allowing unauthenticated requests to reach /train and /segtrain.
Now returns 503 ("Training not configured on this node") when the token
is absent, so missing configuration fails closed rather than open.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-14 10:02:13 +02:00
parent dc283ba271
commit 30a17c97e8
2 changed files with 73 additions and 2 deletions

View File

@@ -259,8 +259,10 @@ async def run_ocr_stream(request: OcrRequest):
def _check_training_token(x_training_token: str | None) -> None:
"""Validate training token if TRAINING_TOKEN env var is set."""
if TRAINING_TOKEN and x_training_token != TRAINING_TOKEN:
"""Validate training token — fails closed when TRAINING_TOKEN is not configured."""
if not TRAINING_TOKEN:
raise HTTPException(status_code=503, detail="Training not configured on this node")
if x_training_token != TRAINING_TOKEN:
raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token")

View File

@@ -0,0 +1,69 @@
"""Tests for /train and /segtrain endpoint authentication."""
import io
import zipfile
from unittest.mock import AsyncMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from main import app
def _minimal_zip() -> bytes:
"""Return a ZIP with one .xml file so endpoint validation passes."""
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("page_01.xml", "<PcGts/>")
return buf.getvalue()
# ─── Missing TRAINING_TOKEN → fail closed ─────────────────────────────────────
@pytest.mark.asyncio
async def test_train_returns_503_when_training_token_not_configured():
"""POST /train must return 503 when TRAINING_TOKEN env var is empty.
An empty token means the service was started without training configured.
Allowing requests through would grant unauthenticated access to the
training endpoint, contradicting the principle of failing closed.
"""
with patch("main.TRAINING_TOKEN", ""), \
patch("main._models_ready", True):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/train",
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
)
assert response.status_code == 503
@pytest.mark.asyncio
async def test_segtrain_returns_503_when_training_token_not_configured():
"""POST /segtrain must return 503 when TRAINING_TOKEN env var is empty."""
with patch("main.TRAINING_TOKEN", ""), \
patch("main._models_ready", True):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/segtrain",
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
)
assert response.status_code == 503
@pytest.mark.asyncio
async def test_train_returns_403_when_token_wrong():
"""POST /train must return 403 when TRAINING_TOKEN is set but header is wrong."""
with patch("main.TRAINING_TOKEN", "secret-token"), \
patch("main._models_ready", True):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/train",
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
headers={"X-Training-Token": "wrong-token"},
)
assert response.status_code == 403