fix(ocr): fail closed when TRAINING_TOKEN is not configured
_check_training_token previously skipped auth when TRAINING_TOKEN was
empty, allowing unauthenticated requests to reach /train and /segtrain.
Now returns 503 ("Training not configured on this node") when the token
is absent, so missing configuration fails closed rather than open.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -259,8 +259,10 @@ async def run_ocr_stream(request: OcrRequest):
|
|||||||
|
|
||||||
|
|
||||||
def _check_training_token(x_training_token: str | None) -> None:
|
def _check_training_token(x_training_token: str | None) -> None:
|
||||||
"""Validate training token if TRAINING_TOKEN env var is set."""
|
"""Validate training token — fails closed when TRAINING_TOKEN is not configured."""
|
||||||
if TRAINING_TOKEN and x_training_token != TRAINING_TOKEN:
|
if not TRAINING_TOKEN:
|
||||||
|
raise HTTPException(status_code=503, detail="Training not configured on this node")
|
||||||
|
if x_training_token != TRAINING_TOKEN:
|
||||||
raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token")
|
raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
69
ocr-service/test_training_auth.py
Normal file
69
ocr-service/test_training_auth.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""Tests for /train and /segtrain endpoint authentication."""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import zipfile
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from main import app
|
||||||
|
|
||||||
|
|
||||||
|
def _minimal_zip() -> bytes:
|
||||||
|
"""Return a ZIP with one .xml file so endpoint validation passes."""
|
||||||
|
buf = io.BytesIO()
|
||||||
|
with zipfile.ZipFile(buf, "w") as zf:
|
||||||
|
zf.writestr("page_01.xml", "<PcGts/>")
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Missing TRAINING_TOKEN → fail closed ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_train_returns_503_when_training_token_not_configured():
|
||||||
|
"""POST /train must return 503 when TRAINING_TOKEN env var is empty.
|
||||||
|
|
||||||
|
An empty token means the service was started without training configured.
|
||||||
|
Allowing requests through would grant unauthenticated access to the
|
||||||
|
training endpoint, contradicting the principle of failing closed.
|
||||||
|
"""
|
||||||
|
with patch("main.TRAINING_TOKEN", ""), \
|
||||||
|
patch("main._models_ready", True):
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/train",
|
||||||
|
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 503
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_segtrain_returns_503_when_training_token_not_configured():
|
||||||
|
"""POST /segtrain must return 503 when TRAINING_TOKEN env var is empty."""
|
||||||
|
with patch("main.TRAINING_TOKEN", ""), \
|
||||||
|
patch("main._models_ready", True):
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/segtrain",
|
||||||
|
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 503
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_train_returns_403_when_token_wrong():
|
||||||
|
"""POST /train must return 403 when TRAINING_TOKEN is set but header is wrong."""
|
||||||
|
with patch("main.TRAINING_TOKEN", "secret-token"), \
|
||||||
|
patch("main._models_ready", True):
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/train",
|
||||||
|
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
|
||||||
|
headers={"X-Training-Token": "wrong-token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
Reference in New Issue
Block a user