When a document has manually drawn annotation boxes, the user can now enable "Nur annotierte Bereiche" in the OCR trigger panel. The engine skips layout detection entirely and runs recognition only within the pre-drawn bounding boxes, preserving manual transcription blocks. - Python: adds OcrRegion model, extend OcrRequest/OcrBlock; guided branch in /ocr/stream groups by page and crops each region - Engines: add extract_region_text() to both Kraken and Surya - Java: adds OcrBlockResult.annotationId, OcrClient.OcrRegion, TriggerOcrDTO.useExistingAnnotations; OcrAsyncRunner dispatches to upsertGuidedBlock when annotationId is present; OcrService threads the flag through to runSingleDocument - TranscriptionService: adds upsertGuidedBlock (creates, updates OCR, or preserves MANUAL blocks) - Frontend: guided OCR toggle in OcrTrigger shown when blocks exist; skips destructive-replace confirmation in guided mode Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
450 lines
16 KiB
Python
450 lines
16 KiB
Python
"""OCR microservice — FastAPI app with Surya and Kraken engine support."""
|
|
|
|
import asyncio
|
|
import glob
|
|
import io
|
|
import json
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import zipfile
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime, timezone
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
import pypdfium2 as pdfium
|
|
from fastapi import FastAPI, Header, HTTPException, UploadFile
|
|
from fastapi.responses import StreamingResponse
|
|
from PIL import Image
|
|
|
|
from confidence import apply_confidence_markers, get_threshold
|
|
from engines import kraken as kraken_engine
|
|
from engines import surya as surya_engine
|
|
from models import OcrBlock, OcrRequest
|
|
|
|
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
|
|
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_models_ready = False
|
|
|
|
ALLOWED_PDF_HOSTS = set(
|
|
h.strip() for h in os.getenv("ALLOWED_PDF_HOSTS", "minio,localhost,127.0.0.1").split(",")
|
|
)
|
|
|
|
|
|
def _validate_url(url: str) -> None:
|
|
"""Validate that the PDF URL points to an allowed host (SSRF protection)."""
|
|
parsed = urlparse(url)
|
|
hostname = parsed.hostname or ""
|
|
if hostname not in ALLOWED_PDF_HOSTS:
|
|
raise HTTPException(status_code=400, detail=f"PDF host not allowed: {hostname}")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Load lightweight models at startup. Surya loads lazily on first request."""
|
|
global _models_ready
|
|
|
|
logger.info("Loading Kraken model at startup (Surya loads lazily on first OCR request)...")
|
|
kraken_engine.load_models()
|
|
_models_ready = True
|
|
logger.info("Startup complete — ready to accept requests")
|
|
|
|
yield
|
|
|
|
logger.info("Shutting down OCR service")
|
|
|
|
|
|
app = FastAPI(title="Familienarchiv OCR Service", lifespan=lifespan)
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
"""Health endpoint — returns 200 only after models are loaded."""
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
return {"status": "ok", "surya": True, "kraken": kraken_engine.is_available()}
|
|
|
|
|
|
@app.post("/ocr", response_model=list[OcrBlock])
|
|
async def run_ocr(request: OcrRequest):
|
|
"""Run OCR on a PDF document.
|
|
|
|
Downloads the PDF from the provided URL, converts pages to images,
|
|
and runs the appropriate OCR engine based on scriptType.
|
|
OCR engines run in a thread pool so the event loop stays free for /health.
|
|
"""
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
|
|
images = await _download_and_convert_pdf(request.pdfUrl)
|
|
|
|
script_type = request.scriptType.upper()
|
|
|
|
if script_type == "HANDWRITING_KURRENT":
|
|
if not kraken_engine.is_available():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Kraken model not available — cannot process Kurrent script",
|
|
)
|
|
blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language)
|
|
else:
|
|
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
|
|
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
|
|
|
|
threshold = get_threshold(script_type)
|
|
for block in blocks:
|
|
if block.get("words"):
|
|
block["text"] = apply_confidence_markers(block["words"], threshold)
|
|
block.pop("words", None)
|
|
|
|
return [OcrBlock(**b) for b in blocks]
|
|
|
|
|
|
@app.post("/ocr/stream")
|
|
async def run_ocr_stream(request: OcrRequest):
|
|
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page.
|
|
|
|
When request.regions is provided, runs in guided mode: each region is cropped
|
|
and recognized individually, skipping full-page layout detection. The response
|
|
blocks include the annotationId from the region.
|
|
"""
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
|
|
images = await _download_and_convert_pdf(request.pdfUrl)
|
|
script_type = request.scriptType.upper()
|
|
threshold = get_threshold(script_type)
|
|
|
|
use_kraken = script_type == "HANDWRITING_KURRENT"
|
|
if use_kraken and not kraken_engine.is_available():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Kraken model not available — cannot process Kurrent script",
|
|
)
|
|
|
|
engine = kraken_engine if use_kraken else surya_engine
|
|
|
|
if request.regions:
|
|
# Guided mode: recognize only the user-drawn annotation regions
|
|
regions_by_page: dict[int, list] = {}
|
|
for region in request.regions:
|
|
regions_by_page.setdefault(region.pageNumber, []).append(region)
|
|
|
|
async def generate_guided():
|
|
total_pages = len(images)
|
|
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
|
|
|
total_blocks = 0
|
|
skipped_pages = 0
|
|
|
|
for page_idx, image in enumerate(images, start=1):
|
|
page_regions = regions_by_page.get(page_idx, [])
|
|
if not page_regions:
|
|
yield json.dumps({
|
|
"type": "page",
|
|
"pageNumber": page_idx,
|
|
"blocks": [],
|
|
}) + "\n"
|
|
del image
|
|
continue
|
|
|
|
try:
|
|
blocks = []
|
|
for region in page_regions:
|
|
text = await asyncio.to_thread(
|
|
engine.extract_region_text, image,
|
|
region.x, region.y, region.width, region.height,
|
|
)
|
|
blocks.append({
|
|
"pageNumber": page_idx,
|
|
"x": region.x,
|
|
"y": region.y,
|
|
"width": region.width,
|
|
"height": region.height,
|
|
"polygon": None,
|
|
"text": text,
|
|
"annotationId": region.annotationId,
|
|
})
|
|
|
|
total_blocks += len(blocks)
|
|
yield json.dumps({
|
|
"type": "page",
|
|
"pageNumber": page_idx,
|
|
"blocks": blocks,
|
|
}) + "\n"
|
|
|
|
except Exception:
|
|
logger.exception("Guided OCR failed on page %d", page_idx)
|
|
skipped_pages += 1
|
|
yield json.dumps({
|
|
"type": "error",
|
|
"pageNumber": page_idx,
|
|
"message": f"Guided OCR processing failed on page {page_idx}",
|
|
}) + "\n"
|
|
|
|
finally:
|
|
del image
|
|
|
|
yield json.dumps({
|
|
"type": "done",
|
|
"totalBlocks": total_blocks,
|
|
"skippedPages": skipped_pages,
|
|
}) + "\n"
|
|
|
|
return StreamingResponse(
|
|
generate_guided(),
|
|
media_type="application/x-ndjson",
|
|
headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"},
|
|
)
|
|
|
|
async def generate():
|
|
total_pages = len(images)
|
|
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
|
|
|
total_blocks = 0
|
|
skipped_pages = 0
|
|
|
|
for page_idx, image in enumerate(images, start=1):
|
|
try:
|
|
blocks = await asyncio.to_thread(
|
|
engine.extract_page_blocks, image, page_idx, request.language
|
|
)
|
|
|
|
for block in blocks:
|
|
if block.get("words"):
|
|
block["text"] = apply_confidence_markers(block["words"], threshold)
|
|
block.pop("words", None)
|
|
|
|
total_blocks += len(blocks)
|
|
yield json.dumps({
|
|
"type": "page",
|
|
"pageNumber": page_idx,
|
|
"blocks": blocks,
|
|
}) + "\n"
|
|
|
|
except Exception:
|
|
logger.exception("OCR failed on page %d", page_idx)
|
|
skipped_pages += 1
|
|
yield json.dumps({
|
|
"type": "error",
|
|
"pageNumber": page_idx,
|
|
"message": f"OCR processing failed on page {page_idx}",
|
|
}) + "\n"
|
|
|
|
finally:
|
|
del image
|
|
|
|
yield json.dumps({
|
|
"type": "done",
|
|
"totalBlocks": total_blocks,
|
|
"skippedPages": skipped_pages,
|
|
}) + "\n"
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="application/x-ndjson",
|
|
headers={
|
|
"X-Accel-Buffering": "no",
|
|
"Cache-Control": "no-cache",
|
|
},
|
|
)
|
|
|
|
|
|
def _check_training_token(x_training_token: str | None) -> None:
|
|
"""Validate training token if TRAINING_TOKEN env var is set."""
|
|
if TRAINING_TOKEN and x_training_token != TRAINING_TOKEN:
|
|
raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token")
|
|
|
|
|
|
def _validate_zip_entry(name: str, extract_dir: str) -> None:
|
|
"""Reject ZIP Slip attacks: path traversal and absolute paths."""
|
|
if os.path.isabs(name) or name.startswith(".."):
|
|
raise HTTPException(status_code=400, detail=f"Unsafe ZIP entry: {name}")
|
|
resolved = os.path.realpath(os.path.join(extract_dir, name))
|
|
if not resolved.startswith(os.path.realpath(extract_dir)):
|
|
raise HTTPException(status_code=400, detail=f"ZIP Slip detected: {name}")
|
|
|
|
|
|
def _rotate_backups(model_path: str, keep: int = 3) -> None:
|
|
"""Keep only the last `keep` timestamped backups of the model."""
|
|
pattern = model_path + ".*.bak"
|
|
backups = sorted(glob.glob(pattern))
|
|
for old in backups[:-keep]:
|
|
try:
|
|
os.remove(old)
|
|
except OSError:
|
|
logger.warning("Could not remove old backup: %s", old)
|
|
|
|
|
|
@app.post("/train")
|
|
async def train_model(
|
|
file: UploadFile,
|
|
x_training_token: str | None = Header(default=None),
|
|
):
|
|
"""Fine-tune the Kurrent recognition model with uploaded training data.
|
|
|
|
Accepts a ZIP archive containing .png/.gt.txt training pairs exported
|
|
by the Java backend. Training mutates in-process model state — not safe
|
|
if the service is replicated.
|
|
"""
|
|
_check_training_token(x_training_token)
|
|
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
|
|
zip_bytes = await file.read()
|
|
|
|
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
|
log.info("Starting training run %s", training_run_id)
|
|
|
|
def _run_training() -> dict:
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
# Extract ZIP with safety checks
|
|
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
|
for entry in zf.namelist():
|
|
_validate_zip_entry(entry, tmp_dir)
|
|
zf.extractall(tmp_dir)
|
|
|
|
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
|
|
|
|
# Run ketos train (transfer learning from existing model)
|
|
from kraken import ketos
|
|
ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt"))
|
|
if not ground_truth:
|
|
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
|
|
|
|
log.info("Training on %d ground-truth pairs", len(ground_truth))
|
|
output_model_path = os.path.join(tmp_dir, "fine_tuned.mlmodel")
|
|
|
|
result = ketos.train(
|
|
ground_truth=ground_truth,
|
|
load=KRAKEN_MODEL_PATH,
|
|
output=output_model_path,
|
|
format_type="path",
|
|
)
|
|
|
|
epochs = getattr(result, "epochs", None) or 0
|
|
loss = getattr(result, "best_loss", None)
|
|
accuracy = getattr(result, "best_accuracy", None)
|
|
|
|
log.info("Training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
|
|
|
|
# Backup existing model and replace
|
|
if os.path.exists(KRAKEN_MODEL_PATH):
|
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
backup_path = f"{KRAKEN_MODEL_PATH}.{timestamp}.bak"
|
|
shutil.copy2(KRAKEN_MODEL_PATH, backup_path)
|
|
log.info("Backed up model to %s", backup_path)
|
|
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
|
|
|
|
shutil.move(output_model_path, KRAKEN_MODEL_PATH)
|
|
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
|
|
|
|
# Reload model in-process
|
|
kraken_engine.load_models()
|
|
log.info("Reloaded Kraken model in-process")
|
|
|
|
return {"loss": loss, "accuracy": accuracy, "epochs": epochs}
|
|
|
|
result = await asyncio.to_thread(_run_training)
|
|
return result
|
|
|
|
|
|
@app.post("/segtrain")
|
|
async def segtrain_model(
|
|
file: UploadFile,
|
|
x_training_token: str | None = Header(default=None),
|
|
):
|
|
"""Fine-tune the blla segmentation model with uploaded PAGE XML training data.
|
|
|
|
Accepts a ZIP archive containing .png/.xml (PAGE XML) training pairs exported
|
|
by the Java backend. Training mutates in-process model state — not safe
|
|
if the service is replicated.
|
|
"""
|
|
_check_training_token(x_training_token)
|
|
|
|
if not _models_ready:
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
|
|
|
zip_bytes = await file.read()
|
|
|
|
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
|
log.info("Starting segmentation training run %s", training_run_id)
|
|
|
|
blla_model_path = os.environ.get("BLLA_MODEL_PATH", "/app/models/blla.mlmodel")
|
|
|
|
def _run_segtrain() -> dict:
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
|
for entry in zf.namelist():
|
|
_validate_zip_entry(entry, tmp_dir)
|
|
zf.extractall(tmp_dir)
|
|
|
|
log.info("Extracted %d ZIP entries for segmentation training", len(os.listdir(tmp_dir)))
|
|
|
|
xml_files = glob.glob(os.path.join(tmp_dir, "*.xml"))
|
|
if not xml_files:
|
|
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
|
|
|
|
log.info("Training on %d PAGE XML files", len(xml_files))
|
|
output_model_path = os.path.join(tmp_dir, "fine_tuned_blla.mlmodel")
|
|
|
|
from kraken import ketos
|
|
result = ketos.segtrain(
|
|
ground_truth=xml_files,
|
|
load=blla_model_path if os.path.exists(blla_model_path) else None,
|
|
output=output_model_path,
|
|
format_type="path",
|
|
)
|
|
|
|
epochs = getattr(result, "epochs", None) or 0
|
|
loss = getattr(result, "best_loss", None)
|
|
accuracy = getattr(result, "best_accuracy", None)
|
|
|
|
log.info("Segmentation training complete — epochs=%s loss=%s", epochs, loss)
|
|
|
|
if os.path.exists(blla_model_path):
|
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
backup_path = f"{blla_model_path}.{timestamp}.bak"
|
|
shutil.copy2(blla_model_path, backup_path)
|
|
_rotate_backups(blla_model_path, keep=3)
|
|
|
|
shutil.move(output_model_path, blla_model_path)
|
|
log.info("Replaced blla model at %s", blla_model_path)
|
|
|
|
return {"loss": loss, "accuracy": accuracy, "epochs": epochs}
|
|
|
|
result = await asyncio.to_thread(_run_segtrain)
|
|
return result
|
|
|
|
|
|
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
|
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
|
_validate_url(url)
|
|
async with httpx.AsyncClient(
|
|
timeout=httpx.Timeout(300.0), follow_redirects=False
|
|
) as client:
|
|
response = await client.get(url)
|
|
response.raise_for_status()
|
|
|
|
pdf = pdfium.PdfDocument(io.BytesIO(response.content))
|
|
images = []
|
|
|
|
for page_idx in range(len(pdf)):
|
|
page = pdf[page_idx]
|
|
# Render at 200 DPI — balances OCR quality vs memory usage
|
|
# (Surya 0.17 models use ~5GB idle; 300 DPI causes OOM on multi-page docs)
|
|
bitmap = page.render(scale=200 / 72)
|
|
pil_image = bitmap.to_pil()
|
|
images.append(pil_image)
|
|
|
|
return images
|