After each training run, the Character Error Rate (CER = 1 - accuracy), loss, accuracy, and epoch count are now stored on the OcrTrainingRun record and shown in the training history table. Also adds the missing POST /api/ocr/segtrain endpoint and the triggerSegTraining service method so the segmentation training card can actually trigger training. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
452 lines
16 KiB
Python
452 lines
16 KiB
Python
"""OCR microservice — FastAPI app with Surya and Kraken engine support."""
|
|
|
|
import asyncio
|
|
import glob
|
|
import io
|
|
import json
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import zipfile
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime, timezone
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
import pypdfium2 as pdfium
|
|
from fastapi import FastAPI, Header, HTTPException, UploadFile
|
|
from fastapi.responses import StreamingResponse
|
|
from PIL import Image
|
|
|
|
from confidence import apply_confidence_markers, get_threshold
|
|
from engines import kraken as kraken_engine
|
|
from engines import surya as surya_engine
|
|
from models import OcrBlock, OcrRequest
|
|
|
|
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
|
|
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_models_ready = False
|
|
|
|
ALLOWED_PDF_HOSTS = set(
|
|
h.strip() for h in os.getenv("ALLOWED_PDF_HOSTS", "minio,localhost,127.0.0.1").split(",")
|
|
)
|
|
|
|
|
|
def _validate_url(url: str) -> None:
|
|
"""Validate that the PDF URL points to an allowed host (SSRF protection)."""
|
|
parsed = urlparse(url)
|
|
hostname = parsed.hostname or ""
|
|
if hostname not in ALLOWED_PDF_HOSTS:
|
|
raise HTTPException(status_code=400, detail=f"PDF host not allowed: {hostname}")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Load lightweight models at startup. Surya loads lazily on first request."""
|
|
global _models_ready
|
|
|
|
logger.info("Loading Kraken model at startup (Surya loads lazily on first OCR request)...")
|
|
kraken_engine.load_models()
|
|
_models_ready = True
|
|
logger.info("Startup complete — ready to accept requests")
|
|
|
|
yield
|
|
|
|
logger.info("Shutting down OCR service")
|
|
|
|
|
|
app = FastAPI(title="Familienarchiv OCR Service", lifespan=lifespan)
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
"""Health endpoint — returns 200 only after models are loaded."""
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
return {"status": "ok", "surya": True, "kraken": kraken_engine.is_available()}
|
|
|
|
|
|
@app.post("/ocr", response_model=list[OcrBlock])
|
|
async def run_ocr(request: OcrRequest):
|
|
"""Run OCR on a PDF document.
|
|
|
|
Downloads the PDF from the provided URL, converts pages to images,
|
|
and runs the appropriate OCR engine based on scriptType.
|
|
OCR engines run in a thread pool so the event loop stays free for /health.
|
|
"""
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
|
|
images = await _download_and_convert_pdf(request.pdfUrl)
|
|
|
|
script_type = request.scriptType.upper()
|
|
|
|
if script_type == "HANDWRITING_KURRENT":
|
|
if not kraken_engine.is_available():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Kraken model not available — cannot process Kurrent script",
|
|
)
|
|
blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language)
|
|
else:
|
|
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
|
|
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
|
|
|
|
threshold = get_threshold(script_type)
|
|
for block in blocks:
|
|
if block.get("words"):
|
|
block["text"] = apply_confidence_markers(block["words"], threshold)
|
|
block.pop("words", None)
|
|
|
|
return [OcrBlock(**b) for b in blocks]
|
|
|
|
|
|
@app.post("/ocr/stream")
|
|
async def run_ocr_stream(request: OcrRequest):
|
|
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page.
|
|
|
|
When request.regions is provided, runs in guided mode: each region is cropped
|
|
and recognized individually, skipping full-page layout detection. The response
|
|
blocks include the annotationId from the region.
|
|
"""
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
|
|
images = await _download_and_convert_pdf(request.pdfUrl)
|
|
script_type = request.scriptType.upper()
|
|
threshold = get_threshold(script_type)
|
|
|
|
use_kraken = script_type == "HANDWRITING_KURRENT"
|
|
if use_kraken and not kraken_engine.is_available():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Kraken model not available — cannot process Kurrent script",
|
|
)
|
|
|
|
engine = kraken_engine if use_kraken else surya_engine
|
|
|
|
if request.regions:
|
|
# Guided mode: recognize only the user-drawn annotation regions
|
|
regions_by_page: dict[int, list] = {}
|
|
for region in request.regions:
|
|
regions_by_page.setdefault(region.pageNumber, []).append(region)
|
|
|
|
async def generate_guided():
|
|
total_pages = len(images)
|
|
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
|
|
|
total_blocks = 0
|
|
skipped_pages = 0
|
|
|
|
for page_idx, image in enumerate(images, start=1):
|
|
page_regions = regions_by_page.get(page_idx, [])
|
|
if not page_regions:
|
|
yield json.dumps({
|
|
"type": "page",
|
|
"pageNumber": page_idx,
|
|
"blocks": [],
|
|
}) + "\n"
|
|
del image
|
|
continue
|
|
|
|
try:
|
|
blocks = []
|
|
for region in page_regions:
|
|
text = await asyncio.to_thread(
|
|
engine.extract_region_text, image,
|
|
region.x, region.y, region.width, region.height,
|
|
)
|
|
blocks.append({
|
|
"pageNumber": page_idx,
|
|
"x": region.x,
|
|
"y": region.y,
|
|
"width": region.width,
|
|
"height": region.height,
|
|
"polygon": None,
|
|
"text": text,
|
|
"annotationId": region.annotationId,
|
|
})
|
|
|
|
total_blocks += len(blocks)
|
|
yield json.dumps({
|
|
"type": "page",
|
|
"pageNumber": page_idx,
|
|
"blocks": blocks,
|
|
}) + "\n"
|
|
|
|
except Exception:
|
|
logger.exception("Guided OCR failed on page %d", page_idx)
|
|
skipped_pages += 1
|
|
yield json.dumps({
|
|
"type": "error",
|
|
"pageNumber": page_idx,
|
|
"message": f"Guided OCR processing failed on page {page_idx}",
|
|
}) + "\n"
|
|
|
|
finally:
|
|
del image
|
|
|
|
yield json.dumps({
|
|
"type": "done",
|
|
"totalBlocks": total_blocks,
|
|
"skippedPages": skipped_pages,
|
|
}) + "\n"
|
|
|
|
return StreamingResponse(
|
|
generate_guided(),
|
|
media_type="application/x-ndjson",
|
|
headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"},
|
|
)
|
|
|
|
async def generate():
|
|
total_pages = len(images)
|
|
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
|
|
|
total_blocks = 0
|
|
skipped_pages = 0
|
|
|
|
for page_idx, image in enumerate(images, start=1):
|
|
try:
|
|
blocks = await asyncio.to_thread(
|
|
engine.extract_page_blocks, image, page_idx, request.language
|
|
)
|
|
|
|
for block in blocks:
|
|
if block.get("words"):
|
|
block["text"] = apply_confidence_markers(block["words"], threshold)
|
|
block.pop("words", None)
|
|
|
|
total_blocks += len(blocks)
|
|
yield json.dumps({
|
|
"type": "page",
|
|
"pageNumber": page_idx,
|
|
"blocks": blocks,
|
|
}) + "\n"
|
|
|
|
except Exception:
|
|
logger.exception("OCR failed on page %d", page_idx)
|
|
skipped_pages += 1
|
|
yield json.dumps({
|
|
"type": "error",
|
|
"pageNumber": page_idx,
|
|
"message": f"OCR processing failed on page {page_idx}",
|
|
}) + "\n"
|
|
|
|
finally:
|
|
del image
|
|
|
|
yield json.dumps({
|
|
"type": "done",
|
|
"totalBlocks": total_blocks,
|
|
"skippedPages": skipped_pages,
|
|
}) + "\n"
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="application/x-ndjson",
|
|
headers={
|
|
"X-Accel-Buffering": "no",
|
|
"Cache-Control": "no-cache",
|
|
},
|
|
)
|
|
|
|
|
|
def _check_training_token(x_training_token: str | None) -> None:
|
|
"""Validate training token if TRAINING_TOKEN env var is set."""
|
|
if TRAINING_TOKEN and x_training_token != TRAINING_TOKEN:
|
|
raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token")
|
|
|
|
|
|
def _validate_zip_entry(name: str, extract_dir: str) -> None:
|
|
"""Reject ZIP Slip attacks: path traversal and absolute paths."""
|
|
if os.path.isabs(name) or name.startswith(".."):
|
|
raise HTTPException(status_code=400, detail=f"Unsafe ZIP entry: {name}")
|
|
resolved = os.path.realpath(os.path.join(extract_dir, name))
|
|
if not resolved.startswith(os.path.realpath(extract_dir)):
|
|
raise HTTPException(status_code=400, detail=f"ZIP Slip detected: {name}")
|
|
|
|
|
|
def _rotate_backups(model_path: str, keep: int = 3) -> None:
|
|
"""Keep only the last `keep` timestamped backups of the model."""
|
|
pattern = model_path + ".*.bak"
|
|
backups = sorted(glob.glob(pattern))
|
|
for old in backups[:-keep]:
|
|
try:
|
|
os.remove(old)
|
|
except OSError:
|
|
logger.warning("Could not remove old backup: %s", old)
|
|
|
|
|
|
@app.post("/train")
|
|
async def train_model(
|
|
file: UploadFile,
|
|
x_training_token: str | None = Header(default=None),
|
|
):
|
|
"""Fine-tune the Kurrent recognition model with uploaded training data.
|
|
|
|
Accepts a ZIP archive containing .png/.gt.txt training pairs exported
|
|
by the Java backend. Training mutates in-process model state — not safe
|
|
if the service is replicated.
|
|
"""
|
|
_check_training_token(x_training_token)
|
|
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
|
|
zip_bytes = await file.read()
|
|
|
|
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
|
log.info("Starting training run %s", training_run_id)
|
|
|
|
def _run_training() -> dict:
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
# Extract ZIP with safety checks
|
|
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
|
for entry in zf.namelist():
|
|
_validate_zip_entry(entry, tmp_dir)
|
|
zf.extractall(tmp_dir)
|
|
|
|
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
|
|
|
|
# Run ketos train (transfer learning from existing model)
|
|
from kraken import ketos
|
|
ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt"))
|
|
if not ground_truth:
|
|
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
|
|
|
|
log.info("Training on %d ground-truth pairs", len(ground_truth))
|
|
output_model_path = os.path.join(tmp_dir, "fine_tuned.mlmodel")
|
|
|
|
result = ketos.train(
|
|
ground_truth=ground_truth,
|
|
load=KRAKEN_MODEL_PATH,
|
|
output=output_model_path,
|
|
format_type="path",
|
|
)
|
|
|
|
epochs = getattr(result, "epochs", None) or 0
|
|
loss = getattr(result, "best_loss", None)
|
|
accuracy = getattr(result, "best_accuracy", None)
|
|
# CER = 1 - character accuracy (Kraken reports character-level accuracy)
|
|
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
|
|
|
log.info("Training complete — epochs=%s loss=%s accuracy=%s cer=%s", epochs, loss, accuracy, cer)
|
|
|
|
# Backup existing model and replace
|
|
if os.path.exists(KRAKEN_MODEL_PATH):
|
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
backup_path = f"{KRAKEN_MODEL_PATH}.{timestamp}.bak"
|
|
shutil.copy2(KRAKEN_MODEL_PATH, backup_path)
|
|
log.info("Backed up model to %s", backup_path)
|
|
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
|
|
|
|
shutil.move(output_model_path, KRAKEN_MODEL_PATH)
|
|
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
|
|
|
|
# Reload model in-process
|
|
kraken_engine.load_models()
|
|
log.info("Reloaded Kraken model in-process")
|
|
|
|
return {"loss": loss, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
|
|
|
result = await asyncio.to_thread(_run_training)
|
|
return result
|
|
|
|
|
|
@app.post("/segtrain")
|
|
async def segtrain_model(
|
|
file: UploadFile,
|
|
x_training_token: str | None = Header(default=None),
|
|
):
|
|
"""Fine-tune the blla segmentation model with uploaded PAGE XML training data.
|
|
|
|
Accepts a ZIP archive containing .png/.xml (PAGE XML) training pairs exported
|
|
by the Java backend. Training mutates in-process model state — not safe
|
|
if the service is replicated.
|
|
"""
|
|
_check_training_token(x_training_token)
|
|
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
|
|
zip_bytes = await file.read()
|
|
|
|
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
|
log.info("Starting segmentation training run %s", training_run_id)
|
|
|
|
blla_model_path = os.environ.get("BLLA_MODEL_PATH", "/app/models/blla.mlmodel")
|
|
|
|
def _run_segtrain() -> dict:
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
|
for entry in zf.namelist():
|
|
_validate_zip_entry(entry, tmp_dir)
|
|
zf.extractall(tmp_dir)
|
|
|
|
log.info("Extracted %d ZIP entries for segmentation training", len(os.listdir(tmp_dir)))
|
|
|
|
xml_files = glob.glob(os.path.join(tmp_dir, "*.xml"))
|
|
if not xml_files:
|
|
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
|
|
|
|
log.info("Training on %d PAGE XML files", len(xml_files))
|
|
output_model_path = os.path.join(tmp_dir, "fine_tuned_blla.mlmodel")
|
|
|
|
from kraken import ketos
|
|
result = ketos.segtrain(
|
|
ground_truth=xml_files,
|
|
load=blla_model_path if os.path.exists(blla_model_path) else None,
|
|
output=output_model_path,
|
|
format_type="path",
|
|
)
|
|
|
|
epochs = getattr(result, "epochs", None) or 0
|
|
loss = getattr(result, "best_loss", None)
|
|
accuracy = getattr(result, "best_accuracy", None)
|
|
|
|
log.info("Segmentation training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
|
|
|
|
if os.path.exists(blla_model_path):
|
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
backup_path = f"{blla_model_path}.{timestamp}.bak"
|
|
shutil.copy2(blla_model_path, backup_path)
|
|
_rotate_backups(blla_model_path, keep=3)
|
|
|
|
shutil.move(output_model_path, blla_model_path)
|
|
log.info("Replaced blla model at %s", blla_model_path)
|
|
|
|
return {"loss": loss, "accuracy": accuracy, "cer": None, "epochs": epochs}
|
|
|
|
result = await asyncio.to_thread(_run_segtrain)
|
|
return result
|
|
|
|
|
|
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
|
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
|
_validate_url(url)
|
|
async with httpx.AsyncClient(
|
|
timeout=httpx.Timeout(300.0), follow_redirects=False
|
|
) as client:
|
|
response = await client.get(url)
|
|
response.raise_for_status()
|
|
|
|
pdf = pdfium.PdfDocument(io.BytesIO(response.content))
|
|
images = []
|
|
|
|
for page_idx in range(len(pdf)):
|
|
page = pdf[page_idx]
|
|
# Render at 200 DPI — balances OCR quality vs memory usage
|
|
# (Surya 0.17 models use ~5GB idle; 300 DPI causes OOM on multi-page docs)
|
|
bitmap = page.render(scale=200 / 72)
|
|
pil_image = bitmap.to_pil()
|
|
images.append(pil_image)
|
|
|
|
return images
|