Force CPU-only training (--device cpu), cap OpenMP/BLAS thread pool at 2 (--threads 2), and reduce epochs from 50 to 10 (-N 10). 50 epochs on a laptop OOM-killed the container. 10 epochs is sufficient for incremental fine-tuning runs; more data is added over time and training re-run. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
522 lines
19 KiB
Python
522 lines
19 KiB
Python
"""OCR microservice — FastAPI app with Surya and Kraken engine support."""
|
|
|
|
import asyncio
|
|
import glob
|
|
import io
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import shutil
|
|
import subprocess
|
|
import tempfile
|
|
import zipfile
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime, timezone
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
import pypdfium2 as pdfium
|
|
from fastapi import FastAPI, Header, HTTPException, UploadFile
|
|
from fastapi.responses import StreamingResponse
|
|
from PIL import Image
|
|
|
|
from confidence import apply_confidence_markers, get_threshold
|
|
from engines import kraken as kraken_engine
|
|
from engines import surya as surya_engine
|
|
from models import OcrBlock, OcrRequest
|
|
|
|
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
|
|
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_models_ready = False
|
|
|
|
ALLOWED_PDF_HOSTS = set(
|
|
h.strip() for h in os.getenv("ALLOWED_PDF_HOSTS", "minio,localhost,127.0.0.1").split(",")
|
|
)
|
|
|
|
|
|
def _validate_url(url: str) -> None:
|
|
"""Validate that the PDF URL points to an allowed host (SSRF protection)."""
|
|
parsed = urlparse(url)
|
|
hostname = parsed.hostname or ""
|
|
if hostname not in ALLOWED_PDF_HOSTS:
|
|
raise HTTPException(status_code=400, detail=f"PDF host not allowed: {hostname}")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Load lightweight models at startup. Surya loads lazily on first request."""
|
|
global _models_ready
|
|
|
|
logger.info("Loading Kraken model at startup (Surya loads lazily on first OCR request)...")
|
|
kraken_engine.load_models()
|
|
_models_ready = True
|
|
logger.info("Startup complete — ready to accept requests")
|
|
|
|
yield
|
|
|
|
logger.info("Shutting down OCR service")
|
|
|
|
|
|
app = FastAPI(title="Familienarchiv OCR Service", lifespan=lifespan)
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
"""Health endpoint — returns 200 only after models are loaded."""
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
return {"status": "ok", "surya": True, "kraken": kraken_engine.is_available()}
|
|
|
|
|
|
@app.post("/ocr", response_model=list[OcrBlock])
|
|
async def run_ocr(request: OcrRequest):
|
|
"""Run OCR on a PDF document.
|
|
|
|
Downloads the PDF from the provided URL, converts pages to images,
|
|
and runs the appropriate OCR engine based on scriptType.
|
|
OCR engines run in a thread pool so the event loop stays free for /health.
|
|
"""
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
|
|
images = await _download_and_convert_pdf(request.pdfUrl)
|
|
|
|
script_type = request.scriptType.upper()
|
|
|
|
if script_type == "HANDWRITING_KURRENT":
|
|
if not kraken_engine.is_available():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Kraken model not available — cannot process Kurrent script",
|
|
)
|
|
blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language)
|
|
else:
|
|
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
|
|
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
|
|
|
|
threshold = get_threshold(script_type)
|
|
for block in blocks:
|
|
if block.get("words"):
|
|
block["text"] = apply_confidence_markers(block["words"], threshold)
|
|
block.pop("words", None)
|
|
|
|
return [OcrBlock(**b) for b in blocks]
|
|
|
|
|
|
@app.post("/ocr/stream")
|
|
async def run_ocr_stream(request: OcrRequest):
|
|
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page.
|
|
|
|
When request.regions is provided, runs in guided mode: each region is cropped
|
|
and recognized individually, skipping full-page layout detection. The response
|
|
blocks include the annotationId from the region.
|
|
"""
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
|
|
images = await _download_and_convert_pdf(request.pdfUrl)
|
|
script_type = request.scriptType.upper()
|
|
threshold = get_threshold(script_type)
|
|
|
|
use_kraken = script_type == "HANDWRITING_KURRENT"
|
|
if use_kraken and not kraken_engine.is_available():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Kraken model not available — cannot process Kurrent script",
|
|
)
|
|
|
|
engine = kraken_engine if use_kraken else surya_engine
|
|
|
|
if request.regions:
|
|
# Guided mode: recognize only the user-drawn annotation regions
|
|
regions_by_page: dict[int, list] = {}
|
|
for region in request.regions:
|
|
regions_by_page.setdefault(region.pageNumber, []).append(region)
|
|
|
|
async def generate_guided():
|
|
total_pages = len(images)
|
|
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
|
|
|
total_blocks = 0
|
|
skipped_pages = 0
|
|
|
|
for page_idx, image in enumerate(images, start=1):
|
|
page_regions = regions_by_page.get(page_idx, [])
|
|
if not page_regions:
|
|
yield json.dumps({
|
|
"type": "page",
|
|
"pageNumber": page_idx,
|
|
"blocks": [],
|
|
}) + "\n"
|
|
del image
|
|
continue
|
|
|
|
try:
|
|
blocks = []
|
|
for region in page_regions:
|
|
text = await asyncio.to_thread(
|
|
engine.extract_region_text, image,
|
|
region.x, region.y, region.width, region.height,
|
|
)
|
|
blocks.append({
|
|
"pageNumber": page_idx,
|
|
"x": region.x,
|
|
"y": region.y,
|
|
"width": region.width,
|
|
"height": region.height,
|
|
"polygon": None,
|
|
"text": text,
|
|
"annotationId": region.annotationId,
|
|
})
|
|
|
|
total_blocks += len(blocks)
|
|
yield json.dumps({
|
|
"type": "page",
|
|
"pageNumber": page_idx,
|
|
"blocks": blocks,
|
|
}) + "\n"
|
|
|
|
except Exception:
|
|
logger.exception("Guided OCR failed on page %d", page_idx)
|
|
skipped_pages += 1
|
|
yield json.dumps({
|
|
"type": "error",
|
|
"pageNumber": page_idx,
|
|
"message": f"Guided OCR processing failed on page {page_idx}",
|
|
}) + "\n"
|
|
|
|
finally:
|
|
del image
|
|
|
|
yield json.dumps({
|
|
"type": "done",
|
|
"totalBlocks": total_blocks,
|
|
"skippedPages": skipped_pages,
|
|
}) + "\n"
|
|
|
|
return StreamingResponse(
|
|
generate_guided(),
|
|
media_type="application/x-ndjson",
|
|
headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"},
|
|
)
|
|
|
|
async def generate():
|
|
total_pages = len(images)
|
|
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
|
|
|
total_blocks = 0
|
|
skipped_pages = 0
|
|
|
|
for page_idx, image in enumerate(images, start=1):
|
|
try:
|
|
blocks = await asyncio.to_thread(
|
|
engine.extract_page_blocks, image, page_idx, request.language
|
|
)
|
|
|
|
for block in blocks:
|
|
if block.get("words"):
|
|
block["text"] = apply_confidence_markers(block["words"], threshold)
|
|
block.pop("words", None)
|
|
|
|
total_blocks += len(blocks)
|
|
yield json.dumps({
|
|
"type": "page",
|
|
"pageNumber": page_idx,
|
|
"blocks": blocks,
|
|
}) + "\n"
|
|
|
|
except Exception:
|
|
logger.exception("OCR failed on page %d", page_idx)
|
|
skipped_pages += 1
|
|
yield json.dumps({
|
|
"type": "error",
|
|
"pageNumber": page_idx,
|
|
"message": f"OCR processing failed on page {page_idx}",
|
|
}) + "\n"
|
|
|
|
finally:
|
|
del image
|
|
|
|
yield json.dumps({
|
|
"type": "done",
|
|
"totalBlocks": total_blocks,
|
|
"skippedPages": skipped_pages,
|
|
}) + "\n"
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="application/x-ndjson",
|
|
headers={
|
|
"X-Accel-Buffering": "no",
|
|
"Cache-Control": "no-cache",
|
|
},
|
|
)
|
|
|
|
|
|
def _check_training_token(x_training_token: str | None) -> None:
|
|
"""Validate training token if TRAINING_TOKEN env var is set."""
|
|
if TRAINING_TOKEN and x_training_token != TRAINING_TOKEN:
|
|
raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token")
|
|
|
|
|
|
def _validate_zip_entry(name: str, extract_dir: str) -> None:
|
|
"""Reject ZIP Slip attacks: path traversal and absolute paths."""
|
|
if os.path.isabs(name) or name.startswith(".."):
|
|
raise HTTPException(status_code=400, detail=f"Unsafe ZIP entry: {name}")
|
|
resolved = os.path.realpath(os.path.join(extract_dir, name))
|
|
if not resolved.startswith(os.path.realpath(extract_dir)):
|
|
raise HTTPException(status_code=400, detail=f"ZIP Slip detected: {name}")
|
|
|
|
|
|
def _rotate_backups(model_path: str, keep: int = 3) -> None:
|
|
"""Keep only the last `keep` timestamped backups of the model."""
|
|
pattern = model_path + ".*.bak"
|
|
backups = sorted(glob.glob(pattern))
|
|
for old in backups[:-keep]:
|
|
try:
|
|
os.remove(old)
|
|
except OSError:
|
|
logger.warning("Could not remove old backup: %s", old)
|
|
|
|
|
|
def _parse_best_checkpoint(checkpoint_dir: str) -> tuple[float | None, int]:
|
|
"""Parse checkpoint filenames to find the best validation metric and epoch count.
|
|
|
|
Kraken saves checkpoints as e.g. ``checkpoint_03-0.9500.ckpt``.
|
|
Returns (best_accuracy, epoch_count).
|
|
"""
|
|
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
|
|
best_acc: float | None = None
|
|
max_epoch = 0
|
|
for fname in os.listdir(checkpoint_dir):
|
|
m = pattern.match(fname)
|
|
if m:
|
|
epoch = int(m.group(1))
|
|
acc = float(m.group(2))
|
|
max_epoch = max(max_epoch, epoch)
|
|
if best_acc is None or acc > best_acc:
|
|
best_acc = acc
|
|
return best_acc, max_epoch
|
|
|
|
|
|
def _find_best_model(checkpoint_dir: str) -> str | None:
|
|
"""Return the checkpoint file with the highest validation metric, or any model file."""
|
|
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
|
|
best_acc: float | None = None
|
|
best_path: str | None = None
|
|
for fname in os.listdir(checkpoint_dir):
|
|
m = pattern.match(fname)
|
|
if m:
|
|
acc = float(m.group(2))
|
|
if best_acc is None or acc > best_acc:
|
|
best_acc = acc
|
|
best_path = os.path.join(checkpoint_dir, fname)
|
|
if best_path:
|
|
return best_path
|
|
# Fallback: any .mlmodel file in the directory
|
|
for fname in os.listdir(checkpoint_dir):
|
|
if fname.endswith(".mlmodel"):
|
|
return os.path.join(checkpoint_dir, fname)
|
|
return None
|
|
|
|
|
|
@app.post("/train")
|
|
async def train_model(
|
|
file: UploadFile,
|
|
x_training_token: str | None = Header(default=None),
|
|
):
|
|
"""Fine-tune the Kurrent recognition model with uploaded training data.
|
|
|
|
Accepts a ZIP archive containing .png/.gt.txt training pairs exported
|
|
by the Java backend. Training mutates in-process model state — not safe
|
|
if the service is replicated.
|
|
"""
|
|
_check_training_token(x_training_token)
|
|
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
|
|
zip_bytes = await file.read()
|
|
|
|
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
|
log.info("Starting training run %s", training_run_id)
|
|
|
|
def _run_training() -> dict:
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
# Extract ZIP with safety checks
|
|
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
|
for entry in zf.namelist():
|
|
_validate_zip_entry(entry, tmp_dir)
|
|
zf.extractall(tmp_dir)
|
|
|
|
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
|
|
|
|
ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt"))
|
|
if not ground_truth:
|
|
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
|
|
|
|
log.info("Training on %d ground-truth pairs", len(ground_truth))
|
|
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
cmd = [
|
|
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
|
|
"train",
|
|
"-f", "path",
|
|
"-o", checkpoint_dir,
|
|
"-q", "fixed",
|
|
"-N", "10",
|
|
"-B", "1",
|
|
]
|
|
if os.path.exists(KRAKEN_MODEL_PATH):
|
|
cmd += ["-i", KRAKEN_MODEL_PATH]
|
|
cmd += ground_truth
|
|
|
|
log.info("Running: %s", " ".join(cmd[:6]) + " ...")
|
|
proc = subprocess.run(cmd, capture_output=True, text=True)
|
|
log.info("ketos train stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
|
|
if proc.stderr:
|
|
log.info("ketos train stderr: %s", proc.stderr[-2000:])
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
|
|
|
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
|
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
|
log.info("Training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
|
|
|
|
# Find the best model file produced by training
|
|
best_model = _find_best_model(checkpoint_dir)
|
|
if best_model is None:
|
|
raise RuntimeError("Training produced no model file")
|
|
|
|
# Backup existing model and replace
|
|
if os.path.exists(KRAKEN_MODEL_PATH):
|
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
backup_path = f"{KRAKEN_MODEL_PATH}.{timestamp}.bak"
|
|
shutil.copy2(KRAKEN_MODEL_PATH, backup_path)
|
|
log.info("Backed up model to %s", backup_path)
|
|
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
|
|
|
|
shutil.copy2(best_model, KRAKEN_MODEL_PATH)
|
|
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
|
|
|
|
# Reload model in-process
|
|
kraken_engine.load_models()
|
|
log.info("Reloaded Kraken model in-process")
|
|
|
|
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
|
|
|
result = await asyncio.to_thread(_run_training)
|
|
return result
|
|
|
|
|
|
@app.post("/segtrain")
|
|
async def segtrain_model(
|
|
file: UploadFile,
|
|
x_training_token: str | None = Header(default=None),
|
|
):
|
|
"""Fine-tune the blla segmentation model with uploaded PAGE XML training data.
|
|
|
|
Accepts a ZIP archive containing .png/.xml (PAGE XML) training pairs exported
|
|
by the Java backend. Training mutates in-process model state — not safe
|
|
if the service is replicated.
|
|
"""
|
|
_check_training_token(x_training_token)
|
|
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
|
|
zip_bytes = await file.read()
|
|
|
|
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
|
log.info("Starting segmentation training run %s", training_run_id)
|
|
|
|
blla_model_path = os.environ.get("BLLA_MODEL_PATH", "/app/models/blla.mlmodel")
|
|
|
|
def _run_segtrain() -> dict:
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
|
for entry in zf.namelist():
|
|
_validate_zip_entry(entry, tmp_dir)
|
|
zf.extractall(tmp_dir)
|
|
|
|
log.info("Extracted %d ZIP entries for segmentation training", len(os.listdir(tmp_dir)))
|
|
|
|
xml_files = glob.glob(os.path.join(tmp_dir, "*.xml"))
|
|
if not xml_files:
|
|
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
|
|
|
|
log.info("Training on %d PAGE XML files", len(xml_files))
|
|
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
cmd = [
|
|
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
|
|
"segtrain",
|
|
"-o", checkpoint_dir,
|
|
"-q", "fixed",
|
|
"-N", "10",
|
|
]
|
|
if os.path.exists(blla_model_path):
|
|
cmd += ["-i", blla_model_path]
|
|
cmd += xml_files
|
|
|
|
log.info("Running: %s", " ".join(cmd[:5]) + " ...")
|
|
proc = subprocess.run(cmd, capture_output=True, text=True)
|
|
log.info("ketos segtrain stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
|
|
if proc.stderr:
|
|
log.info("ketos segtrain stderr: %s", proc.stderr[-2000:])
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
|
|
|
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
|
log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
|
|
|
|
best_model = _find_best_model(checkpoint_dir)
|
|
if best_model is None:
|
|
raise RuntimeError("Segmentation training produced no model file")
|
|
|
|
if os.path.exists(blla_model_path):
|
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
backup_path = f"{blla_model_path}.{timestamp}.bak"
|
|
shutil.copy2(blla_model_path, backup_path)
|
|
_rotate_backups(blla_model_path, keep=3)
|
|
|
|
shutil.copy2(best_model, blla_model_path)
|
|
log.info("Replaced blla model at %s", blla_model_path)
|
|
|
|
return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
|
|
|
|
result = await asyncio.to_thread(_run_segtrain)
|
|
return result
|
|
|
|
|
|
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
|
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
|
_validate_url(url)
|
|
async with httpx.AsyncClient(
|
|
timeout=httpx.Timeout(300.0), follow_redirects=False
|
|
) as client:
|
|
response = await client.get(url)
|
|
response.raise_for_status()
|
|
|
|
pdf = pdfium.PdfDocument(io.BytesIO(response.content))
|
|
images = []
|
|
|
|
for page_idx in range(len(pdf)):
|
|
page = pdf[page_idx]
|
|
# Render at 200 DPI — balances OCR quality vs memory usage
|
|
# (Surya 0.17 models use ~5GB idle; 300 DPI causes OOM on multi-page docs)
|
|
bitmap = page.render(scale=200 / 72)
|
|
pil_image = bitmap.to_pil()
|
|
images.append(pil_image)
|
|
|
|
return images
|