feat(ocr): per-sender model registry and /train-sender endpoint

engines/kraken.py:
- Add _SenderModelRegistry with LRU eviction (max configurable via
  OCR_MAX_CACHED_MODELS env var), double-checked locking, invalidate(),
  and path whitelist (/app/models/ only)
- Add _load_sender_model() helper for testability
- extract_page_blocks() and extract_region_text() accept optional
  sender_model_path; route to sender registry when provided

models.py:
- OcrRequest gains senderModelPath: str | None = None field

main.py:
- /ocr and /ocr/stream pass request.senderModelPath to Kraken engine
- New /train-sender endpoint: validates output_model_path, runs ketos
  train with base model as starting point, invalidates sender cache

docker-compose.yml:
- Add OCR_MAX_CACHED_MODELS: "5" to ocr-service environment

test_sender_registry.py:
- 4 tests: cache hit, LRU eviction, invalidate, path traversal guard

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-17 18:05:39 +02:00
parent 7a342a07cf
commit 64d27d6d61
5 changed files with 234 additions and 9 deletions

View File

@@ -97,6 +97,7 @@ services:
DETECTOR_BATCH_SIZE: "8"
OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency)
OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page)
OCR_MAX_CACHED_MODELS: "5" # LRU cache size for per-sender Kraken models
networks:
- archive-net
healthcheck:

View File

