refactor(ocr): extract _record_training for shared metric bookkeeping
The /train, /train-sender, and /segtrain endpoints each duplicated the same eight-line try/except + counter + gauge block around the asyncio.to_thread call. Lift it into _record_training(runner, kind), which accepts a sync- or async-returning callable for flexibility. Each endpoint now ends with a single return line. Behaviour preserved — status codes, error propagation, and metric labels stay identical. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import glob
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
@@ -14,6 +15,7 @@ import time
|
||||
import zipfile
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from typing import Awaitable, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
@@ -50,6 +52,30 @@ ALLOWED_PDF_HOSTS = set(
|
||||
_SPELL_CHECK_SCRIPT_TYPES = {"HANDWRITING_KURRENT", "HANDWRITING_LATIN"}
|
||||
|
||||
|
||||
async def _record_training(
|
||||
runner: Callable[[], Awaitable[dict] | dict],
|
||||
kind: str,
|
||||
) -> dict:
|
||||
"""Run a training callable and record outcome + accuracy metrics.
|
||||
|
||||
Wraps the per-endpoint try/except + outcome counter + accuracy gauge
|
||||
block that used to be repeated at /train, /train-sender, and /segtrain.
|
||||
The runner returns a dict with at least an `accuracy` key; if its value
|
||||
is None, the gauge is left at its default.
|
||||
"""
|
||||
try:
|
||||
result = runner()
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
except Exception:
|
||||
metrics.ocr_training_runs_total.labels(kind=kind, outcome="error").inc()
|
||||
raise
|
||||
metrics.ocr_training_runs_total.labels(kind=kind, outcome="success").inc()
|
||||
if result.get("accuracy") is not None:
|
||||
metrics.ocr_model_accuracy.labels(kind=kind).set(result["accuracy"])
|
||||
return result
|
||||
|
||||
|
||||
def _observe_block_words(words: list[dict], threshold: float) -> None:
|
||||
"""Record per-block word counts and below-threshold word counts.
|
||||
|
||||
@@ -504,15 +530,7 @@ async def train_model(
|
||||
|
||||
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
try:
|
||||
result = await asyncio.to_thread(_run_training)
|
||||
except Exception:
|
||||
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc()
|
||||
raise
|
||||
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc()
|
||||
if result.get("accuracy") is not None:
|
||||
metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"])
|
||||
return result
|
||||
return await _record_training(lambda: asyncio.to_thread(_run_training), kind="recognition")
|
||||
|
||||
|
||||
@app.post("/train-sender")
|
||||
@@ -591,15 +609,9 @@ async def train_sender_model(
|
||||
|
||||
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
try:
|
||||
result = await asyncio.to_thread(_run_sender_training)
|
||||
except Exception:
|
||||
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="error").inc()
|
||||
raise
|
||||
metrics.ocr_training_runs_total.labels(kind="recognition", outcome="success").inc()
|
||||
if result.get("accuracy") is not None:
|
||||
metrics.ocr_model_accuracy.labels(kind="recognition").set(result["accuracy"])
|
||||
return result
|
||||
return await _record_training(
|
||||
lambda: asyncio.to_thread(_run_sender_training), kind="recognition"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/segtrain")
|
||||
@@ -708,15 +720,7 @@ async def segtrain_model(
|
||||
|
||||
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
try:
|
||||
result = await asyncio.to_thread(_run_segtrain)
|
||||
except Exception:
|
||||
metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="error").inc()
|
||||
raise
|
||||
metrics.ocr_training_runs_total.labels(kind="segmentation", outcome="success").inc()
|
||||
if result.get("accuracy") is not None:
|
||||
metrics.ocr_model_accuracy.labels(kind="segmentation").set(result["accuracy"])
|
||||
return result
|
||||
return await _record_training(lambda: asyncio.to_thread(_run_segtrain), kind="segmentation")
|
||||
|
||||
|
||||
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
||||
|
||||
Reference in New Issue
Block a user