feat(ocr): full OCR pipeline with polygon annotations, training, and guided mode #232

Merged
marcel merged 40 commits from feat/issue-226-227-ocr-pipeline-polygon into main 2026-04-14 10:31:35 +02:00
6 changed files with 188 additions and 8 deletions
Showing only changes of commit bc97a2dade - Show all commits

View File

@@ -10,6 +10,16 @@ import java.util.function.Consumer;
public interface OcrClient {
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
/**
* Send a training ZIP to the OCR service for fine-tuning the Kurrent model.
*
* @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService
* @return training result metrics (loss, accuracy, epochs)
*/
TrainingResult trainModel(byte[] trainingDataZip);
record TrainingResult(Double loss, Double accuracy, Integer epochs) {}
/**
* Stream OCR results page-by-page via NDJSON. Implementations should override
* this method. The default exists only for backward compatibility during migration

View File

@@ -9,9 +9,14 @@ import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.model.ScriptType;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestClient;
import java.io.BufferedReader;
@@ -36,11 +41,16 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true);
private final RestClient restClient;
private final RestClient trainingRestClient;
private final HttpClient streamingHttpClient;
private final String baseUrl;
private final String trainingToken;
public RestClientOcrClient(@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl) {
public RestClientOcrClient(
@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl,
@Value("${app.ocr.training-token:}") String trainingToken) {
this.baseUrl = baseUrl;
this.trainingToken = trainingToken;
HttpClient httpClient = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
@@ -54,6 +64,17 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.requestFactory(requestFactory)
.build();
HttpClient trainingHttpClient = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
.connectTimeout(Duration.ofSeconds(10))
.build();
JdkClientHttpRequestFactory trainingRequestFactory = new JdkClientHttpRequestFactory(trainingHttpClient);
trainingRequestFactory.setReadTimeout(Duration.ofMinutes(10));
this.trainingRestClient = RestClient.builder()
.baseUrl(baseUrl)
.requestFactory(trainingRequestFactory)
.build();
this.streamingHttpClient = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
.connectTimeout(Duration.ofSeconds(10))
@@ -81,6 +102,35 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.toList();
}
@Override
public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) {
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
@Override
public String getFilename() { return "training-data.zip"; }
};
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
HttpHeaders partHeaders = new HttpHeaders();
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders));
var spec = trainingRestClient.post()
.uri("/train")
.contentType(MediaType.MULTIPART_FORM_DATA);
if (trainingToken != null && !trainingToken.isBlank()) {
spec = spec.header("X-Training-Token", trainingToken);
}
TrainingResultJson result = spec
.body(body)
.retrieve()
.body(TrainingResultJson.class);
if (result == null) return new OcrClient.TrainingResult(null, null, null);
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
}
@Override
public boolean isHealthy() {
try {
@@ -171,6 +221,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
}
}
record TrainingResultJson(Double loss, Double accuracy, Integer epochs) {}
record OcrBlockJson(
@JsonProperty("pageNumber") int pageNumber,
double x,

View File

@@ -12,10 +12,15 @@ class OcrClientDefaultStreamTest {
@Test
void defaultStreamBlocksSynthesizesEventsFromExtractBlocks() {
OcrClient client = (pdfUrl, scriptType) -> List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
OcrClient client = new OcrClient() {
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
return List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
}
@Override public TrainingResult trainModel(byte[] zip) { return null; }
};
List<OcrStreamEvent> events = new ArrayList<>();
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
@@ -42,7 +47,10 @@ class OcrClientDefaultStreamTest {
@Test
void defaultStreamBlocksHandlesEmptyResults() {
OcrClient client = (pdfUrl, scriptType) -> List.of();
OcrClient client = new OcrClient() {
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); }
@Override public TrainingResult trainModel(byte[] zip) { return null; }
};
List<OcrStreamEvent> events = new ArrayList<>();
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);

View File

@@ -87,6 +87,7 @@ services:
- ocr_cache:/root/.cache
environment:
KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel
TRAINING_TOKEN: "${OCR_TRAINING_TOKEN:-}"
OCR_CONFIDENCE_THRESHOLD: "0.3"
OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5"
RECOGNITION_BATCH_SIZE: "16"

View File

@@ -23,4 +23,4 @@ COPY . .
EXPOSE 8000
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]

View File

@@ -1,16 +1,21 @@
"""OCR microservice — FastAPI app with Surya and Kraken engine support."""
import asyncio
import glob
import io
import json
import logging
import os
import shutil
import tempfile
import zipfile
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from urllib.parse import urlparse
import httpx
import pypdfium2 as pdfium
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, Header, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from PIL import Image
@@ -19,6 +24,9 @@ from engines import kraken as kraken_engine
from engines import surya as surya_engine
from models import OcrBlock, OcrRequest
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -169,6 +177,107 @@ async def run_ocr_stream(request: OcrRequest):
)
def _check_training_token(x_training_token: str | None) -> None:
"""Validate training token if TRAINING_TOKEN env var is set."""
if TRAINING_TOKEN and x_training_token != TRAINING_TOKEN:
raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token")
def _validate_zip_entry(name: str, extract_dir: str) -> None:
"""Reject ZIP Slip attacks: path traversal and absolute paths."""
if os.path.isabs(name) or name.startswith(".."):
raise HTTPException(status_code=400, detail=f"Unsafe ZIP entry: {name}")
resolved = os.path.realpath(os.path.join(extract_dir, name))
if not resolved.startswith(os.path.realpath(extract_dir)):
raise HTTPException(status_code=400, detail=f"ZIP Slip detected: {name}")
def _rotate_backups(model_path: str, keep: int = 3) -> None:
"""Keep only the last `keep` timestamped backups of the model."""
pattern = model_path + ".*.bak"
backups = sorted(glob.glob(pattern))
for old in backups[:-keep]:
try:
os.remove(old)
except OSError:
logger.warning("Could not remove old backup: %s", old)
@app.post("/train")
async def train_model(
file: UploadFile,
x_training_token: str | None = Header(default=None),
):
"""Fine-tune the Kurrent recognition model with uploaded training data.
Accepts a ZIP archive containing .png/.gt.txt training pairs exported
by the Java backend. Training mutates in-process model state — not safe
if the service is replicated.
"""
_check_training_token(x_training_token)
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
zip_bytes = await file.read()
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
log.info("Starting training run %s", training_run_id)
def _run_training() -> dict:
with tempfile.TemporaryDirectory() as tmp_dir:
# Extract ZIP with safety checks
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
for entry in zf.namelist():
_validate_zip_entry(entry, tmp_dir)
zf.extractall(tmp_dir)
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
# Run ketos train (transfer learning from existing model)
from kraken import ketos
ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt"))
if not ground_truth:
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
log.info("Training on %d ground-truth pairs", len(ground_truth))
output_model_path = os.path.join(tmp_dir, "fine_tuned.mlmodel")
result = ketos.train(
ground_truth=ground_truth,
load=KRAKEN_MODEL_PATH,
output=output_model_path,
format_type="path",
)
epochs = getattr(result, "epochs", None) or 0
loss = getattr(result, "best_loss", None)
accuracy = getattr(result, "best_accuracy", None)
log.info("Training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
# Backup existing model and replace
if os.path.exists(KRAKEN_MODEL_PATH):
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
backup_path = f"{KRAKEN_MODEL_PATH}.{timestamp}.bak"
shutil.copy2(KRAKEN_MODEL_PATH, backup_path)
log.info("Backed up model to %s", backup_path)
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
shutil.move(output_model_path, KRAKEN_MODEL_PATH)
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
# Reload model in-process
kraken_engine.load_models()
log.info("Reloaded Kraken model in-process")
return {"loss": loss, "accuracy": accuracy, "epochs": epochs}
result = await asyncio.to_thread(_run_training)
return result
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
_validate_url(url)