feat(ocr): record training runs in ocr_training_runs_total per kind and outcome

Wraps the await asyncio.to_thread(_run_*) calls in /train, /train-sender,
and /segtrain with try/except. Recognition training (/train, /train-sender)
shares kind="recognition"; /segtrain uses kind="segmentation". The
ocr_model_accuracy gauge is set per kind on success.

Refs #652 (AC6, decision #2)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-05-21 16:12:26 +02:00
parent 2e3744d9ef
commit 6c2b9af10b
2 changed files with 105 additions and 3 deletions

View File

@@ -478,7 +478,14 @@ async def train_model(
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_training)
try:
result = await asyncio.to_thread(_run_training)
except Exception:
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc()
raise
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc()
if result.get("accuracy") is not None:
metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"])
return result
@@ -558,7 +565,14 @@ async def train_sender_model(
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_sender_training)
try:
result = await asyncio.to_thread(_run_sender_training)
except Exception:
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc()
raise
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc()
if result.get("accuracy") is not None:
metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"])
return result
@@ -668,7 +682,14 @@ async def segtrain_model(
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_segtrain)
try:
result = await asyncio.to_thread(_run_segtrain)
except Exception:
metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="error").inc()
raise
metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="success").inc()
if result.get("accuracy") is not None:
metrics.ocr_model_accuracy.labels(kind="segmentation").set(result["accuracy"])
return result

View File

@@ -4,6 +4,8 @@ Each test that asserts on a counter/gauge value uses a fresh CollectorRegistry
(see decision #3 on issue #652) to keep the metrics isolated between tests.
"""
import io
import zipfile
from unittest.mock import AsyncMock, patch
import pytest
@@ -15,6 +17,18 @@ from main import app
from metrics import build_metrics
def _minimal_zip() -> bytes:
"""Return a ZIP containing one fake .xml so endpoint validation passes."""
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("page_01.xml", "<PcGts/>")
return buf.getvalue()
def _fake_training_result(accuracy: float = 0.91) -> dict:
return {"loss": None, "accuracy": accuracy, "cer": round(1 - accuracy, 4), "epochs": 5}
@pytest.fixture
def fresh_metrics(monkeypatch):
"""Replace the module-level `main.metrics` with one bound to a fresh registry."""
@@ -307,3 +321,70 @@ async def test_ocr_processing_seconds_histogram_observed_per_page_in_stream(fres
)
assert count == 2.0
assert sum_seconds >= 0.0
@pytest.mark.asyncio
async def test_ocr_training_runs_total_incremented_with_recognition_success_label(fresh_metrics):
"""/train success increments ocr_training_runs_total{kind=recognition, outcome=success}."""
async def fake_to_thread(func, *args, **kwargs):
return _fake_training_result()
with patch("main.TRAINING_TOKEN", "secret-token"), \
patch("main._models_ready", True), \
patch("main.asyncio.to_thread", side_effect=fake_to_thread):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/train",
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
headers={"X-Training-Token": "secret-token"},
)
assert response.status_code == 200
assert fresh_metrics.ocr_training_runs_total.labels(
kind="recognition", outcome="success"
)._value.get() == 1.0
@pytest.mark.asyncio
async def test_ocr_training_runs_total_incremented_with_recognition_error_label(fresh_metrics):
"""When /train's inner runner raises, the error counter bumps and the exception propagates."""
async def fake_to_thread(func, *args, **kwargs):
raise RuntimeError("ketos train failed (exit 1): synthetic")
with patch("main.TRAINING_TOKEN", "secret-token"), \
patch("main._models_ready", True), \
patch("main.asyncio.to_thread", side_effect=fake_to_thread):
transport = ASGITransport(app=app, raise_app_exceptions=False)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/train",
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
headers={"X-Training-Token": "secret-token"},
)
assert response.status_code == 500
assert fresh_metrics.ocr_training_runs_total.labels(
kind="recognition", outcome="error"
)._value.get() == 1.0
@pytest.mark.asyncio
async def test_ocr_training_runs_total_incremented_with_segmentation_success_label(fresh_metrics):
"""/segtrain success increments ocr_training_runs_total{kind=segmentation, outcome=success}."""
async def fake_to_thread(func, *args, **kwargs):
return _fake_training_result(accuracy=0.83)
with patch("main.TRAINING_TOKEN", "secret-token"), \
patch("main._models_ready", True), \
patch("main.asyncio.to_thread", side_effect=fake_to_thread):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/segtrain",
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
headers={"X-Training-Token": "secret-token"},
)
assert response.status_code == 200
assert fresh_metrics.ocr_training_runs_total.labels(
kind="segmentation", outcome="success"
)._value.get() == 1.0