feat(ocr): per-sender model registry and /train-sender endpoint

engines/kraken.py:
- Add _SenderModelRegistry with LRU eviction (max configurable via
  OCR_MAX_CACHED_MODELS env var), double-checked locking, invalidate(),
  and path whitelist (/app/models/ only)
- Add _load_sender_model() helper for testability
- extract_page_blocks() and extract_region_text() accept optional
  sender_model_path; route to sender registry when provided

models.py:
- OcrRequest gains senderModelPath: str | None = None field

main.py:
- /ocr and /ocr/stream pass request.senderModelPath to Kraken engine
- New /train-sender endpoint: validates output_model_path, runs ketos
  train with base model as starting point, invalidates sender cache

docker-compose.yml:
- Add OCR_MAX_CACHED_MODELS: "5" to ocr-service environment

test_sender_registry.py:
- 4 tests: cache hit, LRU eviction, invalidate, path traversal guard

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-17 18:05:39 +02:00
committed by marcel
parent 548ad0fa68
commit a146a2ec3c
5 changed files with 234 additions and 9 deletions

View File

@@ -17,7 +17,7 @@ from urllib.parse import urlparse
import httpx
import pypdfium2 as pdfium
from fastapi import FastAPI, Header, HTTPException, UploadFile
from fastapi import FastAPI, Form, Header, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from PIL import Image
@@ -103,7 +103,8 @@ async def run_ocr(request: OcrRequest):
status_code=400,
detail="Kraken model not available — cannot process Kurrent script",
)
blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language)
blocks = await asyncio.to_thread(
kraken_engine.extract_blocks, images, request.language, request.senderModelPath)
else:
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
@@ -171,10 +172,12 @@ async def run_ocr_stream(request: OcrRequest):
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
image = await asyncio.to_thread(preprocess_page, image)
blocks = []
sender_path = request.senderModelPath if use_kraken else None
for region in page_regions:
text = await asyncio.to_thread(
engine.extract_region_text, image,
region.x, region.y, region.width, region.height,
sender_path,
)
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
text = correct_text(text)
@@ -231,8 +234,9 @@ async def run_ocr_stream(request: OcrRequest):
try:
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
image = await asyncio.to_thread(preprocess_page, image)
sender_path = request.senderModelPath if use_kraken else None
blocks = await asyncio.to_thread(
engine.extract_page_blocks, image, page_idx, request.language
engine.extract_page_blocks, image, page_idx, request.language, sender_path
)
for block in blocks:
@@ -443,6 +447,86 @@ async def train_model(
return result
@app.post("/train-sender")
async def train_sender_model(
file: UploadFile,
output_model_path: str = Form(...),
x_training_token: str | None = Header(default=None),
):
"""Fine-tune a per-sender Kurrent model and save to output_model_path.
output_model_path must be within /app/models/. The per-sender model cache
is invalidated after training so the next OCR request picks up the new model.
"""
_check_training_token(x_training_token)
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
models_dir = os.path.realpath("/app/models")
resolved_output = os.path.realpath(output_model_path)
if not resolved_output.startswith(models_dir + os.sep):
raise HTTPException(status_code=400, detail="output_model_path must be within /app/models/")
zip_bytes = await file.read()
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
log.info("Starting sender training run %s%s", training_run_id, output_model_path)
def _run_sender_training() -> dict:
with tempfile.TemporaryDirectory() as tmp_dir:
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
for entry in zf.namelist():
_validate_zip_entry(entry, tmp_dir)
zf.extractall(tmp_dir)
ground_truth = glob.glob(os.path.join(tmp_dir, "*.xml"))
if not ground_truth:
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
log.info("Sender training on %d ground-truth pairs", len(ground_truth))
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
cmd = [
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
"train",
"-f", "page",
"--weights-format", "coreml",
"-o", checkpoint_dir,
"-q", "fixed",
"-N", "10",
"-B", "1",
]
if os.path.exists(KRAKEN_MODEL_PATH):
cmd += ["-i", KRAKEN_MODEL_PATH]
cmd += ground_truth
log.info("Running sender training: %s", " ".join(cmd[:6]) + " ...")
proc = subprocess.run(cmd, capture_output=True, text=True)
if proc.returncode != 0:
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
best_model = _find_best_model(checkpoint_dir)
if best_model is None:
raise RuntimeError("Sender training produced no model file")
os.makedirs(os.path.dirname(resolved_output), exist_ok=True)
shutil.copy2(best_model, resolved_output)
log.info("Saved sender model to %s", resolved_output)
kraken_engine._sender_registry.invalidate(output_model_path)
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_sender_training)
return result
@app.post("/segtrain")
async def segtrain_model(
file: UploadFile,