From 6c2b9af10bc61b6e5583ea131422ead6ec8831f9 Mon Sep 17 00:00:00 2001 From: Marcel Date: Thu, 21 May 2026 16:12:26 +0200 Subject: [PATCH] feat(ocr): record training runs in ocr_training_runs_total per kind and outcome Wraps the await asyncio.to_thread(_run_*) calls in /train, /train-sender, and /segtrain with try/except. Recognition training (/train, /train-sender) shares kind="recognition"; /segtrain uses kind="segmentation". The ocr_model_accuracy gauge is set per kind on success. Refs #652 (AC6, decision #2) Co-Authored-By: Claude Sonnet 4.6 --- ocr-service/main.py | 27 +++++++++++-- ocr-service/test_metrics.py | 81 +++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 3 deletions(-) diff --git a/ocr-service/main.py b/ocr-service/main.py index c9f5b783..1460a213 100644 --- a/ocr-service/main.py +++ b/ocr-service/main.py @@ -478,7 +478,14 @@ async def train_model( return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} - result = await asyncio.to_thread(_run_training) + try: + result = await asyncio.to_thread(_run_training) + except Exception: + metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc() + raise + metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc() + if result.get("accuracy") is not None: + metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"]) return result @@ -558,7 +565,14 @@ async def train_sender_model( return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} - result = await asyncio.to_thread(_run_sender_training) + try: + result = await asyncio.to_thread(_run_sender_training) + except Exception: + metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc() + raise + metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc() + if result.get("accuracy") is not None: + metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"]) return result @@ -668,7 +682,14 @@ async def segtrain_model( return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} - result = await asyncio.to_thread(_run_segtrain) + try: + result = await asyncio.to_thread(_run_segtrain) + except Exception: + metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="error").inc() + raise + metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="success").inc() + if result.get("accuracy") is not None: + metrics.ocr_model_accuracy.labels(kind="segmentation").set(result["accuracy"]) return result diff --git a/ocr-service/test_metrics.py b/ocr-service/test_metrics.py index 5701bbe3..55108fb8 100644 --- a/ocr-service/test_metrics.py +++ b/ocr-service/test_metrics.py @@ -4,6 +4,8 @@ Each test that asserts on a counter/gauge value uses a fresh CollectorRegistry (see decision #3 on issue #652) to keep the metrics isolated between tests. """ +import io +import zipfile from unittest.mock import AsyncMock, patch import pytest @@ -15,6 +17,18 @@ from main import app from metrics import build_metrics +def _minimal_zip() -> bytes: + """Return a ZIP containing one fake .xml so endpoint validation passes.""" + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("page_01.xml", "") + return buf.getvalue() + + +def _fake_training_result(accuracy: float = 0.91) -> dict: + return {"loss": None, "accuracy": accuracy, "cer": round(1 - accuracy, 4), "epochs": 5} + + @pytest.fixture def fresh_metrics(monkeypatch): """Replace the module-level `main.metrics` with one bound to a fresh registry.""" @@ -307,3 +321,70 @@ async def test_ocr_processing_seconds_histogram_observed_per_page_in_stream(fres ) assert count == 2.0 assert sum_seconds >= 0.0 + + +@pytest.mark.asyncio +async def test_ocr_training_runs_total_incremented_with_recognition_success_label(fresh_metrics): + """/train success increments ocr_training_runs_total{kind=recognition, outcome=success}.""" + async def fake_to_thread(func, *args, **kwargs): + return _fake_training_result() + + with patch("main.TRAINING_TOKEN", "secret-token"), \ + patch("main._models_ready", True), \ + patch("main.asyncio.to_thread", side_effect=fake_to_thread): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/train", + files={"file": ("training.zip", _minimal_zip(), "application/zip")}, + headers={"X-Training-Token": "secret-token"}, + ) + + assert response.status_code == 200 + assert fresh_metrics.ocr_training_runs_total.labels( + kind="recognition", outcome="success" + )._value.get() == 1.0 + + +@pytest.mark.asyncio +async def test_ocr_training_runs_total_incremented_with_recognition_error_label(fresh_metrics): + """When /train's inner runner raises, the error counter bumps and the exception propagates.""" + async def fake_to_thread(func, *args, **kwargs): + raise RuntimeError("ketos train failed (exit 1): synthetic") + + with patch("main.TRAINING_TOKEN", "secret-token"), \ + patch("main._models_ready", True), \ + patch("main.asyncio.to_thread", side_effect=fake_to_thread): + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/train", + files={"file": ("training.zip", _minimal_zip(), "application/zip")}, + headers={"X-Training-Token": "secret-token"}, + ) + + assert response.status_code == 500 + assert fresh_metrics.ocr_training_runs_total.labels( + kind="recognition", outcome="error" + )._value.get() == 1.0 + + +@pytest.mark.asyncio +async def test_ocr_training_runs_total_incremented_with_segmentation_success_label(fresh_metrics): + """/segtrain success increments ocr_training_runs_total{kind=segmentation, outcome=success}.""" + async def fake_to_thread(func, *args, **kwargs): + return _fake_training_result(accuracy=0.83) + + with patch("main.TRAINING_TOKEN", "secret-token"), \ + patch("main._models_ready", True), \ + patch("main.asyncio.to_thread", side_effect=fake_to_thread): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/segtrain", + files={"file": ("training.zip", _minimal_zip(), "application/zip")}, + headers={"X-Training-Token": "secret-token"}, + ) + + assert response.status_code == 200 + assert fresh_metrics.ocr_training_runs_total.labels( + kind="segmentation", outcome="success" + )._value.get() == 1.0