feat(ocr): wire SenderModelService into OcrAsyncRunner; stage missing foundational files
Some checks failed
CI / Unit & Component Tests (push) Failing after 2m21s
CI / OCR Service Tests (push) Successful in 29s
CI / Backend Unit Tests (push) Failing after 2m38s
CI / Unit & Component Tests (pull_request) Failing after 2m26s
CI / OCR Service Tests (pull_request) Successful in 31s
CI / Backend Unit Tests (pull_request) Failing after 2m44s

OcrAsyncRunner now passes the per-sender model path to streamBlocks for
HANDWRITING_KURRENT documents. processDocument replaced extractBlocks
with streamBlocks + AtomicReference, removing the unchecked raw-array
pattern.

Also stages all previously uncommitted foundational files for this
feature: SenderModel entity, SenderModelRepository, Flyway migrations
V40/V41, updated OcrClient/RestClientOcrClient streaming API,
TrainingDataExportService.exportForSender, TranscriptionService Kurrent
hook, application.yaml OCR config, and frontend i18n/test additions.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-17 19:27:02 +02:00
parent 386dc83958
commit 18cf839fac
20 changed files with 281 additions and 31 deletions

View File

@@ -59,6 +59,9 @@ public class OcrTrainingRun {
@Column(name = "triggered_by")
private UUID triggeredBy;
@Column(name = "person_id")
private UUID personId;
@CreationTimestamp
@Column(name = "created_at", nullable = false, updatable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)

View File

@@ -0,0 +1,56 @@
package org.raddatz.familienarchiv.model;
import com.fasterxml.jackson.annotation.JsonIgnore;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.persistence.*;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.hibernate.annotations.CreationTimestamp;
import org.hibernate.annotations.UpdateTimestamp;
import java.time.Instant;
import java.util.UUID;
@Entity
@Table(name = "sender_models")
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class SenderModel {
@Id
@GeneratedValue(strategy = GenerationType.UUID)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID id;
@Column(name = "person_id", nullable = false, unique = true)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID personId;
@JsonIgnore
@Column(name = "model_path", nullable = false)
private String modelPath;
@Column
private Double accuracy;
@Column
private Double cer;
@Column(name = "corrected_lines_at_training", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private int correctedLinesAtTraining;
@CreationTimestamp
@Column(name = "created_at", nullable = false, updatable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private Instant createdAt;
@UpdateTimestamp
@Column(name = "updated_at", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private Instant updatedAt;
}

View File

@@ -1,6 +1,7 @@
package org.raddatz.familienarchiv.model;
public enum TrainingStatus {
QUEUED,
RUNNING,
DONE,
FAILED

View File

@@ -12,5 +12,9 @@ public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun,
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
List<OcrTrainingRun> findTop10ByOrderByCreatedAtDesc();
Optional<OcrTrainingRun> findFirstByStatusOrderByCreatedAtAsc(TrainingStatus status);
boolean existsByPersonIdAndStatus(UUID personId, TrainingStatus status);
List<OcrTrainingRun> findTop20ByOrderByCreatedAtDesc();
}

View File

@@ -0,0 +1,12 @@
package org.raddatz.familienarchiv.repository;
import org.raddatz.familienarchiv.model.SenderModel;
import org.springframework.data.jpa.repository.JpaRepository;
import java.util.Optional;
import java.util.UUID;
public interface SenderModelRepository extends JpaRepository<SenderModel, UUID> {
Optional<SenderModel> findByPersonId(UUID personId);
}

View File

@@ -3,6 +3,7 @@ package org.raddatz.familienarchiv.repository;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import java.util.List;
import java.util.Optional;
@@ -37,4 +38,22 @@ public interface TranscriptionBlockRepository extends JpaRepository<Transcriptio
AND 'KURRENT_SEGMENTATION' MEMBER OF d.trainingLabels
""")
List<TranscriptionBlock> findSegmentationBlocks();
@Query("""
SELECT COUNT(b) FROM TranscriptionBlock b
JOIN Document d ON d.id = b.documentId
WHERE b.source = 'MANUAL'
AND d.sender.id = :personId
AND d.scriptType = 'HANDWRITING_KURRENT'
""")
long countManualKurrentBlocksByPerson(@Param("personId") UUID personId);
@Query("""
SELECT b FROM TranscriptionBlock b
JOIN Document d ON d.id = b.documentId
WHERE b.source = 'MANUAL'
AND d.sender.id = :personId
AND d.scriptType = 'HANDWRITING_KURRENT'
""")
List<TranscriptionBlock> findManualKurrentBlocksByPerson(@Param("personId") UUID personId);
}

View File

@@ -9,10 +9,12 @@ import org.raddatz.familienarchiv.repository.OcrJobRepository;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
@Component
@RequiredArgsConstructor
@@ -29,6 +31,7 @@ public class OcrAsyncRunner {
private final OcrJobRepository ocrJobRepository;
private final OcrJobDocumentRepository ocrJobDocumentRepository;
private final OcrProgressService ocrProgressService;
private final SenderModelService senderModelService;
@Async
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
@@ -68,12 +71,18 @@ public class OcrAsyncRunner {
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
String senderModelPath = null;
if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) {
senderModelPath = senderModelService.maybeGetModelPath(doc.getSender().getId()).orElse(null);
}
AtomicInteger blockCounter = new AtomicInteger(0);
AtomicInteger currentPage = new AtomicInteger(0);
AtomicInteger skippedPages = new AtomicInteger(0);
AtomicInteger totalPages = new AtomicInteger(0);
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, event -> {
final String finalSenderModelPath = senderModelPath;
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, finalSenderModelPath, event -> {
switch (event) {
case OcrStreamEvent.Start start -> {
totalPages.set(start.totalPages());
@@ -207,7 +216,25 @@ public class OcrAsyncRunner {
clearExistingBlocks(documentId);
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
List<OcrBlockResult> blocks = ocrClient.extractBlocks(pdfUrl, doc.getScriptType());
String senderModelPath = null;
if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) {
senderModelPath = senderModelService.maybeGetModelPath(doc.getSender().getId()).orElse(null);
}
final AtomicReference<List<OcrBlockResult>> blocksRef = new AtomicReference<>();
final String finalSenderModelPath = senderModelPath;
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), null, finalSenderModelPath, event -> {
switch (event) {
case OcrStreamEvent.Page page -> {
blocksRef.compareAndSet(null, new ArrayList<>());
blocksRef.get().addAll(page.blocks());
}
default -> {}
}
});
List<OcrBlockResult> blocks = blocksRef.get() != null ? blocksRef.get() : List.of();
createTranscriptionBlocks(documentId, blocks, userId, doc.getFileHash());
}

View File

@@ -1,6 +1,7 @@
package org.raddatz.familienarchiv.service;
import org.raddatz.familienarchiv.model.ScriptType;
import org.springframework.lang.Nullable;
import java.util.ArrayList;
import java.util.LinkedHashMap;
@@ -37,15 +38,27 @@ public interface OcrClient {
TrainingResult segtrainModel(byte[] trainingDataZip);
/**
* Stream OCR results page-by-page via NDJSON. Implementations should override
* this method. The default exists only for backward compatibility during migration
* — it calls extractBlocks() and synthesizes events from the collected result.
* Fine-tune the Kurrent model for a specific sender.
*
* @param regions optional list of pre-drawn annotation regions; when non-null,
* the OCR service runs in guided mode (crop + recognize per region)
* @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService.exportForSender()
* @param outputModelPath where to save the trained model (e.g. /app/models/sender_{uuid}.mlmodel)
* @return training result metrics
*/
TrainingResult trainSenderModel(byte[] trainingDataZip, String outputModelPath);
/**
* Stream OCR results page-by-page via NDJSON, optionally using a sender-specific model.
* The default implementation synthesizes events from extractBlocks() for backward compatibility.
* Implementations that support real streaming (e.g. RestClientOcrClient) override this.
*
* @param regions optional list of pre-drawn annotation regions; when non-null,
* the OCR service runs in guided mode (crop + recognize per region)
* @param senderModelPath optional path to a per-sender model file; null means use base model
*/
default void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
List<OcrRegion> regions,
@Nullable String senderModelPath,
Consumer<OcrStreamEvent> handler) {
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();
@@ -62,4 +75,9 @@ public interface OcrClient {
handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0));
}
default void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
streamBlocks(pdfUrl, scriptType, regions, null, handler);
}
}

View File

@@ -14,6 +14,7 @@ import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
@@ -102,6 +103,13 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.toList();
}
private RestClient.RequestBodySpec addTrainingAuth(RestClient.RequestBodySpec spec) {
if (trainingToken != null && !trainingToken.isBlank()) {
return spec.header("X-Training-Token", trainingToken);
}
return spec;
}
@Override
public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) {
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
@@ -114,15 +122,10 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders));
var spec = trainingRestClient.post()
.uri("/train")
.contentType(MediaType.MULTIPART_FORM_DATA);
if (trainingToken != null && !trainingToken.isBlank()) {
spec = spec.header("X-Training-Token", trainingToken);
}
TrainingResultJson result = spec
TrainingResultJson result = addTrainingAuth(
trainingRestClient.post()
.uri("/train")
.contentType(MediaType.MULTIPART_FORM_DATA))
.body(body)
.retrieve()
.body(TrainingResultJson.class);
@@ -143,15 +146,35 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders));
var spec = trainingRestClient.post()
.uri("/segtrain")
.contentType(MediaType.MULTIPART_FORM_DATA);
TrainingResultJson result = addTrainingAuth(
trainingRestClient.post()
.uri("/segtrain")
.contentType(MediaType.MULTIPART_FORM_DATA))
.body(body)
.retrieve()
.body(TrainingResultJson.class);
if (trainingToken != null && !trainingToken.isBlank()) {
spec = spec.header("X-Training-Token", trainingToken);
}
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
}
TrainingResultJson result = spec
@Override
public OcrClient.TrainingResult trainSenderModel(byte[] trainingDataZip, String outputModelPath) {
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
@Override
public String getFilename() { return "sender-training-data.zip"; }
};
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
HttpHeaders partHeaders = new HttpHeaders();
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders));
body.add("output_model_path", outputModelPath);
TrainingResultJson result = addTrainingAuth(
trainingRestClient.post()
.uri("/train")
.contentType(MediaType.MULTIPART_FORM_DATA))
.body(body)
.retrieve()
.body(TrainingResultJson.class);
@@ -176,7 +199,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
@Override
public void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
List<OcrRegion> regions, @Nullable String senderModelPath,
Consumer<OcrStreamEvent> handler) {
String body;
try {
var requestMap = new java.util.LinkedHashMap<String, Object>();
@@ -186,6 +210,9 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
if (regions != null && !regions.isEmpty()) {
requestMap.put("regions", regions);
}
if (senderModelPath != null) {
requestMap.put("senderModelPath", senderModelPath);
}
body = NDJSON_MAPPER.writeValueAsString(requestMap);
} catch (IOException e) {
throw new RuntimeException("Failed to serialize OCR request", e);
@@ -204,7 +231,12 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
if (response.statusCode() == 404) {
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
OcrClient.super.streamBlocks(pdfUrl, scriptType, regions, handler);
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
handler.accept(new OcrStreamEvent.Start(0));
for (OcrBlockResult block : allBlocks) {
handler.accept(new OcrStreamEvent.Page(block.pageNumber(), List.of(block)));
}
handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0));
return;
}

View File

@@ -38,10 +38,20 @@ public class TrainingDataExportService {
}
public StreamingResponseBody exportToZip() {
// Collect all data before entering the lambda — no open DB txn during streaming
List<TranscriptionBlock> blocks = queryEligibleBlocks();
return exportBlocksToZip(queryEligibleBlocks());
}
public List<TranscriptionBlock> queryBlocksForSender(UUID personId) {
return blockRepository.findManualKurrentBlocksByPerson(personId);
}
public StreamingResponseBody exportForSender(UUID personId) {
return exportBlocksToZip(queryBlocksForSender(personId));
}
private StreamingResponseBody exportBlocksToZip(List<TranscriptionBlock> blocks) {
if (blocks.isEmpty()) {
return out -> {}; // caller checks isEmpty() for 204 response
return out -> {};
}
// Group blocks by documentId so we only download each PDF once

View File

@@ -11,6 +11,7 @@ import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.BlockSource;
import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.model.DocumentAnnotation;
import org.raddatz.familienarchiv.model.ScriptType;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.model.TranscriptionBlockVersion;
import org.raddatz.familienarchiv.repository.AnnotationRepository;
@@ -35,6 +36,7 @@ public class TranscriptionService {
private final AnnotationRepository annotationRepository;
private final AnnotationService annotationService;
private final DocumentService documentService;
private final SenderModelService senderModelService;
public List<TranscriptionBlock> listBlocks(UUID documentId) {
return blockRepository.findByDocumentIdOrderBySortOrderAsc(documentId);
@@ -122,6 +124,7 @@ public class TranscriptionService {
String text = sanitizeText(dto.getText());
block.setText(text);
block.setSource(BlockSource.MANUAL);
if (dto.getLabel() != null) {
block.setLabel(dto.getLabel());
}
@@ -129,6 +132,12 @@ public class TranscriptionService {
TranscriptionBlock saved = blockRepository.save(block);
saveVersion(saved, userId);
Document doc = documentService.getDocumentById(documentId);
if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) {
senderModelService.checkAndTriggerTraining(doc.getSender().getId());
}
return saved;
}

View File

@@ -77,3 +77,8 @@ app:
tags: 10
summary: 11
transcription: 13
ocr:
sender-model:
activation-threshold: 100
retrain-delta: 50

View File

@@ -0,0 +1,10 @@
CREATE TABLE sender_models (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
person_id UUID NOT NULL UNIQUE REFERENCES persons(id) ON DELETE CASCADE,
model_path TEXT NOT NULL,
accuracy DOUBLE PRECISION,
cer DOUBLE PRECISION,
corrected_lines_at_training INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);

View File

@@ -0,0 +1,2 @@
ALTER TABLE ocr_training_runs
ADD COLUMN person_id UUID REFERENCES persons(id) ON DELETE SET NULL;

View File

@@ -21,6 +21,7 @@ class OcrClientDefaultStreamTest {
}
@Override public TrainingResult trainModel(byte[] zip) { return null; }
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
@Override public TrainingResult trainSenderModel(byte[] zip, String path) { return null; }
};
List<OcrStreamEvent> events = new ArrayList<>();
@@ -52,6 +53,7 @@ class OcrClientDefaultStreamTest {
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); }
@Override public TrainingResult trainModel(byte[] zip) { return null; }
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
@Override public TrainingResult trainSenderModel(byte[] zip, String path) { return null; }
};
List<OcrStreamEvent> events = new ArrayList<>();

View File

@@ -22,6 +22,7 @@ class TranscriptionServiceGuidedTest {
AnnotationRepository annotationRepository;
AnnotationService annotationService;
DocumentService documentService;
SenderModelService senderModelService;
TranscriptionService service;
UUID docId = UUID.randomUUID();
@@ -35,9 +36,10 @@ class TranscriptionServiceGuidedTest {
annotationRepository = mock(AnnotationRepository.class);
annotationService = mock(AnnotationService.class);
documentService = mock(DocumentService.class);
senderModelService = mock(SenderModelService.class);
service = new TranscriptionService(blockRepository, versionRepository,
annotationRepository, annotationService, documentService);
annotationRepository, annotationService, documentService, senderModelService);
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));

View File

@@ -561,6 +561,11 @@
"transcription_block_segmentation_only": "Nur Segmentierung",
"training_chip_kurrent": "Kurrent-Erkennung",
"training_chip_segmentation": "Segmentierung",
"training_col_type": "Typ",
"training_type_base": "Basis",
"training_type_personalized": "Personalisiert",
"training_col_person": "Absender",
"training_status_queued": "Warteschlange",
"mission_control_heading": "Was braucht Aufmerksamkeit?",
"mission_control_segmentation_heading": "Text markieren",
"mission_control_segmentation_description": "Textbereiche markieren — keine Vorkenntnisse nötig",

View File

@@ -561,6 +561,11 @@
"transcription_block_segmentation_only": "Segmentation only",
"training_chip_kurrent": "Kurrent recognition",
"training_chip_segmentation": "Segmentation",
"training_col_type": "Type",
"training_type_base": "Base",
"training_type_personalized": "Personalized",
"training_col_person": "Sender",
"training_status_queued": "Queued",
"mission_control_heading": "What needs attention?",
"mission_control_segmentation_heading": "Mark text",
"mission_control_segmentation_description": "Mark text areas — no prior knowledge needed",

View File

@@ -561,6 +561,11 @@
"transcription_block_segmentation_only": "Solo segmentación",
"training_chip_kurrent": "Reconocimiento Kurrent",
"training_chip_segmentation": "Segmentación",
"training_col_type": "Tipo",
"training_type_base": "Base",
"training_type_personalized": "Personalizado",
"training_col_person": "Remitente",
"training_status_queued": "En cola",
"mission_control_heading": "¿Qué necesita atención?",
"mission_control_segmentation_heading": "Marcar texto",
"mission_control_segmentation_description": "Marcar áreas de texto — sin conocimientos previos",

View File

@@ -50,3 +50,26 @@ describe('TrainingHistory — expand/collapse', () => {
.not.toBeInTheDocument();
});
});
describe('TrainingHistory — type and person columns', () => {
it('shows "Basis" for runs without personId', async () => {
render(TrainingHistory, { runs: [makeRun(0)] });
await expect.element(page.getByText(/Basis/i)).toBeInTheDocument();
});
it('shows "Personalisiert" for runs with personId', async () => {
const run = { ...makeRun(0), personId: 'person-1' };
render(TrainingHistory, { runs: [run], personNames: { 'person-1': 'Karl Müller' } });
await expect.element(page.getByText(/Personalisiert/i)).toBeInTheDocument();
});
it('shows person name from personNames for sender runs', async () => {
const run = { ...makeRun(0), personId: 'person-1' };
render(TrainingHistory, { runs: [run], personNames: { 'person-1': 'Karl Müller' } });
await expect.element(page.getByText(/Personalisiert/i)).toBeInTheDocument();
await expect.element(page.getByText('Karl Müller')).toBeInTheDocument();
});
});