feat(ocr): per-sender model registry and /train-sender endpoint
engines/kraken.py: - Add _SenderModelRegistry with LRU eviction (max configurable via OCR_MAX_CACHED_MODELS env var), double-checked locking, invalidate(), and path whitelist (/app/models/ only) - Add _load_sender_model() helper for testability - extract_page_blocks() and extract_region_text() accept optional sender_model_path; route to sender registry when provided models.py: - OcrRequest gains senderModelPath: str | None = None field main.py: - /ocr and /ocr/stream pass request.senderModelPath to Kraken engine - New /train-sender endpoint: validates output_model_path, runs ketos train with base model as starting point, invalidates sender cache docker-compose.yml: - Add OCR_MAX_CACHED_MODELS: "5" to ocr-service environment test_sender_registry.py: - 4 tests: cache hit, LRU eviction, invalidate, path traversal guard Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,13 +1,78 @@
|
||||
"""Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin."""
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_model = None
|
||||
_model_path = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
||||
|
||||
_MODELS_DIR = os.path.realpath("/app/models")
|
||||
_MAX_CACHED_SENDER_MODELS = int(os.environ.get("OCR_MAX_CACHED_MODELS", "5"))
|
||||
|
||||
|
||||
def _load_sender_model(path: str):
|
||||
"""Load a Kraken model from disk. Extracted for testability."""
|
||||
from kraken.lib import models as kraken_models
|
||||
return kraken_models.load_any(path)
|
||||
|
||||
|
||||
class _SenderModelRegistry:
|
||||
"""Thread-safe LRU cache for per-sender Kraken models.
|
||||
|
||||
Uses double-checked locking: model loading happens outside the lock to
|
||||
avoid blocking concurrent OCR requests. At most one entry per path is
|
||||
stored even under concurrent load.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int):
|
||||
self._max_size = max_size
|
||||
self._cache: collections.OrderedDict = collections.OrderedDict()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_model(self, model_path: str):
|
||||
"""Return the cached model or load it. Validates path is within /app/models/."""
|
||||
resolved = os.path.realpath(model_path)
|
||||
if not resolved.startswith(_MODELS_DIR + os.sep) and resolved != _MODELS_DIR:
|
||||
raise ValueError(f"Sender model path not allowed: {model_path}")
|
||||
|
||||
with self._lock:
|
||||
if model_path in self._cache:
|
||||
self._cache.move_to_end(model_path)
|
||||
return self._cache[model_path]
|
||||
|
||||
new_model = _load_sender_model(model_path)
|
||||
|
||||
with self._lock:
|
||||
if model_path in self._cache:
|
||||
self._cache.move_to_end(model_path)
|
||||
return self._cache[model_path]
|
||||
self._cache[model_path] = new_model
|
||||
self._cache.move_to_end(model_path)
|
||||
while len(self._cache) > self._max_size:
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
return new_model
|
||||
|
||||
def invalidate(self, model_path: str) -> None:
|
||||
"""Remove model from cache so the next request reloads from disk."""
|
||||
with self._lock:
|
||||
self._cache.pop(model_path, None)
|
||||
|
||||
def size(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._cache)
|
||||
|
||||
def _contains(self, model_path: str) -> bool:
|
||||
with self._lock:
|
||||
return model_path in self._cache
|
||||
|
||||
|
||||
_sender_registry = _SenderModelRegistry(_MAX_CACHED_SENDER_MODELS)
|
||||
|
||||
|
||||
def load_models():
|
||||
"""Load the Kraken model at startup. Skips if model file is not present."""
|
||||
@@ -29,10 +94,12 @@ def is_available() -> bool:
|
||||
return _model is not None
|
||||
|
||||
|
||||
def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict]:
|
||||
def extract_page_blocks(image, page_idx: int, language: str = "de",
|
||||
sender_model_path: str | None = None) -> list[dict]:
|
||||
"""Run Kraken segmentation + recognition on a single PIL image.
|
||||
|
||||
Returns block dicts for that page. Coordinates are normalized to [0, 1].
|
||||
When sender_model_path is provided, the per-sender fine-tuned model is used.
|
||||
"""
|
||||
from kraken import blla, rpred
|
||||
from confidence import words_from_characters
|
||||
@@ -40,11 +107,13 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
||||
if _model is None:
|
||||
raise RuntimeError("Kraken model is not loaded")
|
||||
|
||||
active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model
|
||||
|
||||
page_w, page_h = image.size
|
||||
blocks = []
|
||||
|
||||
baseline_seg = blla.segment(image)
|
||||
pred_it = rpred.rpred(_model, image, baseline_seg)
|
||||
pred_it = rpred.rpred(active_model, image, baseline_seg)
|
||||
|
||||
for record in pred_it:
|
||||
polygon_pts = record.boundary if hasattr(record, "boundary") and record.boundary else []
|
||||
@@ -79,13 +148,15 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
||||
return blocks
|
||||
|
||||
|
||||
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
||||
def extract_region_text(image, x: float, y: float, w: float, h: float,
|
||||
sender_model_path: str | None = None) -> str:
|
||||
"""Crop image to a normalized region and run Kraken recognition on the crop.
|
||||
|
||||
Used for guided OCR — skips full-page layout detection entirely.
|
||||
A single synthetic baseline spanning the full crop width is used so that
|
||||
blla.segment() (which crashes on small crops) is never called.
|
||||
Coordinates are normalized to [0, 1].
|
||||
When sender_model_path is provided, the per-sender fine-tuned model is used.
|
||||
"""
|
||||
from kraken import rpred
|
||||
from kraken.containers import Segmentation, BaselineLine
|
||||
@@ -93,6 +164,8 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
||||
if _model is None:
|
||||
raise RuntimeError("Kraken model is not loaded")
|
||||
|
||||
active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model
|
||||
|
||||
pw, ph = image.size
|
||||
x1 = max(0, int(x * pw))
|
||||
y1 = max(0, int(y * ph))
|
||||
@@ -123,11 +196,12 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
||||
regions={},
|
||||
line_orders=[],
|
||||
)
|
||||
pred_it = rpred.rpred(_model, crop, synthetic_seg)
|
||||
pred_it = rpred.rpred(active_model, crop, synthetic_seg)
|
||||
return " ".join(r.prediction for r in pred_it)
|
||||
|
||||
|
||||
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
||||
def extract_blocks(images: list, language: str = "de",
|
||||
sender_model_path: str | None = None) -> list[dict]:
|
||||
"""Run Kraken segmentation + recognition on a list of PIL images.
|
||||
|
||||
Returns block dicts with pageNumber, x, y, width, height, polygon, text.
|
||||
@@ -137,7 +211,7 @@ def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
||||
all_blocks = []
|
||||
|
||||
for page_idx, image in enumerate(images, start=1):
|
||||
all_blocks.extend(extract_page_blocks(image, page_idx, language))
|
||||
all_blocks.extend(extract_page_blocks(image, page_idx, language, sender_model_path))
|
||||
|
||||
return all_blocks
|
||||
|
||||
|
||||
Reference in New Issue
Block a user