feat(ocr): add /train endpoint to OCR service and OcrClient.trainModel()
- POST /train in ocr-service with ZIP Slip validation, TemporaryDirectory, ketos transfer learning, timestamped backups (keep last 3), in-process reload - X-Training-Token auth (no-op in dev when TRAINING_TOKEN env is empty) - trainModel() in OcrClient interface + RestClientOcrClient (10-min timeout, multipart upload, forwards X-Training-Token when configured) - TRAINING_TOKEN env var wired in docker-compose; --workers 2 in Dockerfile so /health stays responsive during synchronous training Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,16 @@ import java.util.function.Consumer;
|
||||
public interface OcrClient {
|
||||
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
|
||||
|
||||
/**
|
||||
* Send a training ZIP to the OCR service for fine-tuning the Kurrent model.
|
||||
*
|
||||
* @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService
|
||||
* @return training result metrics (loss, accuracy, epochs)
|
||||
*/
|
||||
TrainingResult trainModel(byte[] trainingDataZip);
|
||||
|
||||
record TrainingResult(Double loss, Double accuracy, Integer epochs) {}
|
||||
|
||||
/**
|
||||
* Stream OCR results page-by-page via NDJSON. Implementations should override
|
||||
* this method. The default exists only for backward compatibility during migration
|
||||
|
||||
@@ -9,9 +9,14 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.io.ByteArrayResource;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.client.JdkClientHttpRequestFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.client.RestClient;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
@@ -36,11 +41,16 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true);
|
||||
|
||||
private final RestClient restClient;
|
||||
private final RestClient trainingRestClient;
|
||||
private final HttpClient streamingHttpClient;
|
||||
private final String baseUrl;
|
||||
private final String trainingToken;
|
||||
|
||||
public RestClientOcrClient(@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl) {
|
||||
public RestClientOcrClient(
|
||||
@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl,
|
||||
@Value("${app.ocr.training-token:}") String trainingToken) {
|
||||
this.baseUrl = baseUrl;
|
||||
this.trainingToken = trainingToken;
|
||||
|
||||
HttpClient httpClient = HttpClient.newBuilder()
|
||||
.version(HttpClient.Version.HTTP_1_1)
|
||||
@@ -54,6 +64,17 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
.requestFactory(requestFactory)
|
||||
.build();
|
||||
|
||||
HttpClient trainingHttpClient = HttpClient.newBuilder()
|
||||
.version(HttpClient.Version.HTTP_1_1)
|
||||
.connectTimeout(Duration.ofSeconds(10))
|
||||
.build();
|
||||
JdkClientHttpRequestFactory trainingRequestFactory = new JdkClientHttpRequestFactory(trainingHttpClient);
|
||||
trainingRequestFactory.setReadTimeout(Duration.ofMinutes(10));
|
||||
this.trainingRestClient = RestClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.requestFactory(trainingRequestFactory)
|
||||
.build();
|
||||
|
||||
this.streamingHttpClient = HttpClient.newBuilder()
|
||||
.version(HttpClient.Version.HTTP_1_1)
|
||||
.connectTimeout(Duration.ofSeconds(10))
|
||||
@@ -81,6 +102,35 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
.toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) {
|
||||
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
|
||||
@Override
|
||||
public String getFilename() { return "training-data.zip"; }
|
||||
};
|
||||
|
||||
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
|
||||
HttpHeaders partHeaders = new HttpHeaders();
|
||||
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
|
||||
body.add("file", new HttpEntity<>(zipResource, partHeaders));
|
||||
|
||||
var spec = trainingRestClient.post()
|
||||
.uri("/train")
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA);
|
||||
|
||||
if (trainingToken != null && !trainingToken.isBlank()) {
|
||||
spec = spec.header("X-Training-Token", trainingToken);
|
||||
}
|
||||
|
||||
TrainingResultJson result = spec
|
||||
.body(body)
|
||||
.retrieve()
|
||||
.body(TrainingResultJson.class);
|
||||
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isHealthy() {
|
||||
try {
|
||||
@@ -171,6 +221,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
}
|
||||
}
|
||||
|
||||
record TrainingResultJson(Double loss, Double accuracy, Integer epochs) {}
|
||||
|
||||
record OcrBlockJson(
|
||||
@JsonProperty("pageNumber") int pageNumber,
|
||||
double x,
|
||||
|
||||
@@ -12,10 +12,15 @@ class OcrClientDefaultStreamTest {
|
||||
|
||||
@Test
|
||||
void defaultStreamBlocksSynthesizesEventsFromExtractBlocks() {
|
||||
OcrClient client = (pdfUrl, scriptType) -> List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
|
||||
OcrClient client = new OcrClient() {
|
||||
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
|
||||
return List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
|
||||
}
|
||||
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||
};
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
||||
@@ -42,7 +47,10 @@ class OcrClientDefaultStreamTest {
|
||||
|
||||
@Test
|
||||
void defaultStreamBlocksHandlesEmptyResults() {
|
||||
OcrClient client = (pdfUrl, scriptType) -> List.of();
|
||||
OcrClient client = new OcrClient() {
|
||||
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); }
|
||||
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||
};
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
||||
|
||||
Reference in New Issue
Block a user