From bc97a2dade088832a1e92835e945ce794a92fb57 Mon Sep 17 00:00:00 2001 From: Marcel Date: Mon, 13 Apr 2026 14:40:53 +0200 Subject: [PATCH] feat(ocr): add /train endpoint to OCR service and OcrClient.trainModel() - POST /train in ocr-service with ZIP Slip validation, TemporaryDirectory, ketos transfer learning, timestamped backups (keep last 3), in-process reload - X-Training-Token auth (no-op in dev when TRAINING_TOKEN env is empty) - trainModel() in OcrClient interface + RestClientOcrClient (10-min timeout, multipart upload, forwards X-Training-Token when configured) - TRAINING_TOKEN env var wired in docker-compose; --workers 2 in Dockerfile so /health stays responsive during synchronous training Co-Authored-By: Claude Sonnet 4.6 --- .../familienarchiv/service/OcrClient.java | 10 ++ .../service/RestClientOcrClient.java | 54 ++++++++- .../service/OcrClientDefaultStreamTest.java | 18 ++- docker-compose.yml | 1 + ocr-service/Dockerfile | 2 +- ocr-service/main.py | 111 +++++++++++++++++- 6 files changed, 188 insertions(+), 8 deletions(-) diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java index 9cf7c886..92330947 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java @@ -10,6 +10,16 @@ import java.util.function.Consumer; public interface OcrClient { List extractBlocks(String pdfUrl, ScriptType scriptType); + /** + * Send a training ZIP to the OCR service for fine-tuning the Kurrent model. + * + * @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService + * @return training result metrics (loss, accuracy, epochs) + */ + TrainingResult trainModel(byte[] trainingDataZip); + + record TrainingResult(Double loss, Double accuracy, Integer epochs) {} + /** * Stream OCR results page-by-page via NDJSON. Implementations should override * this method. The default exists only for backward compatibility during migration diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java b/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java index a0f7ccf3..f2c68187 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java @@ -9,9 +9,14 @@ import lombok.extern.slf4j.Slf4j; import org.raddatz.familienarchiv.model.ScriptType; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.client.JdkClientHttpRequestFactory; import org.springframework.stereotype.Component; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.web.client.RestClient; import java.io.BufferedReader; @@ -36,11 +41,16 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true); private final RestClient restClient; + private final RestClient trainingRestClient; private final HttpClient streamingHttpClient; private final String baseUrl; + private final String trainingToken; - public RestClientOcrClient(@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl) { + public RestClientOcrClient( + @Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl, + @Value("${app.ocr.training-token:}") String trainingToken) { this.baseUrl = baseUrl; + this.trainingToken = trainingToken; HttpClient httpClient = HttpClient.newBuilder() .version(HttpClient.Version.HTTP_1_1) @@ -54,6 +64,17 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { .requestFactory(requestFactory) .build(); + HttpClient trainingHttpClient = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .connectTimeout(Duration.ofSeconds(10)) + .build(); + JdkClientHttpRequestFactory trainingRequestFactory = new JdkClientHttpRequestFactory(trainingHttpClient); + trainingRequestFactory.setReadTimeout(Duration.ofMinutes(10)); + this.trainingRestClient = RestClient.builder() + .baseUrl(baseUrl) + .requestFactory(trainingRequestFactory) + .build(); + this.streamingHttpClient = HttpClient.newBuilder() .version(HttpClient.Version.HTTP_1_1) .connectTimeout(Duration.ofSeconds(10)) @@ -81,6 +102,35 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { .toList(); } + @Override + public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) { + ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) { + @Override + public String getFilename() { return "training-data.zip"; } + }; + + MultiValueMap body = new LinkedMultiValueMap<>(); + HttpHeaders partHeaders = new HttpHeaders(); + partHeaders.setContentType(MediaType.parseMediaType("application/zip")); + body.add("file", new HttpEntity<>(zipResource, partHeaders)); + + var spec = trainingRestClient.post() + .uri("/train") + .contentType(MediaType.MULTIPART_FORM_DATA); + + if (trainingToken != null && !trainingToken.isBlank()) { + spec = spec.header("X-Training-Token", trainingToken); + } + + TrainingResultJson result = spec + .body(body) + .retrieve() + .body(TrainingResultJson.class); + + if (result == null) return new OcrClient.TrainingResult(null, null, null); + return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs()); + } + @Override public boolean isHealthy() { try { @@ -171,6 +221,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { } } + record TrainingResultJson(Double loss, Double accuracy, Integer epochs) {} + record OcrBlockJson( @JsonProperty("pageNumber") int pageNumber, double x, diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrClientDefaultStreamTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrClientDefaultStreamTest.java index 42219299..25d129b3 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrClientDefaultStreamTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrClientDefaultStreamTest.java @@ -12,10 +12,15 @@ class OcrClientDefaultStreamTest { @Test void defaultStreamBlocksSynthesizesEventsFromExtractBlocks() { - OcrClient client = (pdfUrl, scriptType) -> List.of( - new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"), - new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"), - new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3")); + OcrClient client = new OcrClient() { + @Override public List extractBlocks(String pdfUrl, ScriptType scriptType) { + return List.of( + new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"), + new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"), + new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3")); + } + @Override public TrainingResult trainModel(byte[] zip) { return null; } + }; List events = new ArrayList<>(); client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add); @@ -42,7 +47,10 @@ class OcrClientDefaultStreamTest { @Test void defaultStreamBlocksHandlesEmptyResults() { - OcrClient client = (pdfUrl, scriptType) -> List.of(); + OcrClient client = new OcrClient() { + @Override public List extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); } + @Override public TrainingResult trainModel(byte[] zip) { return null; } + }; List events = new ArrayList<>(); client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add); diff --git a/docker-compose.yml b/docker-compose.yml index 46ed94b2..5e06094b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -87,6 +87,7 @@ services: - ocr_cache:/root/.cache environment: KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel + TRAINING_TOKEN: "${OCR_TRAINING_TOKEN:-}" OCR_CONFIDENCE_THRESHOLD: "0.3" OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5" RECOGNITION_BATCH_SIZE: "16" diff --git a/ocr-service/Dockerfile b/ocr-service/Dockerfile index 0c90503c..7e5cc939 100644 --- a/ocr-service/Dockerfile +++ b/ocr-service/Dockerfile @@ -23,4 +23,4 @@ COPY . . EXPOSE 8000 -CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"] diff --git a/ocr-service/main.py b/ocr-service/main.py index 11f8b520..9f92ca97 100644 --- a/ocr-service/main.py +++ b/ocr-service/main.py @@ -1,16 +1,21 @@ """OCR microservice — FastAPI app with Surya and Kraken engine support.""" import asyncio +import glob import io import json import logging import os +import shutil +import tempfile +import zipfile from contextlib import asynccontextmanager +from datetime import datetime, timezone from urllib.parse import urlparse import httpx import pypdfium2 as pdfium -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, Header, HTTPException, UploadFile from fastapi.responses import StreamingResponse from PIL import Image @@ -19,6 +24,9 @@ from engines import kraken as kraken_engine from engines import surya as surya_engine from models import OcrBlock, OcrRequest +TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "") +KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel") + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -169,6 +177,107 @@ async def run_ocr_stream(request: OcrRequest): ) +def _check_training_token(x_training_token: str | None) -> None: + """Validate training token if TRAINING_TOKEN env var is set.""" + if TRAINING_TOKEN and x_training_token != TRAINING_TOKEN: + raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token") + + +def _validate_zip_entry(name: str, extract_dir: str) -> None: + """Reject ZIP Slip attacks: path traversal and absolute paths.""" + if os.path.isabs(name) or name.startswith(".."): + raise HTTPException(status_code=400, detail=f"Unsafe ZIP entry: {name}") + resolved = os.path.realpath(os.path.join(extract_dir, name)) + if not resolved.startswith(os.path.realpath(extract_dir)): + raise HTTPException(status_code=400, detail=f"ZIP Slip detected: {name}") + + +def _rotate_backups(model_path: str, keep: int = 3) -> None: + """Keep only the last `keep` timestamped backups of the model.""" + pattern = model_path + ".*.bak" + backups = sorted(glob.glob(pattern)) + for old in backups[:-keep]: + try: + os.remove(old) + except OSError: + logger.warning("Could not remove old backup: %s", old) + + +@app.post("/train") +async def train_model( + file: UploadFile, + x_training_token: str | None = Header(default=None), +): + """Fine-tune the Kurrent recognition model with uploaded training data. + + Accepts a ZIP archive containing .png/.gt.txt training pairs exported + by the Java backend. Training mutates in-process model state — not safe + if the service is replicated. + """ + _check_training_token(x_training_token) + + if not _models_ready: + raise HTTPException(status_code=503, detail="Models not loaded yet") + + zip_bytes = await file.read() + + training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id}) + log.info("Starting training run %s", training_run_id) + + def _run_training() -> dict: + with tempfile.TemporaryDirectory() as tmp_dir: + # Extract ZIP with safety checks + with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: + for entry in zf.namelist(): + _validate_zip_entry(entry, tmp_dir) + zf.extractall(tmp_dir) + + log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir) + + # Run ketos train (transfer learning from existing model) + from kraken import ketos + ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt")) + if not ground_truth: + raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP") + + log.info("Training on %d ground-truth pairs", len(ground_truth)) + output_model_path = os.path.join(tmp_dir, "fine_tuned.mlmodel") + + result = ketos.train( + ground_truth=ground_truth, + load=KRAKEN_MODEL_PATH, + output=output_model_path, + format_type="path", + ) + + epochs = getattr(result, "epochs", None) or 0 + loss = getattr(result, "best_loss", None) + accuracy = getattr(result, "best_accuracy", None) + + log.info("Training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy) + + # Backup existing model and replace + if os.path.exists(KRAKEN_MODEL_PATH): + timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + backup_path = f"{KRAKEN_MODEL_PATH}.{timestamp}.bak" + shutil.copy2(KRAKEN_MODEL_PATH, backup_path) + log.info("Backed up model to %s", backup_path) + _rotate_backups(KRAKEN_MODEL_PATH, keep=3) + + shutil.move(output_model_path, KRAKEN_MODEL_PATH) + log.info("Replaced model at %s", KRAKEN_MODEL_PATH) + + # Reload model in-process + kraken_engine.load_models() + log.info("Reloaded Kraken model in-process") + + return {"loss": loss, "accuracy": accuracy, "epochs": epochs} + + result = await asyncio.to_thread(_run_training) + return result + + async def _download_and_convert_pdf(url: str) -> list[Image.Image]: """Download a PDF from a presigned URL and convert each page to a PIL Image.""" _validate_url(url)