diff --git a/ocr-service/main.py b/ocr-service/main.py index 754ed75b..0f656aaa 100644 --- a/ocr-service/main.py +++ b/ocr-service/main.py @@ -259,8 +259,10 @@ async def run_ocr_stream(request: OcrRequest): def _check_training_token(x_training_token: str | None) -> None: - """Validate training token if TRAINING_TOKEN env var is set.""" - if TRAINING_TOKEN and x_training_token != TRAINING_TOKEN: + """Validate training token — fails closed when TRAINING_TOKEN is not configured.""" + if not TRAINING_TOKEN: + raise HTTPException(status_code=503, detail="Training not configured on this node") + if x_training_token != TRAINING_TOKEN: raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token") diff --git a/ocr-service/test_training_auth.py b/ocr-service/test_training_auth.py new file mode 100644 index 00000000..e8ad23ad --- /dev/null +++ b/ocr-service/test_training_auth.py @@ -0,0 +1,69 @@ +"""Tests for /train and /segtrain endpoint authentication.""" + +import io +import zipfile +from unittest.mock import AsyncMock, patch + +import pytest +from httpx import ASGITransport, AsyncClient + +from main import app + + +def _minimal_zip() -> bytes: + """Return a ZIP with one .xml file so endpoint validation passes.""" + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("page_01.xml", "") + return buf.getvalue() + + +# ─── Missing TRAINING_TOKEN → fail closed ───────────────────────────────────── + + +@pytest.mark.asyncio +async def test_train_returns_503_when_training_token_not_configured(): + """POST /train must return 503 when TRAINING_TOKEN env var is empty. + + An empty token means the service was started without training configured. + Allowing requests through would grant unauthenticated access to the + training endpoint, contradicting the principle of failing closed. + """ + with patch("main.TRAINING_TOKEN", ""), \ + patch("main._models_ready", True): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/train", + files={"file": ("training.zip", _minimal_zip(), "application/zip")}, + ) + + assert response.status_code == 503 + + +@pytest.mark.asyncio +async def test_segtrain_returns_503_when_training_token_not_configured(): + """POST /segtrain must return 503 when TRAINING_TOKEN env var is empty.""" + with patch("main.TRAINING_TOKEN", ""), \ + patch("main._models_ready", True): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/segtrain", + files={"file": ("training.zip", _minimal_zip(), "application/zip")}, + ) + + assert response.status_code == 503 + + +@pytest.mark.asyncio +async def test_train_returns_403_when_token_wrong(): + """POST /train must return 403 when TRAINING_TOKEN is set but header is wrong.""" + with patch("main.TRAINING_TOKEN", "secret-token"), \ + patch("main._models_ready", True): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/train", + files={"file": ("training.zip", _minimal_zip(), "application/zip")}, + headers={"X-Training-Token": "wrong-token"}, + ) + + assert response.status_code == 403