refactor(ocr): extract _record_training for shared metric bookkeeping

The /train, /train-sender, and /segtrain endpoints each duplicated the
same eight-line try/except + counter + gauge block around the
asyncio.to_thread call. Lift it into _record_training(runner, kind),
which accepts a sync- or async-returning callable for flexibility.
Each endpoint now ends with a single return line. Behaviour preserved —
status codes, error propagation, and metric labels stay identical.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-05-21 16:58:40 +02:00
parent 22a5ee816a
commit df952861c4

View File

@@ -2,6 +2,7 @@
import asyncio import asyncio
import glob import glob
import inspect
import io import io
import json import json
import logging import logging
@@ -14,6 +15,7 @@ import time
import zipfile import zipfile
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Awaitable, Callable
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
@@ -50,6 +52,30 @@ ALLOWED_PDF_HOSTS = set(
_SPELL_CHECK_SCRIPT_TYPES = {"HANDWRITING_KURRENT", "HANDWRITING_LATIN"} _SPELL_CHECK_SCRIPT_TYPES = {"HANDWRITING_KURRENT", "HANDWRITING_LATIN"}
async def _record_training(
runner: Callable[[], Awaitable[dict] | dict],
kind: str,
) -> dict:
"""Run a training callable and record outcome + accuracy metrics.
Wraps the per-endpoint try/except + outcome counter + accuracy gauge
block that used to be repeated at /train, /train-sender, and /segtrain.
The runner returns a dict with at least an `accuracy` key; if its value
is None, the gauge is left at its default.
"""
try:
result = runner()
if inspect.isawaitable(result):
result = await result
except Exception:
metrics.ocr_training_runs_total.labels(kind=kind, outcome="error").inc()
raise
metrics.ocr_training_runs_total.labels(kind=kind, outcome="success").inc()
if result.get("accuracy") is not None:
metrics.ocr_model_accuracy.labels(kind=kind).set(result["accuracy"])
return result
def _observe_block_words(words: list[dict], threshold: float) -> None: def _observe_block_words(words: list[dict], threshold: float) -> None:
"""Record per-block word counts and below-threshold word counts. """Record per-block word counts and below-threshold word counts.
@@ -504,15 +530,7 @@ async def train_model(
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
try: return await _record_training(lambda: asyncio.to_thread(_run_training), kind="recognition")
result = await asyncio.to_thread(_run_training)
except Exception:
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc()
raise
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc()
if result.get("accuracy") is not None:
metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"])
return result
@app.post("/train-sender") @app.post("/train-sender")
@@ -591,15 +609,9 @@ async def train_sender_model(
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
try: return await _record_training(
result = await asyncio.to_thread(_run_sender_training) lambda: asyncio.to_thread(_run_sender_training), kind="recognition"
except Exception: )
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc()
raise
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc()
if result.get("accuracy") is not None:
metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"])
return result
@app.post("/segtrain") @app.post("/segtrain")
@@ -708,15 +720,7 @@ async def segtrain_model(
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
try: return await _record_training(lambda: asyncio.to_thread(_run_segtrain), kind="segmentation")
result = await asyncio.to_thread(_run_segtrain)
except Exception:
metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="error").inc()
raise
metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="success").inc()
if result.get("accuracy") is not None:
metrics.ocr_model_accuracy.labels(kind="segmentation").set(result["accuracy"])
return result
async def _download_and_convert_pdf(url: str) -> list[Image.Image]: async def _download_and_convert_pdf(url: str) -> list[Image.Image]: