From df952861c4605dc5ec2721c6b7b9ffe43276fec6 Mon Sep 17 00:00:00 2001 From: Marcel Date: Thu, 21 May 2026 16:58:40 +0200 Subject: [PATCH] refactor(ocr): extract _record_training for shared metric bookkeeping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The /train, /train-sender, and /segtrain endpoints each duplicated the same eight-line try/except + counter + gauge block around the asyncio.to_thread call. Lift it into _record_training(runner, kind), which accepts a sync- or async-returning callable for flexibility. Each endpoint now ends with a single return line. Behaviour preserved — status codes, error propagation, and metric labels stay identical. Co-Authored-By: Claude Sonnet 4.6 --- ocr-service/main.py | 58 ++++++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/ocr-service/main.py b/ocr-service/main.py index 51cf737b..489fd98e 100644 --- a/ocr-service/main.py +++ b/ocr-service/main.py @@ -2,6 +2,7 @@ import asyncio import glob +import inspect import io import json import logging @@ -14,6 +15,7 @@ import time import zipfile from contextlib import asynccontextmanager from datetime import datetime, timezone +from typing import Awaitable, Callable from urllib.parse import urlparse import httpx @@ -50,6 +52,30 @@ ALLOWED_PDF_HOSTS = set( _SPELL_CHECK_SCRIPT_TYPES = {"HANDWRITING_KURRENT", "HANDWRITING_LATIN"} +async def _record_training( + runner: Callable[[], Awaitable[dict] | dict], + kind: str, +) -> dict: + """Run a training callable and record outcome + accuracy metrics. + + Wraps the per-endpoint try/except + outcome counter + accuracy gauge + block that used to be repeated at /train, /train-sender, and /segtrain. + The runner returns a dict with at least an `accuracy` key; if its value + is None, the gauge is left at its default. + """ + try: + result = runner() + if inspect.isawaitable(result): + result = await result + except Exception: + metrics.ocr_training_runs_total.labels(kind=kind, outcome="error").inc() + raise + metrics.ocr_training_runs_total.labels(kind=kind, outcome="success").inc() + if result.get("accuracy") is not None: + metrics.ocr_model_accuracy.labels(kind=kind).set(result["accuracy"]) + return result + + def _observe_block_words(words: list[dict], threshold: float) -> None: """Record per-block word counts and below-threshold word counts. @@ -504,15 +530,7 @@ async def train_model( return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} - try: - result = await asyncio.to_thread(_run_training) - except Exception: - metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc() - raise - metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc() - if result.get("accuracy") is not None: - metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"]) - return result + return await _record_training(lambda: asyncio.to_thread(_run_training), kind="recognition") @app.post("/train-sender") @@ -591,15 +609,9 @@ async def train_sender_model( return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} - try: - result = await asyncio.to_thread(_run_sender_training) - except Exception: - metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc() - raise - metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc() - if result.get("accuracy") is not None: - metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"]) - return result + return await _record_training( + lambda: asyncio.to_thread(_run_sender_training), kind="recognition" + ) @app.post("/segtrain") @@ -708,15 +720,7 @@ async def segtrain_model( return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} - try: - result = await asyncio.to_thread(_run_segtrain) - except Exception: - metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="error").inc() - raise - metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="success").inc() - if result.get("accuracy") is not None: - metrics.ocr_model_accuracy.labels(kind="segmentation").set(result["accuracy"]) - return result + return await _record_training(lambda: asyncio.to_thread(_run_segtrain), kind="segmentation") async def _download_and_convert_pdf(url: str) -> list[Image.Image]: