feat(ocr): record training runs in ocr_training_runs_total per kind and outcome
Wraps the await asyncio.to_thread(_run_*) calls in /train, /train-sender, and /segtrain with try/except. Recognition training (/train, /train-sender) shares kind="recognition"; /segtrain uses kind="segmentation". The ocr_model_accuracy gauge is set per kind on success. Refs #652 (AC6, decision #2) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -478,7 +478,14 @@ async def train_model(
|
||||
|
||||
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_training)
|
||||
try:
|
||||
result = await asyncio.to_thread(_run_training)
|
||||
except Exception:
|
||||
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc()
|
||||
raise
|
||||
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc()
|
||||
if result.get("accuracy") is not None:
|
||||
metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"])
|
||||
return result
|
||||
|
||||
|
||||
@@ -558,7 +565,14 @@ async def train_sender_model(
|
||||
|
||||
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_sender_training)
|
||||
try:
|
||||
result = await asyncio.to_thread(_run_sender_training)
|
||||
except Exception:
|
||||
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc()
|
||||
raise
|
||||
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc()
|
||||
if result.get("accuracy") is not None:
|
||||
metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"])
|
||||
return result
|
||||
|
||||
|
||||
@@ -668,7 +682,14 @@ async def segtrain_model(
|
||||
|
||||
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_segtrain)
|
||||
try:
|
||||
result = await asyncio.to_thread(_run_segtrain)
|
||||
except Exception:
|
||||
metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="error").inc()
|
||||
raise
|
||||
metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="success").inc()
|
||||
if result.get("accuracy") is not None:
|
||||
metrics.ocr_model_accuracy.labels(kind="segmentation").set(result["accuracy"])
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ Each test that asserts on a counter/gauge value uses a fresh CollectorRegistry
|
||||
(see decision #3 on issue #652) to keep the metrics isolated between tests.
|
||||
"""
|
||||
|
||||
import io
|
||||
import zipfile
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -15,6 +17,18 @@ from main import app
|
||||
from metrics import build_metrics
|
||||
|
||||
|
||||
def _minimal_zip() -> bytes:
|
||||
"""Return a ZIP containing one fake .xml so endpoint validation passes."""
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr("page_01.xml", "<PcGts/>")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def _fake_training_result(accuracy: float = 0.91) -> dict:
|
||||
return {"loss": None, "accuracy": accuracy, "cer": round(1 - accuracy, 4), "epochs": 5}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_metrics(monkeypatch):
|
||||
"""Replace the module-level `main.metrics` with one bound to a fresh registry."""
|
||||
@@ -307,3 +321,70 @@ async def test_ocr_processing_seconds_histogram_observed_per_page_in_stream(fres
|
||||
)
|
||||
assert count == 2.0
|
||||
assert sum_seconds >= 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_training_runs_total_incremented_with_recognition_success_label(fresh_metrics):
|
||||
"""/train success increments ocr_training_runs_total{kind=recognition, outcome=success}."""
|
||||
async def fake_to_thread(func, *args, **kwargs):
|
||||
return _fake_training_result()
|
||||
|
||||
with patch("main.TRAINING_TOKEN", "secret-token"), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.asyncio.to_thread", side_effect=fake_to_thread):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/train",
|
||||
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
|
||||
headers={"X-Training-Token": "secret-token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert fresh_metrics.ocr_training_runs_total.labels(
|
||||
kind="recognition", outcome="success"
|
||||
)._value.get() == 1.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_training_runs_total_incremented_with_recognition_error_label(fresh_metrics):
|
||||
"""When /train's inner runner raises, the error counter bumps and the exception propagates."""
|
||||
async def fake_to_thread(func, *args, **kwargs):
|
||||
raise RuntimeError("ketos train failed (exit 1): synthetic")
|
||||
|
||||
with patch("main.TRAINING_TOKEN", "secret-token"), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.asyncio.to_thread", side_effect=fake_to_thread):
|
||||
transport = ASGITransport(app=app, raise_app_exceptions=False)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/train",
|
||||
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
|
||||
headers={"X-Training-Token": "secret-token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
assert fresh_metrics.ocr_training_runs_total.labels(
|
||||
kind="recognition", outcome="error"
|
||||
)._value.get() == 1.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_training_runs_total_incremented_with_segmentation_success_label(fresh_metrics):
|
||||
"""/segtrain success increments ocr_training_runs_total{kind=segmentation, outcome=success}."""
|
||||
async def fake_to_thread(func, *args, **kwargs):
|
||||
return _fake_training_result(accuracy=0.83)
|
||||
|
||||
with patch("main.TRAINING_TOKEN", "secret-token"), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.asyncio.to_thread", side_effect=fake_to_thread):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/segtrain",
|
||||
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
|
||||
headers={"X-Training-Token": "secret-token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert fresh_metrics.ocr_training_runs_total.labels(
|
||||
kind="segmentation", outcome="success"
|
||||
)._value.get() == 1.0
|
||||
|
||||
Reference in New Issue
Block a user