feat(ocr): per-sender model registry and /train-sender endpoint
engines/kraken.py: - Add _SenderModelRegistry with LRU eviction (max configurable via OCR_MAX_CACHED_MODELS env var), double-checked locking, invalidate(), and path whitelist (/app/models/ only) - Add _load_sender_model() helper for testability - extract_page_blocks() and extract_region_text() accept optional sender_model_path; route to sender registry when provided models.py: - OcrRequest gains senderModelPath: str | None = None field main.py: - /ocr and /ocr/stream pass request.senderModelPath to Kraken engine - New /train-sender endpoint: validates output_model_path, runs ketos train with base model as starting point, invalidates sender cache docker-compose.yml: - Add OCR_MAX_CACHED_MODELS: "5" to ocr-service environment test_sender_registry.py: - 4 tests: cache hit, LRU eviction, invalidate, path traversal guard Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -97,6 +97,7 @@ services:
|
|||||||
DETECTOR_BATCH_SIZE: "8"
|
DETECTOR_BATCH_SIZE: "8"
|
||||||
OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency)
|
OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency)
|
||||||
OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page)
|
OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page)
|
||||||
|
OCR_MAX_CACHED_MODELS: "5" # LRU cache size for per-sender Kraken models
|
||||||
networks:
|
networks:
|
||||||
- archive-net
|
- archive-net
|
||||||
healthcheck:
|
healthcheck:
|
||||||
|
|||||||
@@ -1,13 +1,78 @@
|
|||||||
"""Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin."""
|
"""Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin."""
|
||||||
|
|
||||||
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_model = None
|
_model = None
|
||||||
_model_path = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
_model_path = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
||||||
|
|
||||||
|
_MODELS_DIR = os.path.realpath("/app/models")
|
||||||
|
_MAX_CACHED_SENDER_MODELS = int(os.environ.get("OCR_MAX_CACHED_MODELS", "5"))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_sender_model(path: str):
|
||||||
|
"""Load a Kraken model from disk. Extracted for testability."""
|
||||||
|
from kraken.lib import models as kraken_models
|
||||||
|
return kraken_models.load_any(path)
|
||||||
|
|
||||||
|
|
||||||
|
class _SenderModelRegistry:
|
||||||
|
"""Thread-safe LRU cache for per-sender Kraken models.
|
||||||
|
|
||||||
|
Uses double-checked locking: model loading happens outside the lock to
|
||||||
|
avoid blocking concurrent OCR requests. At most one entry per path is
|
||||||
|
stored even under concurrent load.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_size: int):
|
||||||
|
self._max_size = max_size
|
||||||
|
self._cache: collections.OrderedDict = collections.OrderedDict()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def get_model(self, model_path: str):
|
||||||
|
"""Return the cached model or load it. Validates path is within /app/models/."""
|
||||||
|
resolved = os.path.realpath(model_path)
|
||||||
|
if not resolved.startswith(_MODELS_DIR + os.sep) and resolved != _MODELS_DIR:
|
||||||
|
raise ValueError(f"Sender model path not allowed: {model_path}")
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if model_path in self._cache:
|
||||||
|
self._cache.move_to_end(model_path)
|
||||||
|
return self._cache[model_path]
|
||||||
|
|
||||||
|
new_model = _load_sender_model(model_path)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if model_path in self._cache:
|
||||||
|
self._cache.move_to_end(model_path)
|
||||||
|
return self._cache[model_path]
|
||||||
|
self._cache[model_path] = new_model
|
||||||
|
self._cache.move_to_end(model_path)
|
||||||
|
while len(self._cache) > self._max_size:
|
||||||
|
self._cache.popitem(last=False)
|
||||||
|
|
||||||
|
return new_model
|
||||||
|
|
||||||
|
def invalidate(self, model_path: str) -> None:
|
||||||
|
"""Remove model from cache so the next request reloads from disk."""
|
||||||
|
with self._lock:
|
||||||
|
self._cache.pop(model_path, None)
|
||||||
|
|
||||||
|
def size(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return len(self._cache)
|
||||||
|
|
||||||
|
def _contains(self, model_path: str) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
return model_path in self._cache
|
||||||
|
|
||||||
|
|
||||||
|
_sender_registry = _SenderModelRegistry(_MAX_CACHED_SENDER_MODELS)
|
||||||
|
|
||||||
|
|
||||||
def load_models():
|
def load_models():
|
||||||
"""Load the Kraken model at startup. Skips if model file is not present."""
|
"""Load the Kraken model at startup. Skips if model file is not present."""
|
||||||
@@ -29,10 +94,12 @@ def is_available() -> bool:
|
|||||||
return _model is not None
|
return _model is not None
|
||||||
|
|
||||||
|
|
||||||
def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict]:
|
def extract_page_blocks(image, page_idx: int, language: str = "de",
|
||||||
|
sender_model_path: str | None = None) -> list[dict]:
|
||||||
"""Run Kraken segmentation + recognition on a single PIL image.
|
"""Run Kraken segmentation + recognition on a single PIL image.
|
||||||
|
|
||||||
Returns block dicts for that page. Coordinates are normalized to [0, 1].
|
Returns block dicts for that page. Coordinates are normalized to [0, 1].
|
||||||
|
When sender_model_path is provided, the per-sender fine-tuned model is used.
|
||||||
"""
|
"""
|
||||||
from kraken import blla, rpred
|
from kraken import blla, rpred
|
||||||
from confidence import words_from_characters
|
from confidence import words_from_characters
|
||||||
@@ -40,11 +107,13 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
|||||||
if _model is None:
|
if _model is None:
|
||||||
raise RuntimeError("Kraken model is not loaded")
|
raise RuntimeError("Kraken model is not loaded")
|
||||||
|
|
||||||
|
active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model
|
||||||
|
|
||||||
page_w, page_h = image.size
|
page_w, page_h = image.size
|
||||||
blocks = []
|
blocks = []
|
||||||
|
|
||||||
baseline_seg = blla.segment(image)
|
baseline_seg = blla.segment(image)
|
||||||
pred_it = rpred.rpred(_model, image, baseline_seg)
|
pred_it = rpred.rpred(active_model, image, baseline_seg)
|
||||||
|
|
||||||
for record in pred_it:
|
for record in pred_it:
|
||||||
polygon_pts = record.boundary if hasattr(record, "boundary") and record.boundary else []
|
polygon_pts = record.boundary if hasattr(record, "boundary") and record.boundary else []
|
||||||
@@ -79,13 +148,15 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
|||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
def extract_region_text(image, x: float, y: float, w: float, h: float,
|
||||||
|
sender_model_path: str | None = None) -> str:
|
||||||
"""Crop image to a normalized region and run Kraken recognition on the crop.
|
"""Crop image to a normalized region and run Kraken recognition on the crop.
|
||||||
|
|
||||||
Used for guided OCR — skips full-page layout detection entirely.
|
Used for guided OCR — skips full-page layout detection entirely.
|
||||||
A single synthetic baseline spanning the full crop width is used so that
|
A single synthetic baseline spanning the full crop width is used so that
|
||||||
blla.segment() (which crashes on small crops) is never called.
|
blla.segment() (which crashes on small crops) is never called.
|
||||||
Coordinates are normalized to [0, 1].
|
Coordinates are normalized to [0, 1].
|
||||||
|
When sender_model_path is provided, the per-sender fine-tuned model is used.
|
||||||
"""
|
"""
|
||||||
from kraken import rpred
|
from kraken import rpred
|
||||||
from kraken.containers import Segmentation, BaselineLine
|
from kraken.containers import Segmentation, BaselineLine
|
||||||
@@ -93,6 +164,8 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
|||||||
if _model is None:
|
if _model is None:
|
||||||
raise RuntimeError("Kraken model is not loaded")
|
raise RuntimeError("Kraken model is not loaded")
|
||||||
|
|
||||||
|
active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model
|
||||||
|
|
||||||
pw, ph = image.size
|
pw, ph = image.size
|
||||||
x1 = max(0, int(x * pw))
|
x1 = max(0, int(x * pw))
|
||||||
y1 = max(0, int(y * ph))
|
y1 = max(0, int(y * ph))
|
||||||
@@ -123,11 +196,12 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
|||||||
regions={},
|
regions={},
|
||||||
line_orders=[],
|
line_orders=[],
|
||||||
)
|
)
|
||||||
pred_it = rpred.rpred(_model, crop, synthetic_seg)
|
pred_it = rpred.rpred(active_model, crop, synthetic_seg)
|
||||||
return " ".join(r.prediction for r in pred_it)
|
return " ".join(r.prediction for r in pred_it)
|
||||||
|
|
||||||
|
|
||||||
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
def extract_blocks(images: list, language: str = "de",
|
||||||
|
sender_model_path: str | None = None) -> list[dict]:
|
||||||
"""Run Kraken segmentation + recognition on a list of PIL images.
|
"""Run Kraken segmentation + recognition on a list of PIL images.
|
||||||
|
|
||||||
Returns block dicts with pageNumber, x, y, width, height, polygon, text.
|
Returns block dicts with pageNumber, x, y, width, height, polygon, text.
|
||||||
@@ -137,7 +211,7 @@ def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
|||||||
all_blocks = []
|
all_blocks = []
|
||||||
|
|
||||||
for page_idx, image in enumerate(images, start=1):
|
for page_idx, image in enumerate(images, start=1):
|
||||||
all_blocks.extend(extract_page_blocks(image, page_idx, language))
|
all_blocks.extend(extract_page_blocks(image, page_idx, language, sender_model_path))
|
||||||
|
|
||||||
return all_blocks
|
return all_blocks
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pypdfium2 as pdfium
|
import pypdfium2 as pdfium
|
||||||
from fastapi import FastAPI, Header, HTTPException, UploadFile
|
from fastapi import FastAPI, Form, Header, HTTPException, UploadFile
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@@ -103,7 +103,8 @@ async def run_ocr(request: OcrRequest):
|
|||||||
status_code=400,
|
status_code=400,
|
||||||
detail="Kraken model not available — cannot process Kurrent script",
|
detail="Kraken model not available — cannot process Kurrent script",
|
||||||
)
|
)
|
||||||
blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language)
|
blocks = await asyncio.to_thread(
|
||||||
|
kraken_engine.extract_blocks, images, request.language, request.senderModelPath)
|
||||||
else:
|
else:
|
||||||
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
|
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
|
||||||
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
|
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
|
||||||
@@ -171,10 +172,12 @@ async def run_ocr_stream(request: OcrRequest):
|
|||||||
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
||||||
image = await asyncio.to_thread(preprocess_page, image)
|
image = await asyncio.to_thread(preprocess_page, image)
|
||||||
blocks = []
|
blocks = []
|
||||||
|
sender_path = request.senderModelPath if use_kraken else None
|
||||||
for region in page_regions:
|
for region in page_regions:
|
||||||
text = await asyncio.to_thread(
|
text = await asyncio.to_thread(
|
||||||
engine.extract_region_text, image,
|
engine.extract_region_text, image,
|
||||||
region.x, region.y, region.width, region.height,
|
region.x, region.y, region.width, region.height,
|
||||||
|
sender_path,
|
||||||
)
|
)
|
||||||
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
|
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
|
||||||
text = correct_text(text)
|
text = correct_text(text)
|
||||||
@@ -231,8 +234,9 @@ async def run_ocr_stream(request: OcrRequest):
|
|||||||
try:
|
try:
|
||||||
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
||||||
image = await asyncio.to_thread(preprocess_page, image)
|
image = await asyncio.to_thread(preprocess_page, image)
|
||||||
|
sender_path = request.senderModelPath if use_kraken else None
|
||||||
blocks = await asyncio.to_thread(
|
blocks = await asyncio.to_thread(
|
||||||
engine.extract_page_blocks, image, page_idx, request.language
|
engine.extract_page_blocks, image, page_idx, request.language, sender_path
|
||||||
)
|
)
|
||||||
|
|
||||||
for block in blocks:
|
for block in blocks:
|
||||||
@@ -443,6 +447,86 @@ async def train_model(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/train-sender")
|
||||||
|
async def train_sender_model(
|
||||||
|
file: UploadFile,
|
||||||
|
output_model_path: str = Form(...),
|
||||||
|
x_training_token: str | None = Header(default=None),
|
||||||
|
):
|
||||||
|
"""Fine-tune a per-sender Kurrent model and save to output_model_path.
|
||||||
|
|
||||||
|
output_model_path must be within /app/models/. The per-sender model cache
|
||||||
|
is invalidated after training so the next OCR request picks up the new model.
|
||||||
|
"""
|
||||||
|
_check_training_token(x_training_token)
|
||||||
|
|
||||||
|
if not _models_ready:
|
||||||
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||||
|
|
||||||
|
models_dir = os.path.realpath("/app/models")
|
||||||
|
resolved_output = os.path.realpath(output_model_path)
|
||||||
|
if not resolved_output.startswith(models_dir + os.sep):
|
||||||
|
raise HTTPException(status_code=400, detail="output_model_path must be within /app/models/")
|
||||||
|
|
||||||
|
zip_bytes = await file.read()
|
||||||
|
|
||||||
|
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||||
|
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
||||||
|
log.info("Starting sender training run %s → %s", training_run_id, output_model_path)
|
||||||
|
|
||||||
|
def _run_sender_training() -> dict:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||||
|
for entry in zf.namelist():
|
||||||
|
_validate_zip_entry(entry, tmp_dir)
|
||||||
|
zf.extractall(tmp_dir)
|
||||||
|
|
||||||
|
ground_truth = glob.glob(os.path.join(tmp_dir, "*.xml"))
|
||||||
|
if not ground_truth:
|
||||||
|
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
|
||||||
|
|
||||||
|
log.info("Sender training on %d ground-truth pairs", len(ground_truth))
|
||||||
|
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
|
||||||
|
"train",
|
||||||
|
"-f", "page",
|
||||||
|
"--weights-format", "coreml",
|
||||||
|
"-o", checkpoint_dir,
|
||||||
|
"-q", "fixed",
|
||||||
|
"-N", "10",
|
||||||
|
"-B", "1",
|
||||||
|
]
|
||||||
|
if os.path.exists(KRAKEN_MODEL_PATH):
|
||||||
|
cmd += ["-i", KRAKEN_MODEL_PATH]
|
||||||
|
cmd += ground_truth
|
||||||
|
|
||||||
|
log.info("Running sender training: %s", " ".join(cmd[:6]) + " ...")
|
||||||
|
proc = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
||||||
|
|
||||||
|
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
||||||
|
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
||||||
|
|
||||||
|
best_model = _find_best_model(checkpoint_dir)
|
||||||
|
if best_model is None:
|
||||||
|
raise RuntimeError("Sender training produced no model file")
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(resolved_output), exist_ok=True)
|
||||||
|
shutil.copy2(best_model, resolved_output)
|
||||||
|
log.info("Saved sender model to %s", resolved_output)
|
||||||
|
|
||||||
|
kraken_engine._sender_registry.invalidate(output_model_path)
|
||||||
|
|
||||||
|
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||||
|
|
||||||
|
result = await asyncio.to_thread(_run_sender_training)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@app.post("/segtrain")
|
@app.post("/segtrain")
|
||||||
async def segtrain_model(
|
async def segtrain_model(
|
||||||
file: UploadFile,
|
file: UploadFile,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class OcrRequest(BaseModel):
|
|||||||
scriptType: str = "UNKNOWN"
|
scriptType: str = "UNKNOWN"
|
||||||
language: str = "de"
|
language: str = "de"
|
||||||
regions: list[OcrRegion] | None = None
|
regions: list[OcrRegion] | None = None
|
||||||
|
senderModelPath: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class OcrBlock(BaseModel):
|
class OcrBlock(BaseModel):
|
||||||
|
|||||||
65
ocr-service/test_sender_registry.py
Normal file
65
ocr-service/test_sender_registry.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""Tests for the per-sender model LRU registry in engines/kraken.py."""
|
||||||
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_registry(max_size=5):
|
||||||
|
from engines.kraken import _SenderModelRegistry
|
||||||
|
return _SenderModelRegistry(max_size=max_size)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_cache_hit_returns_same_object():
|
||||||
|
"""Second get_model call with the same path must return the cached object."""
|
||||||
|
registry = _make_registry()
|
||||||
|
mock_model = MagicMock(name="model_a")
|
||||||
|
with patch("engines.kraken._load_sender_model", return_value=mock_model) as loader:
|
||||||
|
m1 = registry.get_model("/app/models/sender_a.mlmodel")
|
||||||
|
m2 = registry.get_model("/app/models/sender_a.mlmodel")
|
||||||
|
|
||||||
|
assert m1 is m2
|
||||||
|
loader.assert_called_once() # only loaded once despite two gets
|
||||||
|
|
||||||
|
|
||||||
|
def test_lru_eviction_removes_least_recently_used():
|
||||||
|
"""When the cache exceeds max_size, the oldest-accessed entry is evicted."""
|
||||||
|
registry = _make_registry(max_size=2)
|
||||||
|
|
||||||
|
def _side_effect(path):
|
||||||
|
return MagicMock(name=path)
|
||||||
|
|
||||||
|
with patch("engines.kraken._load_sender_model", side_effect=_side_effect):
|
||||||
|
registry.get_model("/app/models/sender_a.mlmodel")
|
||||||
|
registry.get_model("/app/models/sender_b.mlmodel")
|
||||||
|
registry.get_model("/app/models/sender_c.mlmodel") # should evict 'a'
|
||||||
|
|
||||||
|
assert registry.size() == 2
|
||||||
|
# 'a' was the least-recently-used and should be gone
|
||||||
|
assert not registry._contains("/app/models/sender_a.mlmodel")
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalidate_removes_entry_from_cache():
|
||||||
|
"""invalidate() must evict the entry so the next get reloads from disk."""
|
||||||
|
registry = _make_registry()
|
||||||
|
mock_model = MagicMock(name="model_x")
|
||||||
|
with patch("engines.kraken._load_sender_model", return_value=mock_model):
|
||||||
|
registry.get_model("/app/models/sender_x.mlmodel")
|
||||||
|
assert registry.size() == 1
|
||||||
|
|
||||||
|
registry.invalidate("/app/models/sender_x.mlmodel")
|
||||||
|
assert registry.size() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_path_outside_models_dir_raises():
|
||||||
|
"""get_model must reject paths outside /app/models/ (path traversal guard)."""
|
||||||
|
registry = _make_registry()
|
||||||
|
with pytest.raises(ValueError, match="not allowed"):
|
||||||
|
registry.get_model("/etc/passwd")
|
||||||
Reference in New Issue
Block a user