feat(ocr): add /train endpoint to OCR service and OcrClient.trainModel()
- POST /train in ocr-service with ZIP Slip validation, TemporaryDirectory, ketos transfer learning, timestamped backups (keep last 3), in-process reload - X-Training-Token auth (no-op in dev when TRAINING_TOKEN env is empty) - trainModel() in OcrClient interface + RestClientOcrClient (10-min timeout, multipart upload, forwards X-Training-Token when configured) - TRAINING_TOKEN env var wired in docker-compose; --workers 2 in Dockerfile so /health stays responsive during synchronous training Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,16 @@ import java.util.function.Consumer;
|
|||||||
public interface OcrClient {
|
public interface OcrClient {
|
||||||
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
|
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send a training ZIP to the OCR service for fine-tuning the Kurrent model.
|
||||||
|
*
|
||||||
|
* @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService
|
||||||
|
* @return training result metrics (loss, accuracy, epochs)
|
||||||
|
*/
|
||||||
|
TrainingResult trainModel(byte[] trainingDataZip);
|
||||||
|
|
||||||
|
record TrainingResult(Double loss, Double accuracy, Integer epochs) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stream OCR results page-by-page via NDJSON. Implementations should override
|
* Stream OCR results page-by-page via NDJSON. Implementations should override
|
||||||
* this method. The default exists only for backward compatibility during migration
|
* this method. The default exists only for backward compatibility during migration
|
||||||
|
|||||||
@@ -9,9 +9,14 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.raddatz.familienarchiv.model.ScriptType;
|
import org.raddatz.familienarchiv.model.ScriptType;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.core.ParameterizedTypeReference;
|
import org.springframework.core.ParameterizedTypeReference;
|
||||||
|
import org.springframework.core.io.ByteArrayResource;
|
||||||
|
import org.springframework.http.HttpEntity;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
import org.springframework.http.client.JdkClientHttpRequestFactory;
|
import org.springframework.http.client.JdkClientHttpRequestFactory;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
import org.springframework.util.LinkedMultiValueMap;
|
||||||
|
import org.springframework.util.MultiValueMap;
|
||||||
import org.springframework.web.client.RestClient;
|
import org.springframework.web.client.RestClient;
|
||||||
|
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
@@ -36,11 +41,16 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
|||||||
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true);
|
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true);
|
||||||
|
|
||||||
private final RestClient restClient;
|
private final RestClient restClient;
|
||||||
|
private final RestClient trainingRestClient;
|
||||||
private final HttpClient streamingHttpClient;
|
private final HttpClient streamingHttpClient;
|
||||||
private final String baseUrl;
|
private final String baseUrl;
|
||||||
|
private final String trainingToken;
|
||||||
|
|
||||||
public RestClientOcrClient(@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl) {
|
public RestClientOcrClient(
|
||||||
|
@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl,
|
||||||
|
@Value("${app.ocr.training-token:}") String trainingToken) {
|
||||||
this.baseUrl = baseUrl;
|
this.baseUrl = baseUrl;
|
||||||
|
this.trainingToken = trainingToken;
|
||||||
|
|
||||||
HttpClient httpClient = HttpClient.newBuilder()
|
HttpClient httpClient = HttpClient.newBuilder()
|
||||||
.version(HttpClient.Version.HTTP_1_1)
|
.version(HttpClient.Version.HTTP_1_1)
|
||||||
@@ -54,6 +64,17 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
|||||||
.requestFactory(requestFactory)
|
.requestFactory(requestFactory)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
HttpClient trainingHttpClient = HttpClient.newBuilder()
|
||||||
|
.version(HttpClient.Version.HTTP_1_1)
|
||||||
|
.connectTimeout(Duration.ofSeconds(10))
|
||||||
|
.build();
|
||||||
|
JdkClientHttpRequestFactory trainingRequestFactory = new JdkClientHttpRequestFactory(trainingHttpClient);
|
||||||
|
trainingRequestFactory.setReadTimeout(Duration.ofMinutes(10));
|
||||||
|
this.trainingRestClient = RestClient.builder()
|
||||||
|
.baseUrl(baseUrl)
|
||||||
|
.requestFactory(trainingRequestFactory)
|
||||||
|
.build();
|
||||||
|
|
||||||
this.streamingHttpClient = HttpClient.newBuilder()
|
this.streamingHttpClient = HttpClient.newBuilder()
|
||||||
.version(HttpClient.Version.HTTP_1_1)
|
.version(HttpClient.Version.HTTP_1_1)
|
||||||
.connectTimeout(Duration.ofSeconds(10))
|
.connectTimeout(Duration.ofSeconds(10))
|
||||||
@@ -81,6 +102,35 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
|||||||
.toList();
|
.toList();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) {
|
||||||
|
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
|
||||||
|
@Override
|
||||||
|
public String getFilename() { return "training-data.zip"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
|
||||||
|
HttpHeaders partHeaders = new HttpHeaders();
|
||||||
|
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
|
||||||
|
body.add("file", new HttpEntity<>(zipResource, partHeaders));
|
||||||
|
|
||||||
|
var spec = trainingRestClient.post()
|
||||||
|
.uri("/train")
|
||||||
|
.contentType(MediaType.MULTIPART_FORM_DATA);
|
||||||
|
|
||||||
|
if (trainingToken != null && !trainingToken.isBlank()) {
|
||||||
|
spec = spec.header("X-Training-Token", trainingToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
TrainingResultJson result = spec
|
||||||
|
.body(body)
|
||||||
|
.retrieve()
|
||||||
|
.body(TrainingResultJson.class);
|
||||||
|
|
||||||
|
if (result == null) return new OcrClient.TrainingResult(null, null, null);
|
||||||
|
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isHealthy() {
|
public boolean isHealthy() {
|
||||||
try {
|
try {
|
||||||
@@ -171,6 +221,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
record TrainingResultJson(Double loss, Double accuracy, Integer epochs) {}
|
||||||
|
|
||||||
record OcrBlockJson(
|
record OcrBlockJson(
|
||||||
@JsonProperty("pageNumber") int pageNumber,
|
@JsonProperty("pageNumber") int pageNumber,
|
||||||
double x,
|
double x,
|
||||||
|
|||||||
@@ -12,10 +12,15 @@ class OcrClientDefaultStreamTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void defaultStreamBlocksSynthesizesEventsFromExtractBlocks() {
|
void defaultStreamBlocksSynthesizesEventsFromExtractBlocks() {
|
||||||
OcrClient client = (pdfUrl, scriptType) -> List.of(
|
OcrClient client = new OcrClient() {
|
||||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
|
||||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
|
return List.of(
|
||||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
|
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
||||||
|
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
|
||||||
|
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
|
||||||
|
}
|
||||||
|
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||||
|
};
|
||||||
|
|
||||||
List<OcrStreamEvent> events = new ArrayList<>();
|
List<OcrStreamEvent> events = new ArrayList<>();
|
||||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
||||||
@@ -42,7 +47,10 @@ class OcrClientDefaultStreamTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void defaultStreamBlocksHandlesEmptyResults() {
|
void defaultStreamBlocksHandlesEmptyResults() {
|
||||||
OcrClient client = (pdfUrl, scriptType) -> List.of();
|
OcrClient client = new OcrClient() {
|
||||||
|
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); }
|
||||||
|
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||||
|
};
|
||||||
|
|
||||||
List<OcrStreamEvent> events = new ArrayList<>();
|
List<OcrStreamEvent> events = new ArrayList<>();
|
||||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ services:
|
|||||||
- ocr_cache:/root/.cache
|
- ocr_cache:/root/.cache
|
||||||
environment:
|
environment:
|
||||||
KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel
|
KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel
|
||||||
|
TRAINING_TOKEN: "${OCR_TRAINING_TOKEN:-}"
|
||||||
OCR_CONFIDENCE_THRESHOLD: "0.3"
|
OCR_CONFIDENCE_THRESHOLD: "0.3"
|
||||||
OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5"
|
OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5"
|
||||||
RECOGNITION_BATCH_SIZE: "16"
|
RECOGNITION_BATCH_SIZE: "16"
|
||||||
|
|||||||
@@ -23,4 +23,4 @@ COPY . .
|
|||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]
|
||||||
|
|||||||
@@ -1,16 +1,21 @@
|
|||||||
"""OCR microservice — FastAPI app with Surya and Kraken engine support."""
|
"""OCR microservice — FastAPI app with Surya and Kraken engine support."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import glob
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import datetime, timezone
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pypdfium2 as pdfium
|
import pypdfium2 as pdfium
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, Header, HTTPException, UploadFile
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@@ -19,6 +24,9 @@ from engines import kraken as kraken_engine
|
|||||||
from engines import surya as surya_engine
|
from engines import surya as surya_engine
|
||||||
from models import OcrBlock, OcrRequest
|
from models import OcrBlock, OcrRequest
|
||||||
|
|
||||||
|
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
|
||||||
|
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -169,6 +177,107 @@ async def run_ocr_stream(request: OcrRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_training_token(x_training_token: str | None) -> None:
|
||||||
|
"""Validate training token if TRAINING_TOKEN env var is set."""
|
||||||
|
if TRAINING_TOKEN and x_training_token != TRAINING_TOKEN:
|
||||||
|
raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_zip_entry(name: str, extract_dir: str) -> None:
|
||||||
|
"""Reject ZIP Slip attacks: path traversal and absolute paths."""
|
||||||
|
if os.path.isabs(name) or name.startswith(".."):
|
||||||
|
raise HTTPException(status_code=400, detail=f"Unsafe ZIP entry: {name}")
|
||||||
|
resolved = os.path.realpath(os.path.join(extract_dir, name))
|
||||||
|
if not resolved.startswith(os.path.realpath(extract_dir)):
|
||||||
|
raise HTTPException(status_code=400, detail=f"ZIP Slip detected: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
def _rotate_backups(model_path: str, keep: int = 3) -> None:
|
||||||
|
"""Keep only the last `keep` timestamped backups of the model."""
|
||||||
|
pattern = model_path + ".*.bak"
|
||||||
|
backups = sorted(glob.glob(pattern))
|
||||||
|
for old in backups[:-keep]:
|
||||||
|
try:
|
||||||
|
os.remove(old)
|
||||||
|
except OSError:
|
||||||
|
logger.warning("Could not remove old backup: %s", old)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/train")
|
||||||
|
async def train_model(
|
||||||
|
file: UploadFile,
|
||||||
|
x_training_token: str | None = Header(default=None),
|
||||||
|
):
|
||||||
|
"""Fine-tune the Kurrent recognition model with uploaded training data.
|
||||||
|
|
||||||
|
Accepts a ZIP archive containing .png/.gt.txt training pairs exported
|
||||||
|
by the Java backend. Training mutates in-process model state — not safe
|
||||||
|
if the service is replicated.
|
||||||
|
"""
|
||||||
|
_check_training_token(x_training_token)
|
||||||
|
|
||||||
|
if not _models_ready:
|
||||||
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||||
|
|
||||||
|
zip_bytes = await file.read()
|
||||||
|
|
||||||
|
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||||
|
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
||||||
|
log.info("Starting training run %s", training_run_id)
|
||||||
|
|
||||||
|
def _run_training() -> dict:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# Extract ZIP with safety checks
|
||||||
|
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||||
|
for entry in zf.namelist():
|
||||||
|
_validate_zip_entry(entry, tmp_dir)
|
||||||
|
zf.extractall(tmp_dir)
|
||||||
|
|
||||||
|
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
|
||||||
|
|
||||||
|
# Run ketos train (transfer learning from existing model)
|
||||||
|
from kraken import ketos
|
||||||
|
ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt"))
|
||||||
|
if not ground_truth:
|
||||||
|
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
|
||||||
|
|
||||||
|
log.info("Training on %d ground-truth pairs", len(ground_truth))
|
||||||
|
output_model_path = os.path.join(tmp_dir, "fine_tuned.mlmodel")
|
||||||
|
|
||||||
|
result = ketos.train(
|
||||||
|
ground_truth=ground_truth,
|
||||||
|
load=KRAKEN_MODEL_PATH,
|
||||||
|
output=output_model_path,
|
||||||
|
format_type="path",
|
||||||
|
)
|
||||||
|
|
||||||
|
epochs = getattr(result, "epochs", None) or 0
|
||||||
|
loss = getattr(result, "best_loss", None)
|
||||||
|
accuracy = getattr(result, "best_accuracy", None)
|
||||||
|
|
||||||
|
log.info("Training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
|
||||||
|
|
||||||
|
# Backup existing model and replace
|
||||||
|
if os.path.exists(KRAKEN_MODEL_PATH):
|
||||||
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||||
|
backup_path = f"{KRAKEN_MODEL_PATH}.{timestamp}.bak"
|
||||||
|
shutil.copy2(KRAKEN_MODEL_PATH, backup_path)
|
||||||
|
log.info("Backed up model to %s", backup_path)
|
||||||
|
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
|
||||||
|
|
||||||
|
shutil.move(output_model_path, KRAKEN_MODEL_PATH)
|
||||||
|
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
|
||||||
|
|
||||||
|
# Reload model in-process
|
||||||
|
kraken_engine.load_models()
|
||||||
|
log.info("Reloaded Kraken model in-process")
|
||||||
|
|
||||||
|
return {"loss": loss, "accuracy": accuracy, "epochs": epochs}
|
||||||
|
|
||||||
|
result = await asyncio.to_thread(_run_training)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
||||||
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
||||||
_validate_url(url)
|
_validate_url(url)
|
||||||
|
|||||||
Reference in New Issue
Block a user