refactor(ocr): extract _record_training for shared metric bookkeeping

The /train, /train-sender, and /segtrain endpoints each duplicated the
same eight-line try/except + counter + gauge block around the
asyncio.to_thread call. Lift it into _record_training(runner, kind),
which accepts a sync- or async-returning callable for flexibility.
Each endpoint now ends with a single return line. Behaviour preserved —
status codes, error propagation, and metric labels stay identical.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-05-21 16:58:40 +02:00
parent 22a5ee816a
commit df952861c4

View File

@@ -2,6 +2,7 @@
import asyncio
import glob
import inspect
import io
import json
import logging
@@ -14,6 +15,7 @@ import time
import zipfile
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import Awaitable, Callable
from urllib.parse import urlparse
import httpx
@@ -50,6 +52,30 @@ ALLOWED_PDF_HOSTS = set(
_SPELL_CHECK_SCRIPT_TYPES = {"HANDWRITING_KURRENT", "HANDWRITING_LATIN"}
async def _record_training(
runner: Callable[[], Awaitable[dict] | dict],
kind: str,
) -> dict:
"""Run a training callable and record outcome + accuracy metrics.
Wraps the per-endpoint try/except + outcome counter + accuracy gauge
block that used to be repeated at /train, /train-sender, and /segtrain.
The runner returns a dict with at least an `accuracy` key; if its value
is None, the gauge is left at its default.
"""
try:
result = runner()
if inspect.isawaitable(result):
result = await result
except Exception:
metrics.ocr_training_runs_total.labels(kind=kind, outcome="error").inc()
raise
metrics.ocr_training_runs_total.labels(kind=kind, outcome="success").inc()
if result.get("accuracy") is not None:
metrics.ocr_model_accuracy.labels(kind=kind).set(result["accuracy"])
return result
def _observe_block_words(words: list[dict], threshold: float) -> None:
"""Record per-block word counts and below-threshold word counts.
@@ -504,15 +530,7 @@ async def train_model(
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
try:
result = await asyncio.to_thread(_run_training)
except Exception:
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc()
raise
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc()
if result.get("accuracy") is not None:
metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"])
return result
return await _record_training(lambda: asyncio.to_thread(_run_training), kind="recognition")
@app.post("/train-sender")
@@ -591,15 +609,9 @@ async def train_sender_model(
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
try:
result = await asyncio.to_thread(_run_sender_training)
except Exception:
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc()
raise
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc()
if result.get("accuracy") is not None:
metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"])
return result
return await _record_training(
lambda: asyncio.to_thread(_run_sender_training), kind="recognition"
)
@app.post("/segtrain")
@@ -708,15 +720,7 @@ async def segtrain_model(
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
try:
result = await asyncio.to_thread(_run_segtrain)
except Exception:
metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="error").inc()
raise
metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="success").inc()
if result.get("accuracy") is not None:
metrics.ocr_model_accuracy.labels(kind="segmentation").set(result["accuracy"])
return result
return await _record_training(lambda: asyncio.to_thread(_run_segtrain), kind="segmentation")
async def _download_and_convert_pdf(url: str) -> list[Image.Image]: