feat(ocr): per-sender model registry and /train-sender endpoint
engines/kraken.py: - Add _SenderModelRegistry with LRU eviction (max configurable via OCR_MAX_CACHED_MODELS env var), double-checked locking, invalidate(), and path whitelist (/app/models/ only) - Add _load_sender_model() helper for testability - extract_page_blocks() and extract_region_text() accept optional sender_model_path; route to sender registry when provided models.py: - OcrRequest gains senderModelPath: str | None = None field main.py: - /ocr and /ocr/stream pass request.senderModelPath to Kraken engine - New /train-sender endpoint: validates output_model_path, runs ketos train with base model as starting point, invalidates sender cache docker-compose.yml: - Add OCR_MAX_CACHED_MODELS: "5" to ocr-service environment test_sender_registry.py: - 4 tests: cache hit, LRU eviction, invalidate, path traversal guard Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import pypdfium2 as pdfium
|
||||
from fastapi import FastAPI, Header, HTTPException, UploadFile
|
||||
from fastapi import FastAPI, Form, Header, HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from PIL import Image
|
||||
|
||||
@@ -103,7 +103,8 @@ async def run_ocr(request: OcrRequest):
|
||||
status_code=400,
|
||||
detail="Kraken model not available — cannot process Kurrent script",
|
||||
)
|
||||
blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language)
|
||||
blocks = await asyncio.to_thread(
|
||||
kraken_engine.extract_blocks, images, request.language, request.senderModelPath)
|
||||
else:
|
||||
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
|
||||
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
|
||||
@@ -171,10 +172,12 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
||||
image = await asyncio.to_thread(preprocess_page, image)
|
||||
blocks = []
|
||||
sender_path = request.senderModelPath if use_kraken else None
|
||||
for region in page_regions:
|
||||
text = await asyncio.to_thread(
|
||||
engine.extract_region_text, image,
|
||||
region.x, region.y, region.width, region.height,
|
||||
sender_path,
|
||||
)
|
||||
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
|
||||
text = correct_text(text)
|
||||
@@ -231,8 +234,9 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
try:
|
||||
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
||||
image = await asyncio.to_thread(preprocess_page, image)
|
||||
sender_path = request.senderModelPath if use_kraken else None
|
||||
blocks = await asyncio.to_thread(
|
||||
engine.extract_page_blocks, image, page_idx, request.language
|
||||
engine.extract_page_blocks, image, page_idx, request.language, sender_path
|
||||
)
|
||||
|
||||
for block in blocks:
|
||||
@@ -443,6 +447,86 @@ async def train_model(
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/train-sender")
|
||||
async def train_sender_model(
|
||||
file: UploadFile,
|
||||
output_model_path: str = Form(...),
|
||||
x_training_token: str | None = Header(default=None),
|
||||
):
|
||||
"""Fine-tune a per-sender Kurrent model and save to output_model_path.
|
||||
|
||||
output_model_path must be within /app/models/. The per-sender model cache
|
||||
is invalidated after training so the next OCR request picks up the new model.
|
||||
"""
|
||||
_check_training_token(x_training_token)
|
||||
|
||||
if not _models_ready:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||
|
||||
models_dir = os.path.realpath("/app/models")
|
||||
resolved_output = os.path.realpath(output_model_path)
|
||||
if not resolved_output.startswith(models_dir + os.sep):
|
||||
raise HTTPException(status_code=400, detail="output_model_path must be within /app/models/")
|
||||
|
||||
zip_bytes = await file.read()
|
||||
|
||||
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
||||
log.info("Starting sender training run %s → %s", training_run_id, output_model_path)
|
||||
|
||||
def _run_sender_training() -> dict:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
for entry in zf.namelist():
|
||||
_validate_zip_entry(entry, tmp_dir)
|
||||
zf.extractall(tmp_dir)
|
||||
|
||||
ground_truth = glob.glob(os.path.join(tmp_dir, "*.xml"))
|
||||
if not ground_truth:
|
||||
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
|
||||
|
||||
log.info("Sender training on %d ground-truth pairs", len(ground_truth))
|
||||
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
cmd = [
|
||||
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
|
||||
"train",
|
||||
"-f", "page",
|
||||
"--weights-format", "coreml",
|
||||
"-o", checkpoint_dir,
|
||||
"-q", "fixed",
|
||||
"-N", "10",
|
||||
"-B", "1",
|
||||
]
|
||||
if os.path.exists(KRAKEN_MODEL_PATH):
|
||||
cmd += ["-i", KRAKEN_MODEL_PATH]
|
||||
cmd += ground_truth
|
||||
|
||||
log.info("Running sender training: %s", " ".join(cmd[:6]) + " ...")
|
||||
proc = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
||||
|
||||
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
||||
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
||||
|
||||
best_model = _find_best_model(checkpoint_dir)
|
||||
if best_model is None:
|
||||
raise RuntimeError("Sender training produced no model file")
|
||||
|
||||
os.makedirs(os.path.dirname(resolved_output), exist_ok=True)
|
||||
shutil.copy2(best_model, resolved_output)
|
||||
log.info("Saved sender model to %s", resolved_output)
|
||||
|
||||
kraken_engine._sender_registry.invalidate(output_model_path)
|
||||
|
||||
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_sender_training)
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/segtrain")
|
||||
async def segtrain_model(
|
||||
file: UploadFile,
|
||||
|
||||
Reference in New Issue
Block a user