From 64d27d6d61b1a96f948c59c60f0648a701ffc375 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 18:05:39 +0200 Subject: [PATCH] feat(ocr): per-sender model registry and /train-sender endpoint engines/kraken.py: - Add _SenderModelRegistry with LRU eviction (max configurable via OCR_MAX_CACHED_MODELS env var), double-checked locking, invalidate(), and path whitelist (/app/models/ only) - Add _load_sender_model() helper for testability - extract_page_blocks() and extract_region_text() accept optional sender_model_path; route to sender registry when provided models.py: - OcrRequest gains senderModelPath: str | None = None field main.py: - /ocr and /ocr/stream pass request.senderModelPath to Kraken engine - New /train-sender endpoint: validates output_model_path, runs ketos train with base model as starting point, invalidates sender cache docker-compose.yml: - Add OCR_MAX_CACHED_MODELS: "5" to ocr-service environment test_sender_registry.py: - 4 tests: cache hit, LRU eviction, invalidate, path traversal guard Co-Authored-By: Claude Sonnet 4.6 --- docker-compose.yml | 1 + ocr-service/engines/kraken.py | 86 +++++++++++++++++++++++++-- ocr-service/main.py | 90 ++++++++++++++++++++++++++++- ocr-service/models.py | 1 + ocr-service/test_sender_registry.py | 65 +++++++++++++++++++++ 5 files changed, 234 insertions(+), 9 deletions(-) create mode 100644 ocr-service/test_sender_registry.py diff --git a/docker-compose.yml b/docker-compose.yml index e9105e3a..e914cc91 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -97,6 +97,7 @@ services: DETECTOR_BATCH_SIZE: "8" OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency) OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page) + OCR_MAX_CACHED_MODELS: "5" # LRU cache size for per-sender Kraken models networks: - archive-net healthcheck: diff --git a/ocr-service/engines/kraken.py b/ocr-service/engines/kraken.py index 8deae3ef..d33567bc 100644 --- a/ocr-service/engines/kraken.py +++ b/ocr-service/engines/kraken.py @@ -1,13 +1,78 @@ """Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin.""" +import collections import logging import os +import threading logger = logging.getLogger(__name__) _model = None _model_path = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel") +_MODELS_DIR = os.path.realpath("/app/models") +_MAX_CACHED_SENDER_MODELS = int(os.environ.get("OCR_MAX_CACHED_MODELS", "5")) + + +def _load_sender_model(path: str): + """Load a Kraken model from disk. Extracted for testability.""" + from kraken.lib import models as kraken_models + return kraken_models.load_any(path) + + +class _SenderModelRegistry: + """Thread-safe LRU cache for per-sender Kraken models. + + Uses double-checked locking: model loading happens outside the lock to + avoid blocking concurrent OCR requests. At most one entry per path is + stored even under concurrent load. + """ + + def __init__(self, max_size: int): + self._max_size = max_size + self._cache: collections.OrderedDict = collections.OrderedDict() + self._lock = threading.Lock() + + def get_model(self, model_path: str): + """Return the cached model or load it. Validates path is within /app/models/.""" + resolved = os.path.realpath(model_path) + if not resolved.startswith(_MODELS_DIR + os.sep) and resolved != _MODELS_DIR: + raise ValueError(f"Sender model path not allowed: {model_path}") + + with self._lock: + if model_path in self._cache: + self._cache.move_to_end(model_path) + return self._cache[model_path] + + new_model = _load_sender_model(model_path) + + with self._lock: + if model_path in self._cache: + self._cache.move_to_end(model_path) + return self._cache[model_path] + self._cache[model_path] = new_model + self._cache.move_to_end(model_path) + while len(self._cache) > self._max_size: + self._cache.popitem(last=False) + + return new_model + + def invalidate(self, model_path: str) -> None: + """Remove model from cache so the next request reloads from disk.""" + with self._lock: + self._cache.pop(model_path, None) + + def size(self) -> int: + with self._lock: + return len(self._cache) + + def _contains(self, model_path: str) -> bool: + with self._lock: + return model_path in self._cache + + +_sender_registry = _SenderModelRegistry(_MAX_CACHED_SENDER_MODELS) + def load_models(): """Load the Kraken model at startup. Skips if model file is not present.""" @@ -29,10 +94,12 @@ def is_available() -> bool: return _model is not None -def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict]: +def extract_page_blocks(image, page_idx: int, language: str = "de", + sender_model_path: str | None = None) -> list[dict]: """Run Kraken segmentation + recognition on a single PIL image. Returns block dicts for that page. Coordinates are normalized to [0, 1]. + When sender_model_path is provided, the per-sender fine-tuned model is used. """ from kraken import blla, rpred from confidence import words_from_characters @@ -40,11 +107,13 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict if _model is None: raise RuntimeError("Kraken model is not loaded") + active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model + page_w, page_h = image.size blocks = [] baseline_seg = blla.segment(image) - pred_it = rpred.rpred(_model, image, baseline_seg) + pred_it = rpred.rpred(active_model, image, baseline_seg) for record in pred_it: polygon_pts = record.boundary if hasattr(record, "boundary") and record.boundary else [] @@ -79,13 +148,15 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict return blocks -def extract_region_text(image, x: float, y: float, w: float, h: float) -> str: +def extract_region_text(image, x: float, y: float, w: float, h: float, + sender_model_path: str | None = None) -> str: """Crop image to a normalized region and run Kraken recognition on the crop. Used for guided OCR — skips full-page layout detection entirely. A single synthetic baseline spanning the full crop width is used so that blla.segment() (which crashes on small crops) is never called. Coordinates are normalized to [0, 1]. + When sender_model_path is provided, the per-sender fine-tuned model is used. """ from kraken import rpred from kraken.containers import Segmentation, BaselineLine @@ -93,6 +164,8 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str: if _model is None: raise RuntimeError("Kraken model is not loaded") + active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model + pw, ph = image.size x1 = max(0, int(x * pw)) y1 = max(0, int(y * ph)) @@ -123,11 +196,12 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str: regions={}, line_orders=[], ) - pred_it = rpred.rpred(_model, crop, synthetic_seg) + pred_it = rpred.rpred(active_model, crop, synthetic_seg) return " ".join(r.prediction for r in pred_it) -def extract_blocks(images: list, language: str = "de") -> list[dict]: +def extract_blocks(images: list, language: str = "de", + sender_model_path: str | None = None) -> list[dict]: """Run Kraken segmentation + recognition on a list of PIL images. Returns block dicts with pageNumber, x, y, width, height, polygon, text. @@ -137,7 +211,7 @@ def extract_blocks(images: list, language: str = "de") -> list[dict]: all_blocks = [] for page_idx, image in enumerate(images, start=1): - all_blocks.extend(extract_page_blocks(image, page_idx, language)) + all_blocks.extend(extract_page_blocks(image, page_idx, language, sender_model_path)) return all_blocks diff --git a/ocr-service/main.py b/ocr-service/main.py index 77c5c33f..bc541c78 100644 --- a/ocr-service/main.py +++ b/ocr-service/main.py @@ -17,7 +17,7 @@ from urllib.parse import urlparse import httpx import pypdfium2 as pdfium -from fastapi import FastAPI, Header, HTTPException, UploadFile +from fastapi import FastAPI, Form, Header, HTTPException, UploadFile from fastapi.responses import StreamingResponse from PIL import Image @@ -103,7 +103,8 @@ async def run_ocr(request: OcrRequest): status_code=400, detail="Kraken model not available — cannot process Kurrent script", ) - blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language) + blocks = await asyncio.to_thread( + kraken_engine.extract_blocks, images, request.language, request.senderModelPath) else: # TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language) @@ -171,10 +172,12 @@ async def run_ocr_stream(request: OcrRequest): yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n" image = await asyncio.to_thread(preprocess_page, image) blocks = [] + sender_path = request.senderModelPath if use_kraken else None for region in page_regions: text = await asyncio.to_thread( engine.extract_region_text, image, region.x, region.y, region.width, region.height, + sender_path, ) if script_type in _SPELL_CHECK_SCRIPT_TYPES: text = correct_text(text) @@ -231,8 +234,9 @@ async def run_ocr_stream(request: OcrRequest): try: yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n" image = await asyncio.to_thread(preprocess_page, image) + sender_path = request.senderModelPath if use_kraken else None blocks = await asyncio.to_thread( - engine.extract_page_blocks, image, page_idx, request.language + engine.extract_page_blocks, image, page_idx, request.language, sender_path ) for block in blocks: @@ -443,6 +447,86 @@ async def train_model( return result +@app.post("/train-sender") +async def train_sender_model( + file: UploadFile, + output_model_path: str = Form(...), + x_training_token: str | None = Header(default=None), +): + """Fine-tune a per-sender Kurrent model and save to output_model_path. + + output_model_path must be within /app/models/. The per-sender model cache + is invalidated after training so the next OCR request picks up the new model. + """ + _check_training_token(x_training_token) + + if not _models_ready: + raise HTTPException(status_code=503, detail="Models not loaded yet") + + models_dir = os.path.realpath("/app/models") + resolved_output = os.path.realpath(output_model_path) + if not resolved_output.startswith(models_dir + os.sep): + raise HTTPException(status_code=400, detail="output_model_path must be within /app/models/") + + zip_bytes = await file.read() + + training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id}) + log.info("Starting sender training run %s → %s", training_run_id, output_model_path) + + def _run_sender_training() -> dict: + with tempfile.TemporaryDirectory() as tmp_dir: + with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: + for entry in zf.namelist(): + _validate_zip_entry(entry, tmp_dir) + zf.extractall(tmp_dir) + + ground_truth = glob.glob(os.path.join(tmp_dir, "*.xml")) + if not ground_truth: + raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP") + + log.info("Sender training on %d ground-truth pairs", len(ground_truth)) + checkpoint_dir = os.path.join(tmp_dir, "checkpoints") + os.makedirs(checkpoint_dir, exist_ok=True) + + cmd = [ + "ketos", "--workers", "0", "--device", "cpu", "--threads", "2", + "train", + "-f", "page", + "--weights-format", "coreml", + "-o", checkpoint_dir, + "-q", "fixed", + "-N", "10", + "-B", "1", + ] + if os.path.exists(KRAKEN_MODEL_PATH): + cmd += ["-i", KRAKEN_MODEL_PATH] + cmd += ground_truth + + log.info("Running sender training: %s", " ".join(cmd[:6]) + " ...") + proc = subprocess.run(cmd, capture_output=True, text=True) + if proc.returncode != 0: + raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}") + + accuracy, epochs = _parse_best_checkpoint(checkpoint_dir) + cer = round(1.0 - accuracy, 4) if accuracy is not None else None + + best_model = _find_best_model(checkpoint_dir) + if best_model is None: + raise RuntimeError("Sender training produced no model file") + + os.makedirs(os.path.dirname(resolved_output), exist_ok=True) + shutil.copy2(best_model, resolved_output) + log.info("Saved sender model to %s", resolved_output) + + kraken_engine._sender_registry.invalidate(output_model_path) + + return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} + + result = await asyncio.to_thread(_run_sender_training) + return result + + @app.post("/segtrain") async def segtrain_model( file: UploadFile, diff --git a/ocr-service/models.py b/ocr-service/models.py index 28389df0..2b006ee0 100644 --- a/ocr-service/models.py +++ b/ocr-service/models.py @@ -19,6 +19,7 @@ class OcrRequest(BaseModel): scriptType: str = "UNKNOWN" language: str = "de" regions: list[OcrRegion] | None = None + senderModelPath: str | None = None class OcrBlock(BaseModel): diff --git a/ocr-service/test_sender_registry.py b/ocr-service/test_sender_registry.py new file mode 100644 index 00000000..e5a7c97d --- /dev/null +++ b/ocr-service/test_sender_registry.py @@ -0,0 +1,65 @@ +"""Tests for the per-sender model LRU registry in engines/kraken.py.""" +from unittest.mock import MagicMock, call, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_registry(max_size=5): + from engines.kraken import _SenderModelRegistry + return _SenderModelRegistry(max_size=max_size) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_cache_hit_returns_same_object(): + """Second get_model call with the same path must return the cached object.""" + registry = _make_registry() + mock_model = MagicMock(name="model_a") + with patch("engines.kraken._load_sender_model", return_value=mock_model) as loader: + m1 = registry.get_model("/app/models/sender_a.mlmodel") + m2 = registry.get_model("/app/models/sender_a.mlmodel") + + assert m1 is m2 + loader.assert_called_once() # only loaded once despite two gets + + +def test_lru_eviction_removes_least_recently_used(): + """When the cache exceeds max_size, the oldest-accessed entry is evicted.""" + registry = _make_registry(max_size=2) + + def _side_effect(path): + return MagicMock(name=path) + + with patch("engines.kraken._load_sender_model", side_effect=_side_effect): + registry.get_model("/app/models/sender_a.mlmodel") + registry.get_model("/app/models/sender_b.mlmodel") + registry.get_model("/app/models/sender_c.mlmodel") # should evict 'a' + + assert registry.size() == 2 + # 'a' was the least-recently-used and should be gone + assert not registry._contains("/app/models/sender_a.mlmodel") + + +def test_invalidate_removes_entry_from_cache(): + """invalidate() must evict the entry so the next get reloads from disk.""" + registry = _make_registry() + mock_model = MagicMock(name="model_x") + with patch("engines.kraken._load_sender_model", return_value=mock_model): + registry.get_model("/app/models/sender_x.mlmodel") + assert registry.size() == 1 + + registry.invalidate("/app/models/sender_x.mlmodel") + assert registry.size() == 0 + + +def test_path_outside_models_dir_raises(): + """get_model must reject paths outside /app/models/ (path traversal guard).""" + registry = _make_registry() + with pytest.raises(ValueError, match="not allowed"): + registry.get_model("/etc/passwd")