feat(ocr): add /train endpoint to OCR service and OcrClient.trainModel()

- POST /train in ocr-service with ZIP Slip validation, TemporaryDirectory,
  ketos transfer learning, timestamped backups (keep last 3), in-process reload
- X-Training-Token auth (no-op in dev when TRAINING_TOKEN env is empty)
- trainModel() in OcrClient interface + RestClientOcrClient (10-min timeout,
  multipart upload, forwards X-Training-Token when configured)
- TRAINING_TOKEN env var wired in docker-compose; --workers 2 in Dockerfile
  so /health stays responsive during synchronous training

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-13 14:40:53 +02:00
parent cfa3c4df67
commit bc97a2dade
6 changed files with 188 additions and 8 deletions

View File

@@ -10,6 +10,16 @@ import java.util.function.Consumer;
public interface OcrClient { public interface OcrClient {
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType); List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
/**
* Send a training ZIP to the OCR service for fine-tuning the Kurrent model.
*
* @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService
* @return training result metrics (loss, accuracy, epochs)
*/
TrainingResult trainModel(byte[] trainingDataZip);
record TrainingResult(Double loss, Double accuracy, Integer epochs) {}
/** /**
* Stream OCR results page-by-page via NDJSON. Implementations should override * Stream OCR results page-by-page via NDJSON. Implementations should override
* this method. The default exists only for backward compatibility during migration * this method. The default exists only for backward compatibility during migration

View File

@@ -9,9 +9,14 @@ import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.model.ScriptType; import org.raddatz.familienarchiv.model.ScriptType;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.client.JdkClientHttpRequestFactory; import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestClient; import org.springframework.web.client.RestClient;
import java.io.BufferedReader; import java.io.BufferedReader;
@@ -36,11 +41,16 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true); .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true);
private final RestClient restClient; private final RestClient restClient;
private final RestClient trainingRestClient;
private final HttpClient streamingHttpClient; private final HttpClient streamingHttpClient;
private final String baseUrl; private final String baseUrl;
private final String trainingToken;
public RestClientOcrClient(@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl) { public RestClientOcrClient(
@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl,
@Value("${app.ocr.training-token:}") String trainingToken) {
this.baseUrl = baseUrl; this.baseUrl = baseUrl;
this.trainingToken = trainingToken;
HttpClient httpClient = HttpClient.newBuilder() HttpClient httpClient = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1) .version(HttpClient.Version.HTTP_1_1)
@@ -54,6 +64,17 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.requestFactory(requestFactory) .requestFactory(requestFactory)
.build(); .build();
HttpClient trainingHttpClient = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
.connectTimeout(Duration.ofSeconds(10))
.build();
JdkClientHttpRequestFactory trainingRequestFactory = new JdkClientHttpRequestFactory(trainingHttpClient);
trainingRequestFactory.setReadTimeout(Duration.ofMinutes(10));
this.trainingRestClient = RestClient.builder()
.baseUrl(baseUrl)
.requestFactory(trainingRequestFactory)
.build();
this.streamingHttpClient = HttpClient.newBuilder() this.streamingHttpClient = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1) .version(HttpClient.Version.HTTP_1_1)
.connectTimeout(Duration.ofSeconds(10)) .connectTimeout(Duration.ofSeconds(10))
@@ -81,6 +102,35 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.toList(); .toList();
} }
@Override
public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) {
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
@Override
public String getFilename() { return "training-data.zip"; }
};
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
HttpHeaders partHeaders = new HttpHeaders();
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders));
var spec = trainingRestClient.post()
.uri("/train")
.contentType(MediaType.MULTIPART_FORM_DATA);
if (trainingToken != null && !trainingToken.isBlank()) {
spec = spec.header("X-Training-Token", trainingToken);
}
TrainingResultJson result = spec
.body(body)
.retrieve()
.body(TrainingResultJson.class);
if (result == null) return new OcrClient.TrainingResult(null, null, null);
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
}
@Override @Override
public boolean isHealthy() { public boolean isHealthy() {
try { try {
@@ -171,6 +221,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
} }
} }
record TrainingResultJson(Double loss, Double accuracy, Integer epochs) {}
record OcrBlockJson( record OcrBlockJson(
@JsonProperty("pageNumber") int pageNumber, @JsonProperty("pageNumber") int pageNumber,
double x, double x,

View File

@@ -12,10 +12,15 @@ class OcrClientDefaultStreamTest {
@Test @Test
void defaultStreamBlocksSynthesizesEventsFromExtractBlocks() { void defaultStreamBlocksSynthesizesEventsFromExtractBlocks() {
OcrClient client = (pdfUrl, scriptType) -> List.of( OcrClient client = new OcrClient() {
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"), @Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"), return List.of(
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3")); new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
}
@Override public TrainingResult trainModel(byte[] zip) { return null; }
};
List<OcrStreamEvent> events = new ArrayList<>(); List<OcrStreamEvent> events = new ArrayList<>();
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add); client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
@@ -42,7 +47,10 @@ class OcrClientDefaultStreamTest {
@Test @Test
void defaultStreamBlocksHandlesEmptyResults() { void defaultStreamBlocksHandlesEmptyResults() {
OcrClient client = (pdfUrl, scriptType) -> List.of(); OcrClient client = new OcrClient() {
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); }
@Override public TrainingResult trainModel(byte[] zip) { return null; }
};
List<OcrStreamEvent> events = new ArrayList<>(); List<OcrStreamEvent> events = new ArrayList<>();
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add); client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);

View File

@@ -87,6 +87,7 @@ services:
- ocr_cache:/root/.cache - ocr_cache:/root/.cache
environment: environment:
KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel
TRAINING_TOKEN: "${OCR_TRAINING_TOKEN:-}"
OCR_CONFIDENCE_THRESHOLD: "0.3" OCR_CONFIDENCE_THRESHOLD: "0.3"
OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5" OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5"
RECOGNITION_BATCH_SIZE: "16" RECOGNITION_BATCH_SIZE: "16"

View File

@@ -23,4 +23,4 @@ COPY . .
EXPOSE 8000 EXPOSE 8000
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]

View File

@@ -1,16 +1,21 @@
"""OCR microservice — FastAPI app with Surya and Kraken engine support.""" """OCR microservice — FastAPI app with Surya and Kraken engine support."""
import asyncio import asyncio
import glob
import io import io
import json import json
import logging import logging
import os import os
import shutil
import tempfile
import zipfile
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timezone
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
import pypdfium2 as pdfium import pypdfium2 as pdfium
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, Header, HTTPException, UploadFile
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from PIL import Image from PIL import Image
@@ -19,6 +24,9 @@ from engines import kraken as kraken_engine
from engines import surya as surya_engine from engines import surya as surya_engine
from models import OcrBlock, OcrRequest from models import OcrBlock, OcrRequest
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -169,6 +177,107 @@ async def run_ocr_stream(request: OcrRequest):
) )
def _check_training_token(x_training_token: str | None) -> None:
"""Validate training token if TRAINING_TOKEN env var is set."""
if TRAINING_TOKEN and x_training_token != TRAINING_TOKEN:
raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token")
def _validate_zip_entry(name: str, extract_dir: str) -> None:
"""Reject ZIP Slip attacks: path traversal and absolute paths."""
if os.path.isabs(name) or name.startswith(".."):
raise HTTPException(status_code=400, detail=f"Unsafe ZIP entry: {name}")
resolved = os.path.realpath(os.path.join(extract_dir, name))
if not resolved.startswith(os.path.realpath(extract_dir)):
raise HTTPException(status_code=400, detail=f"ZIP Slip detected: {name}")
def _rotate_backups(model_path: str, keep: int = 3) -> None:
"""Keep only the last `keep` timestamped backups of the model."""
pattern = model_path + ".*.bak"
backups = sorted(glob.glob(pattern))
for old in backups[:-keep]:
try:
os.remove(old)
except OSError:
logger.warning("Could not remove old backup: %s", old)
@app.post("/train")
async def train_model(
file: UploadFile,
x_training_token: str | None = Header(default=None),
):
"""Fine-tune the Kurrent recognition model with uploaded training data.
Accepts a ZIP archive containing .png/.gt.txt training pairs exported
by the Java backend. Training mutates in-process model state — not safe
if the service is replicated.
"""
_check_training_token(x_training_token)
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
zip_bytes = await file.read()
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
log.info("Starting training run %s", training_run_id)
def _run_training() -> dict:
with tempfile.TemporaryDirectory() as tmp_dir:
# Extract ZIP with safety checks
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
for entry in zf.namelist():
_validate_zip_entry(entry, tmp_dir)
zf.extractall(tmp_dir)
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
# Run ketos train (transfer learning from existing model)
from kraken import ketos
ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt"))
if not ground_truth:
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
log.info("Training on %d ground-truth pairs", len(ground_truth))
output_model_path = os.path.join(tmp_dir, "fine_tuned.mlmodel")
result = ketos.train(
ground_truth=ground_truth,
load=KRAKEN_MODEL_PATH,
output=output_model_path,
format_type="path",
)
epochs = getattr(result, "epochs", None) or 0
loss = getattr(result, "best_loss", None)
accuracy = getattr(result, "best_accuracy", None)
log.info("Training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
# Backup existing model and replace
if os.path.exists(KRAKEN_MODEL_PATH):
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
backup_path = f"{KRAKEN_MODEL_PATH}.{timestamp}.bak"
shutil.copy2(KRAKEN_MODEL_PATH, backup_path)
log.info("Backed up model to %s", backup_path)
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
shutil.move(output_model_path, KRAKEN_MODEL_PATH)
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
# Reload model in-process
kraken_engine.load_models()
log.info("Reloaded Kraken model in-process")
return {"loss": loss, "accuracy": accuracy, "epochs": epochs}
result = await asyncio.to_thread(_run_training)
return result
async def _download_and_convert_pdf(url: str) -> list[Image.Image]: async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
"""Download a PDF from a presigned URL and convert each page to a PIL Image.""" """Download a PDF from a presigned URL and convert each page to a PIL Image."""
_validate_url(url) _validate_url(url)