@@ -1,13 +1,78 @@
"""Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin."""
import collections
import logging
import os
import threading
logger = logging.getLogger(__name__)
_model = None
_model_path = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
_MODELS_DIR = os.path.realpath("/app/models")
_MAX_CACHED_SENDER_MODELS = int(os.environ.get("OCR_MAX_CACHED_MODELS", "5"))
def _load_sender_model(path: str):
"""Load a Kraken model from disk. Extracted for testability."""
from kraken.lib import models as kraken_models
return kraken_models.load_any(path)
class _SenderModelRegistry:
"""Thread-safe LRU cache for per-sender Kraken models.
Uses double-checked locking: model loading happens outside the lock to
avoid blocking concurrent OCR requests. At most one entry per path is
stored even under concurrent load.
"""
def __init__(self, max_size: int):
self._max_size = max_size
self._cache: collections.OrderedDict = collections.OrderedDict()
self._lock = threading.Lock()
def get_model(self, model_path: str):
"""Return the cached model or load it. Validates path is within /app/models/."""
resolved = os.path.realpath(model_path)
if not resolved.startswith(_MODELS_DIR + os.sep) and resolved != _MODELS_DIR:
raise ValueError(f"Sender model path not allowed: {model_path}")
with self._lock:
if model_path in self._cache:
self._cache.move_to_end(model_path)
return self._cache[model_path]
new_model = _load_sender_model(model_path)
with self._lock:
if model_path in self._cache:
self._cache.move_to_end(model_path)
return self._cache[model_path]
self._cache[model_path] = new_model
self._cache.move_to_end(model_path)
while len(self._cache) > self._max_size:
self._cache.popitem(last=False)
return new_model
def invalidate(self, model_path: str) -> None:
"""Remove model from cache so the next request reloads from disk."""
with self._lock:
self._cache.pop(model_path, None)
def size(self) -> int:
with self._lock:
return len(self._cache)
def _contains(self, model_path: str) -> bool:
with self._lock:
return model_path in self._cache
_sender_registry = _SenderModelRegistry(_MAX_CACHED_SENDER_MODELS)
def load_models():
"""Load the Kraken model at startup. Skips if model file is not present."""
@@ -29,10 +94,12 @@ def is_available() -> bool:
return _model is not None
def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict]:
def extract_page_blocks(image, page_idx: int, language: str = "de",
sender_model_path: str | None = None) -> list[dict]:
"""Run Kraken segmentation + recognition on a single PIL image.
Returns block dicts for that page. Coordinates are normalized to [0, 1].
When sender_model_path is provided, the per-sender fine-tuned model is used.
"""
from kraken import blla, rpred
from confidence import words_from_characters
@@ -40,11 +107,13 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
if _model is None:
raise RuntimeError("Kraken model is not loaded")
active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model
page_w, page_h = image.size
blocks = []
baseline_seg = blla.segment(image)
pred_it = rpred.rpred(_model, image, baseline_seg)
pred_it = rpred.rpred(active_model, image, baseline_seg)
for record in pred_it:
polygon_pts = record.boundary if hasattr(record, "boundary") and record.boundary else []
@@ -79,13 +148,15 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
return blocks
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
def extract_region_text(image, x: float, y: float, w: float, h: float,
sender_model_path: str | None = None) -> str:
"""Crop image to a normalized region and run Kraken recognition on the crop.
Used for guided OCR — skips full-page layout detection entirely.
A single synthetic baseline spanning the full crop width is used so that
blla.segment() (which crashes on small crops) is never called.
Coordinates are normalized to [0, 1].
When sender_model_path is provided, the per-sender fine-tuned model is used.
"""
from kraken import rpred
from kraken.containers import Segmentation, BaselineLine
@@ -93,6 +164,8 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
if _model is None:
raise RuntimeError("Kraken model is not loaded")
active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model
pw, ph = image.size
x1 = max(0, int(x * pw))
y1 = max(0, int(y * ph))
@@ -123,11 +196,12 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
regions={},
line_orders=[],
)
pred_it = rpred.rpred(_model, crop, synthetic_seg)
pred_it = rpred.rpred(active_model, crop, synthetic_seg)
return " ".join(r.prediction for r in pred_it)
def extract_blocks(images: list, language: str = "de") -> list[dict]:
def extract_blocks(images: list, language: str = "de",
sender_model_path: str | None = None) -> list[dict]:
"""Run Kraken segmentation + recognition on a list of PIL images.
Returns block dicts with pageNumber, x, y, width, height, polygon, text.
@@ -137,7 +211,7 @@ def extract_blocks(images: list, language: str = "de") -> list[dict]:
all_blocks = []
for page_idx, image in enumerate(images, start=1):
all_blocks.extend(extract_page_blocks(image, page_idx, language))
all_blocks.extend(extract_page_blocks(image, page_idx, language, sender_model_path))
return all_blocks

View File

@@ -17,7 +17,7 @@ from urllib.parse import urlparse
import httpx
import pypdfium2 as pdfium
from fastapi import FastAPI, Header, HTTPException, UploadFile
from fastapi import FastAPI, Form, Header, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from PIL import Image
@@ -103,7 +103,8 @@ async def run_ocr(request: OcrRequest):
status_code=400,
detail="Kraken model not available — cannot process Kurrent script",
)
blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language)
blocks = await asyncio.to_thread(
kraken_engine.extract_blocks, images, request.language, request.senderModelPath)
else:
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
@@ -171,10 +172,12 @@ async def run_ocr_stream(request: OcrRequest):
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
image = await asyncio.to_thread(preprocess_page, image)
blocks = []
sender_path = request.senderModelPath if use_kraken else None
for region in page_regions:
text = await asyncio.to_thread(
engine.extract_region_text, image,
region.x, region.y, region.width, region.height,
sender_path,
)
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
text = correct_text(text)
@@ -231,8 +234,9 @@ async def run_ocr_stream(request: OcrRequest):
try:
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
image = await asyncio.to_thread(preprocess_page, image)
sender_path = request.senderModelPath if use_kraken else None
blocks = await asyncio.to_thread(
engine.extract_page_blocks, image, page_idx, request.language
engine.extract_page_blocks, image, page_idx, request.language, sender_path
)
for block in blocks:
@@ -443,6 +447,86 @@ async def train_model(
return result
@app.post("/train-sender")
async def train_sender_model(
file: UploadFile,
output_model_path: str = Form(...),
x_training_token: str | None = Header(default=None),
):
"""Fine-tune a per-sender Kurrent model and save to output_model_path.
output_model_path must be within /app/models/. The per-sender model cache
is invalidated after training so the next OCR request picks up the new model.
"""
_check_training_token(x_training_token)
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
models_dir = os.path.realpath("/app/models")
resolved_output = os.path.realpath(output_model_path)
if not resolved_output.startswith(models_dir + os.sep):
raise HTTPException(status_code=400, detail="output_model_path must be within /app/models/")
zip_bytes = await file.read()
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
log.info("Starting sender training run %s%s", training_run_id, output_model_path)
def _run_sender_training() -> dict:
with tempfile.TemporaryDirectory() as tmp_dir:
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
for entry in zf.namelist():
_validate_zip_entry(entry, tmp_dir)
zf.extractall(tmp_dir)
ground_truth = glob.glob(os.path.join(tmp_dir, "*.xml"))
if not ground_truth:
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
log.info("Sender training on %d ground-truth pairs", len(ground_truth))
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
cmd = [
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
"train",
"-f", "page",
"--weights-format", "coreml",
"-o", checkpoint_dir,
"-q", "fixed",
"-N", "10",
"-B", "1",
]
if os.path.exists(KRAKEN_MODEL_PATH):
cmd += ["-i", KRAKEN_MODEL_PATH]
cmd += ground_truth
log.info("Running sender training: %s", " ".join(cmd[:6]) + " ...")
proc = subprocess.run(cmd, capture_output=True, text=True)
if proc.returncode != 0:
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
best_model = _find_best_model(checkpoint_dir)
if best_model is None:
raise RuntimeError("Sender training produced no model file")
os.makedirs(os.path.dirname(resolved_output), exist_ok=True)
shutil.copy2(best_model, resolved_output)
log.info("Saved sender model to %s", resolved_output)
kraken_engine._sender_registry.invalidate(output_model_path)
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_sender_training)
return result
@app.post("/segtrain")
async def segtrain_model(
file: UploadFile,

View File

@@ -19,6 +19,7 @@ class OcrRequest(BaseModel):
scriptType: str = "UNKNOWN"
language: str = "de"
regions: list[OcrRegion] | None = None
senderModelPath: str | None = None
class OcrBlock(BaseModel):

View File

@@ -0,0 +1,65 @@
"""Tests for the per-sender model LRU registry in engines/kraken.py."""
from unittest.mock import MagicMock, call, patch
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_registry(max_size=5):
from engines.kraken import _SenderModelRegistry
return _SenderModelRegistry(max_size=max_size)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_cache_hit_returns_same_object():
"""Second get_model call with the same path must return the cached object."""
registry = _make_registry()
mock_model = MagicMock(name="model_a")
with patch("engines.kraken._load_sender_model", return_value=mock_model) as loader:
m1 = registry.get_model("/app/models/sender_a.mlmodel")
m2 = registry.get_model("/app/models/sender_a.mlmodel")
assert m1 is m2
loader.assert_called_once() # only loaded once despite two gets
def test_lru_eviction_removes_least_recently_used():
"""When the cache exceeds max_size, the oldest-accessed entry is evicted."""
registry = _make_registry(max_size=2)
def _side_effect(path):
return MagicMock(name=path)
with patch("engines.kraken._load_sender_model", side_effect=_side_effect):
registry.get_model("/app/models/sender_a.mlmodel")
registry.get_model("/app/models/sender_b.mlmodel")
registry.get_model("/app/models/sender_c.mlmodel") # should evict 'a'
assert registry.size() == 2
# 'a' was the least-recently-used and should be gone
assert not registry._contains("/app/models/sender_a.mlmodel")
def test_invalidate_removes_entry_from_cache():
"""invalidate() must evict the entry so the next get reloads from disk."""
registry = _make_registry()
mock_model = MagicMock(name="model_x")
with patch("engines.kraken._load_sender_model", return_value=mock_model):
registry.get_model("/app/models/sender_x.mlmodel")
assert registry.size() == 1
registry.invalidate("/app/models/sender_x.mlmodel")
assert registry.size() == 0
def test_path_outside_models_dir_raises():
"""get_model must reject paths outside /app/models/ (path traversal guard)."""
registry = _make_registry()
with pytest.raises(ValueError, match="not allowed"):
registry.get_model("/etc/passwd")