diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index 3d75ea9a..389ab11a 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -48,7 +48,7 @@ jobs: path: frontend/test-results/screenshots/ # ─── OCR Service Unit Tests ─────────────────────────────────────────────────── - # Only spell_check.py and test_confidence.py — no ML stack required. + # Only spell_check.py, test_confidence.py, test_sender_registry.py — no ML stack required. ocr-tests: name: OCR Service Tests runs-on: ubuntu-latest @@ -60,11 +60,11 @@ jobs: python-version: '3.11' - name: Install test dependencies - run: pip install "pyspellchecker==0.9.0" pytest + run: pip install "pyspellchecker==0.9.0" pytest pytest-asyncio working-directory: ocr-service - name: Run OCR unit tests (no ML stack required) - run: python -m pytest test_spell_check.py test_confidence.py -v + run: python -m pytest test_spell_check.py test_confidence.py test_sender_registry.py -v working-directory: ocr-service # ─── Backend Unit & Slice Tests ─────────────────────────────────────────────── diff --git a/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java b/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java index e5db5231..c22081aa 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java @@ -59,6 +59,9 @@ public class OcrTrainingRun { @Column(name = "triggered_by") private UUID triggeredBy; + @Column(name = "person_id") + private UUID personId; + @CreationTimestamp @Column(name = "created_at", nullable = false, updatable = false) @Schema(requiredMode = Schema.RequiredMode.REQUIRED) diff --git a/backend/src/main/java/org/raddatz/familienarchiv/model/SenderModel.java b/backend/src/main/java/org/raddatz/familienarchiv/model/SenderModel.java new file mode 100644 index 00000000..4b8586aa --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/model/SenderModel.java @@ -0,0 +1,56 @@ +package org.raddatz.familienarchiv.model; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.persistence.*; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.hibernate.annotations.CreationTimestamp; +import org.hibernate.annotations.UpdateTimestamp; + +import java.time.Instant; +import java.util.UUID; + +@Entity +@Table(name = "sender_models") +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +public class SenderModel { + + @Id + @GeneratedValue(strategy = GenerationType.UUID) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private UUID id; + + @Column(name = "person_id", nullable = false, unique = true) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private UUID personId; + + @JsonIgnore + @Column(name = "model_path", nullable = false) + private String modelPath; + + @Column + private Double accuracy; + + @Column + private Double cer; + + @Column(name = "corrected_lines_at_training", nullable = false) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private int correctedLinesAtTraining; + + @CreationTimestamp + @Column(name = "created_at", nullable = false, updatable = false) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private Instant createdAt; + + @UpdateTimestamp + @Column(name = "updated_at", nullable = false) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private Instant updatedAt; +} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/model/TrainingStatus.java b/backend/src/main/java/org/raddatz/familienarchiv/model/TrainingStatus.java index 7e99dd2f..e91b9967 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/model/TrainingStatus.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/model/TrainingStatus.java @@ -1,6 +1,7 @@ package org.raddatz.familienarchiv.model; public enum TrainingStatus { + QUEUED, RUNNING, DONE, FAILED diff --git a/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java b/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java index fe7d61d1..b3878164 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java @@ -12,5 +12,11 @@ public interface OcrTrainingRunRepository extends JpaRepository findFirstByStatus(TrainingStatus status); - List findTop10ByOrderByCreatedAtDesc(); + Optional findFirstByStatusOrderByCreatedAtAsc(TrainingStatus status); + + Optional findFirstByPersonIdAndStatus(UUID personId, TrainingStatus status); + + boolean existsByPersonIdAndStatus(UUID personId, TrainingStatus status); + + List findTop20ByOrderByCreatedAtDesc(); } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/repository/SenderModelRepository.java b/backend/src/main/java/org/raddatz/familienarchiv/repository/SenderModelRepository.java new file mode 100644 index 00000000..0bfb05a5 --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/repository/SenderModelRepository.java @@ -0,0 +1,12 @@ +package org.raddatz.familienarchiv.repository; + +import org.raddatz.familienarchiv.model.SenderModel; +import org.springframework.data.jpa.repository.JpaRepository; + +import java.util.Optional; +import java.util.UUID; + +public interface SenderModelRepository extends JpaRepository { + + Optional findByPersonId(UUID personId); +} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/repository/TranscriptionBlockRepository.java b/backend/src/main/java/org/raddatz/familienarchiv/repository/TranscriptionBlockRepository.java index d091f950..c88830ad 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/repository/TranscriptionBlockRepository.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/repository/TranscriptionBlockRepository.java @@ -3,6 +3,7 @@ package org.raddatz.familienarchiv.repository; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; import java.util.List; import java.util.Optional; @@ -37,4 +38,22 @@ public interface TranscriptionBlockRepository extends JpaRepository findSegmentationBlocks(); + + @Query(""" + SELECT COUNT(b) FROM TranscriptionBlock b + JOIN Document d ON d.id = b.documentId + WHERE b.source = 'MANUAL' + AND d.sender.id = :personId + AND d.scriptType = 'HANDWRITING_KURRENT' + """) + long countManualKurrentBlocksByPerson(@Param("personId") UUID personId); + + @Query(""" + SELECT b FROM TranscriptionBlock b + JOIN Document d ON d.id = b.documentId + WHERE b.source = 'MANUAL' + AND d.sender.id = :personId + AND d.scriptType = 'HANDWRITING_KURRENT' + """) + List findManualKurrentBlocksByPerson(@Param("personId") UUID personId); } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrAsyncRunner.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrAsyncRunner.java index 3ae853b0..08f38bda 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrAsyncRunner.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrAsyncRunner.java @@ -9,10 +9,12 @@ import org.raddatz.familienarchiv.repository.OcrJobRepository; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Component; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; @Component @RequiredArgsConstructor @@ -29,6 +31,7 @@ public class OcrAsyncRunner { private final OcrJobRepository ocrJobRepository; private final OcrJobDocumentRepository ocrJobDocumentRepository; private final OcrProgressService ocrProgressService; + private final SenderModelService senderModelService; @Async public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) { @@ -68,12 +71,18 @@ public class OcrAsyncRunner { String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath()); + String senderModelPath = null; + if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) { + senderModelPath = senderModelService.maybeGetModelPath(doc.getSender().getId()).orElse(null); + } + AtomicInteger blockCounter = new AtomicInteger(0); AtomicInteger currentPage = new AtomicInteger(0); AtomicInteger skippedPages = new AtomicInteger(0); AtomicInteger totalPages = new AtomicInteger(0); - ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, event -> { + final String finalSenderModelPath = senderModelPath; + ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, finalSenderModelPath, event -> { switch (event) { case OcrStreamEvent.Start start -> { totalPages.set(start.totalPages()); @@ -207,7 +216,25 @@ public class OcrAsyncRunner { clearExistingBlocks(documentId); String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath()); - List blocks = ocrClient.extractBlocks(pdfUrl, doc.getScriptType()); + + String senderModelPath = null; + if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) { + senderModelPath = senderModelService.maybeGetModelPath(doc.getSender().getId()).orElse(null); + } + + final AtomicReference> blocksRef = new AtomicReference<>(); + final String finalSenderModelPath = senderModelPath; + ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), null, finalSenderModelPath, event -> { + switch (event) { + case OcrStreamEvent.Page page -> { + blocksRef.compareAndSet(null, new ArrayList<>()); + blocksRef.get().addAll(page.blocks()); + } + default -> {} + } + }); + + List blocks = blocksRef.get() != null ? blocksRef.get() : List.of(); createTranscriptionBlocks(documentId, blocks, userId, doc.getFileHash()); } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java index 45c2f021..6af68718 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java @@ -1,6 +1,7 @@ package org.raddatz.familienarchiv.service; import org.raddatz.familienarchiv.model.ScriptType; +import org.springframework.lang.Nullable; import java.util.ArrayList; import java.util.LinkedHashMap; @@ -37,15 +38,27 @@ public interface OcrClient { TrainingResult segtrainModel(byte[] trainingDataZip); /** - * Stream OCR results page-by-page via NDJSON. Implementations should override - * this method. The default exists only for backward compatibility during migration - * — it calls extractBlocks() and synthesizes events from the collected result. + * Fine-tune the Kurrent model for a specific sender. * - * @param regions optional list of pre-drawn annotation regions; when non-null, - * the OCR service runs in guided mode (crop + recognize per region) + * @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService.exportForSender() + * @param outputModelPath where to save the trained model (e.g. /app/models/sender_{uuid}.mlmodel) + * @return training result metrics + */ + TrainingResult trainSenderModel(byte[] trainingDataZip, String outputModelPath); + + /** + * Stream OCR results page-by-page via NDJSON, optionally using a sender-specific model. + * The default implementation synthesizes events from extractBlocks() for backward compatibility. + * Implementations that support real streaming (e.g. RestClientOcrClient) override this. + * + * @param regions optional list of pre-drawn annotation regions; when non-null, + * the OCR service runs in guided mode (crop + recognize per region) + * @param senderModelPath optional path to a per-sender model file; null means use base model */ default void streamBlocks(String pdfUrl, ScriptType scriptType, - List regions, Consumer handler) { + List regions, + @Nullable String senderModelPath, + Consumer handler) { List allBlocks = extractBlocks(pdfUrl, scriptType); LinkedHashMap> byPage = new LinkedHashMap<>(); @@ -62,4 +75,9 @@ public interface OcrClient { handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0)); } + + default void streamBlocks(String pdfUrl, ScriptType scriptType, + List regions, Consumer handler) { + streamBlocks(pdfUrl, scriptType, regions, null, handler); + } } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java index 65828999..20780985 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java @@ -17,9 +17,11 @@ import org.springframework.transaction.support.TransactionTemplate; import java.io.ByteArrayOutputStream; import java.time.Instant; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.UUID; @Service @@ -34,6 +36,7 @@ public class OcrTrainingService { private final OcrHealthClient ocrHealthClient; private final TranscriptionBlockRepository blockRepository; private final TransactionTemplate txTemplate; + private final PersonService personService; public record TrainingInfoResponse( int availableBlocks, @@ -42,7 +45,8 @@ public class OcrTrainingService { int availableSegBlocks, boolean ocrServiceAvailable, OcrTrainingRun lastRun, - List runs + List runs, + Map personNames ) {} private void assertNoRunningTraining() { @@ -195,9 +199,20 @@ public class OcrTrainingService { int totalOcrBlocks = (int) blockRepository.count(); int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size(); - List recentRuns = trainingRunRepository.findTop10ByOrderByCreatedAtDesc(); + List recentRuns = trainingRunRepository.findTop20ByOrderByCreatedAtDesc(); OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0); + List distinctPersonIds = recentRuns.stream() + .map(OcrTrainingRun::getPersonId) + .filter(Objects::nonNull) + .distinct() + .toList(); + Map personNames = new HashMap<>(); + if (!distinctPersonIds.isEmpty()) { + personService.getAllById(distinctPersonIds) + .forEach(p -> personNames.put(p.getId().toString(), p.getDisplayName())); + } + return new TrainingInfoResponse( eligibleBlocks.size(), totalOcrBlocks, @@ -205,7 +220,8 @@ public class OcrTrainingService { availableSegBlocks, ocrHealthClient.isHealthy(), lastRun, - recentRuns + recentRuns, + personNames ); } @@ -224,15 +240,4 @@ public class OcrTrainingService { }); } - public Map buildTrainingInfoMap(TrainingInfoResponse info) { - return Map.of( - "availableBlocks", info.availableBlocks(), - "totalOcrBlocks", info.totalOcrBlocks(), - "availableDocuments", info.availableDocuments(), - "availableSegBlocks", info.availableSegBlocks(), - "ocrServiceAvailable", info.ocrServiceAvailable(), - "lastRun", info.lastRun() != null ? info.lastRun() : Map.of(), - "runs", info.runs() - ); - } } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java b/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java index 9e58a2e1..2bda9892 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java @@ -14,6 +14,7 @@ import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.client.JdkClientHttpRequestFactory; +import org.springframework.lang.Nullable; import org.springframework.stereotype.Component; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -102,6 +103,13 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { .toList(); } + private RestClient.RequestBodySpec addTrainingAuth(RestClient.RequestBodySpec spec) { + if (trainingToken != null && !trainingToken.isBlank()) { + return spec.header("X-Training-Token", trainingToken); + } + return spec; + } + @Override public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) { ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) { @@ -114,15 +122,10 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { partHeaders.setContentType(MediaType.parseMediaType("application/zip")); body.add("file", new HttpEntity<>(zipResource, partHeaders)); - var spec = trainingRestClient.post() - .uri("/train") - .contentType(MediaType.MULTIPART_FORM_DATA); - - if (trainingToken != null && !trainingToken.isBlank()) { - spec = spec.header("X-Training-Token", trainingToken); - } - - TrainingResultJson result = spec + TrainingResultJson result = addTrainingAuth( + trainingRestClient.post() + .uri("/train") + .contentType(MediaType.MULTIPART_FORM_DATA)) .body(body) .retrieve() .body(TrainingResultJson.class); @@ -143,15 +146,35 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { partHeaders.setContentType(MediaType.parseMediaType("application/zip")); body.add("file", new HttpEntity<>(zipResource, partHeaders)); - var spec = trainingRestClient.post() - .uri("/segtrain") - .contentType(MediaType.MULTIPART_FORM_DATA); + TrainingResultJson result = addTrainingAuth( + trainingRestClient.post() + .uri("/segtrain") + .contentType(MediaType.MULTIPART_FORM_DATA)) + .body(body) + .retrieve() + .body(TrainingResultJson.class); - if (trainingToken != null && !trainingToken.isBlank()) { - spec = spec.header("X-Training-Token", trainingToken); - } + if (result == null) return new OcrClient.TrainingResult(null, null, null, null); + return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs()); + } - TrainingResultJson result = spec + @Override + public OcrClient.TrainingResult trainSenderModel(byte[] trainingDataZip, String outputModelPath) { + ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) { + @Override + public String getFilename() { return "sender-training-data.zip"; } + }; + + MultiValueMap body = new LinkedMultiValueMap<>(); + HttpHeaders partHeaders = new HttpHeaders(); + partHeaders.setContentType(MediaType.parseMediaType("application/zip")); + body.add("file", new HttpEntity<>(zipResource, partHeaders)); + body.add("output_model_path", outputModelPath); + + TrainingResultJson result = addTrainingAuth( + trainingRestClient.post() + .uri("/train-sender") + .contentType(MediaType.MULTIPART_FORM_DATA)) .body(body) .retrieve() .body(TrainingResultJson.class); @@ -176,7 +199,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { @Override public void streamBlocks(String pdfUrl, ScriptType scriptType, - List regions, Consumer handler) { + List regions, @Nullable String senderModelPath, + Consumer handler) { String body; try { var requestMap = new java.util.LinkedHashMap(); @@ -186,6 +210,9 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { if (regions != null && !regions.isEmpty()) { requestMap.put("regions", regions); } + if (senderModelPath != null) { + requestMap.put("senderModelPath", senderModelPath); + } body = NDJSON_MAPPER.writeValueAsString(requestMap); } catch (IOException e) { throw new RuntimeException("Failed to serialize OCR request", e); @@ -204,7 +231,12 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { if (response.statusCode() == 404) { log.info("OCR service does not support /ocr/stream (404), falling back to /ocr"); - OcrClient.super.streamBlocks(pdfUrl, scriptType, regions, handler); + List allBlocks = extractBlocks(pdfUrl, scriptType); + handler.accept(new OcrStreamEvent.Start(0)); + for (OcrBlockResult block : allBlocks) { + handler.accept(new OcrStreamEvent.Page(block.pageNumber(), List.of(block))); + } + handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0)); return; } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java new file mode 100644 index 00000000..be9318a7 --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java @@ -0,0 +1,197 @@ +package org.raddatz.familienarchiv.service; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.raddatz.familienarchiv.model.OcrTrainingRun; +import org.raddatz.familienarchiv.model.SenderModel; +import org.raddatz.familienarchiv.model.TrainingStatus; +import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; +import org.raddatz.familienarchiv.repository.SenderModelRepository; +import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.slf4j.MDC; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; +import org.springframework.transaction.support.TransactionTemplate; + +import java.io.ByteArrayOutputStream; +import java.time.Instant; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; + +@Service +@RequiredArgsConstructor +@Slf4j +public class SenderModelService { + + private final SenderModelRepository senderModelRepository; + private final TranscriptionBlockRepository blockRepository; + private final OcrTrainingRunRepository trainingRunRepository; + private final OcrClient ocrClient; + private final TransactionTemplate txTemplate; + private final TrainingDataExportService trainingDataExportService; + + @Value("${ocr.sender-model.activation-threshold:100}") + private int activationThreshold; + + @Value("${ocr.sender-model.retrain-delta:50}") + private int retrainDelta; + + /** Returns the model path if a trained sender model exists for this person. */ + public Optional maybeGetModelPath(UUID personId) { + return senderModelRepository.findByPersonId(personId) + .map(SenderModel::getModelPath); + } + + /** + * Called after every MANUAL block save for HANDWRITING_KURRENT documents. + * Checks activation and retrain thresholds; enqueues or starts sender training when met. + */ + @Async + public void checkAndTriggerTraining(UUID personId) { + long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId); + Optional existing = senderModelRepository.findByPersonId(personId); + + boolean shouldActivate = existing.isEmpty() && correctedLines >= activationThreshold; + boolean shouldRetrain = existing.isPresent() + && (correctedLines - existing.get().getCorrectedLinesAtTraining()) >= retrainDelta; + + if (!shouldActivate && !shouldRetrain) { + return; + } + + log.info("Sender training threshold met for person {} (correctedLines={}, activate={}, retrain={})", + personId, correctedLines, shouldActivate, shouldRetrain); + + boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines); + if (runNow) { + triggerSenderTraining(personId, (int) correctedLines); + } + } + + /** + * Atomically checks the queue state and either creates a RUNNING row (returns true) or a + * QUEUED row (returns false). All three operations — idle check, duplicate-queue guard, and + * RUNNING row creation — happen in one transaction, eliminating the race window that would + * otherwise exist between the check and a separate RUNNING row creation. + */ + @Transactional + public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) { + if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) { + log.info("Sender training already queued for person {} — skipping duplicate trigger", personId); + return false; + } + + if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { + int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId); + trainingRunRepository.save(OcrTrainingRun.builder() + .status(TrainingStatus.QUEUED) + .personId(personId) + .blockCount(blockCount) + .documentCount(0) + .modelName("sender_" + personId) + .build()); + log.info("Queued sender training for person {} — training already running", personId); + return false; + } + + long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId); + trainingRunRepository.save(OcrTrainingRun.builder() + .status(TrainingStatus.RUNNING) + .personId(personId) + .blockCount((int) blockCount) + .documentCount(0) + .modelName("sender_" + personId) + .build()); + return true; + } + + /** + * Executes sender training synchronously. Caller must run this on a background thread. + * The RUNNING row is expected to already exist — created atomically by + * runOrQueueSenderTraining (for new runs) or by promoteNextQueuedRun (for promoted runs). + */ + public void triggerSenderTraining(UUID personId, int correctedLines) { + String outputModelPath = "/app/models/sender_" + personId + ".mlmodel"; + + OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> + trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING) + .orElseThrow(() -> new IllegalStateException( + "Expected RUNNING row for person " + personId + " but none found")))); + + String runId = run.getId().toString(); + MDC.put("trainingRunId", runId); + log.info("Started sender training run {} for person {}", runId, personId); + + try { + byte[] zipBytes = exportSenderData(personId); + log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length); + OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath); + + txTemplate.execute(status -> { + SenderModel model = senderModelRepository.findByPersonId(personId) + .orElseGet(() -> SenderModel.builder().personId(personId).build()); + model.setModelPath(outputModelPath); + model.setCer(result.cer()); + model.setAccuracy(result.accuracy()); + model.setCorrectedLinesAtTraining(correctedLines); + senderModelRepository.save(model); + + run.setStatus(TrainingStatus.DONE); + run.setCompletedAt(Instant.now()); + run.setCer(result.cer()); + run.setAccuracy(result.accuracy()); + run.setEpochs(result.epochs()); + trainingRunRepository.save(run); + log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer()); + return null; + }); + } catch (Exception e) { + txTemplate.execute(status -> { + run.setStatus(TrainingStatus.FAILED); + run.setErrorMessage(e.getMessage()); + run.setCompletedAt(Instant.now()); + trainingRunRepository.save(run); + log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e); + return null; + }); + } finally { + MDC.remove("trainingRunId"); + promoteNextQueuedRun(); + } + } + + private byte[] exportSenderData(UUID personId) throws java.io.IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + trainingDataExportService.exportForSender(personId).writeTo(baos); + return baos.toByteArray(); + } + + /** + * Promotes the oldest QUEUED sender run to RUNNING and triggers its training. + * Called in the finally block of triggerSenderTraining, creating a sequential chain: + * each run promotes the next only after it fully completes (success or failure). + * + * This is intentionally tail-recursive via the @Async thread: the same thread holds the + * full queue drain, serialising all sender training runs naturally without an external + * scheduler. With N queued runs the thread stays occupied for N sequential trainings — + * acceptable because the @Async executor is dedicated to long-running background work. + */ + private void promoteNextQueuedRun() { + Optional queuedOpt = txTemplate.execute(status -> + trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED) + .map(queued -> { + queued.setStatus(TrainingStatus.RUNNING); + return trainingRunRepository.save(queued); + })); + + if (queuedOpt != null && queuedOpt.isPresent()) { + OcrTrainingRun promoted = queuedOpt.get(); + log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId()); + long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId()); + triggerSenderTraining(promoted.getPersonId(), (int) freshCount); + } + } +} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/TrainingDataExportService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/TrainingDataExportService.java index cf0b10e3..86c81053 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/TrainingDataExportService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/TrainingDataExportService.java @@ -38,10 +38,20 @@ public class TrainingDataExportService { } public StreamingResponseBody exportToZip() { - // Collect all data before entering the lambda — no open DB txn during streaming - List blocks = queryEligibleBlocks(); + return exportBlocksToZip(queryEligibleBlocks()); + } + + public List queryBlocksForSender(UUID personId) { + return blockRepository.findManualKurrentBlocksByPerson(personId); + } + + public StreamingResponseBody exportForSender(UUID personId) { + return exportBlocksToZip(queryBlocksForSender(personId)); + } + + private StreamingResponseBody exportBlocksToZip(List blocks) { if (blocks.isEmpty()) { - return out -> {}; // caller checks isEmpty() for 204 response + return out -> {}; } // Group blocks by documentId so we only download each PDF once diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/TranscriptionService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/TranscriptionService.java index bfb02253..081f588d 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/TranscriptionService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/TranscriptionService.java @@ -11,6 +11,7 @@ import org.raddatz.familienarchiv.exception.ErrorCode; import org.raddatz.familienarchiv.model.BlockSource; import org.raddatz.familienarchiv.model.Document; import org.raddatz.familienarchiv.model.DocumentAnnotation; +import org.raddatz.familienarchiv.model.ScriptType; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.model.TranscriptionBlockVersion; import org.raddatz.familienarchiv.repository.AnnotationRepository; @@ -35,6 +36,7 @@ public class TranscriptionService { private final AnnotationRepository annotationRepository; private final AnnotationService annotationService; private final DocumentService documentService; + private final SenderModelService senderModelService; public List listBlocks(UUID documentId) { return blockRepository.findByDocumentIdOrderBySortOrderAsc(documentId); @@ -122,6 +124,7 @@ public class TranscriptionService { String text = sanitizeText(dto.getText()); block.setText(text); + block.setSource(BlockSource.MANUAL); if (dto.getLabel() != null) { block.setLabel(dto.getLabel()); } @@ -129,6 +132,12 @@ public class TranscriptionService { TranscriptionBlock saved = blockRepository.save(block); saveVersion(saved, userId); + + Document doc = documentService.getDocumentById(documentId); + if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) { + senderModelService.checkAndTriggerTraining(doc.getSender().getId()); + } + return saved; } diff --git a/backend/src/main/resources/application.yaml b/backend/src/main/resources/application.yaml index aa4c8ceb..d9bbe9d0 100644 --- a/backend/src/main/resources/application.yaml +++ b/backend/src/main/resources/application.yaml @@ -77,3 +77,8 @@ app: tags: 10 summary: 11 transcription: 13 + +ocr: + sender-model: + activation-threshold: 100 + retrain-delta: 50 diff --git a/backend/src/main/resources/db/migration/V40__add_sender_models.sql b/backend/src/main/resources/db/migration/V40__add_sender_models.sql new file mode 100644 index 00000000..2386bbe2 --- /dev/null +++ b/backend/src/main/resources/db/migration/V40__add_sender_models.sql @@ -0,0 +1,10 @@ +CREATE TABLE sender_models ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + person_id UUID NOT NULL UNIQUE REFERENCES persons(id) ON DELETE CASCADE, + model_path TEXT NOT NULL, + accuracy DOUBLE PRECISION, + cer DOUBLE PRECISION, + corrected_lines_at_training INT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/backend/src/main/resources/db/migration/V41__add_person_to_training_runs.sql b/backend/src/main/resources/db/migration/V41__add_person_to_training_runs.sql new file mode 100644 index 00000000..b95a70b6 --- /dev/null +++ b/backend/src/main/resources/db/migration/V41__add_person_to_training_runs.sql @@ -0,0 +1,2 @@ +ALTER TABLE ocr_training_runs + ADD COLUMN person_id UUID REFERENCES persons(id) ON DELETE SET NULL; diff --git a/backend/src/main/resources/db/migration/V42__add_queued_per_person_unique_index.sql b/backend/src/main/resources/db/migration/V42__add_queued_per_person_unique_index.sql new file mode 100644 index 00000000..5e4194bd --- /dev/null +++ b/backend/src/main/resources/db/migration/V42__add_queued_per_person_unique_index.sql @@ -0,0 +1,3 @@ +CREATE UNIQUE INDEX idx_training_runs_queued_per_person + ON ocr_training_runs(person_id) + WHERE status = 'QUEUED'; diff --git a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java index 94e67df1..e65e490c 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java @@ -23,6 +23,8 @@ import org.springframework.test.web.servlet.MockMvc; import java.util.List; import java.util.UUID; +import java.util.Map; + import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.eq; @@ -219,7 +221,7 @@ class OcrControllerTest { @WithMockUser(authorities = "ADMIN") void getTrainingInfo_returns200_withInfo() throws Exception { OcrTrainingService.TrainingInfoResponse info = - new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of()); + new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(), Map.of()); when(ocrTrainingService.getTrainingInfo()).thenReturn(info); mockMvc.perform(get("/api/ocr/training-info")) @@ -228,6 +230,52 @@ class OcrControllerTest { .andExpect(jsonPath("$.ocrServiceAvailable").value(true)); } + @Test + @WithMockUser(authorities = "ADMIN") + void getTrainingInfo_returns200_and_omits_personName_when_service_provides_empty_map() throws Exception { + UUID personId = UUID.randomUUID(); + OcrTrainingRun runWithPerson = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.DONE) + .personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build(); + OcrTrainingService.TrainingInfoResponse info = + new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, + List.of(runWithPerson), Map.of()); + when(ocrTrainingService.getTrainingInfo()).thenReturn(info); + + mockMvc.perform(get("/api/ocr/training-info")) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.personNames").isEmpty()); + } + + @Test + @WithMockUser(authorities = "ADMIN") + void getTrainingInfo_includesPersonName_whenPersonIdResolves() throws Exception { + UUID personId = UUID.randomUUID(); + OcrTrainingRun runWithPerson = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.DONE) + .personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build(); + OcrTrainingService.TrainingInfoResponse info = + new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, + List.of(runWithPerson), Map.of(personId.toString(), "Max Mustermann")); + when(ocrTrainingService.getTrainingInfo()).thenReturn(info); + + mockMvc.perform(get("/api/ocr/training-info")) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.personNames." + personId).value("Max Mustermann")); + } + + @Test + @WithMockUser(authorities = "ADMIN") + void getTrainingInfo_serializes_null_lastRun_as_json_null() throws Exception { + OcrTrainingService.TrainingInfoResponse info = + new OcrTrainingService.TrainingInfoResponse(0, 0, 0, 0, false, null, List.of(), Map.of()); + when(ocrTrainingService.getTrainingInfo()).thenReturn(info); + + mockMvc.perform(get("/api/ocr/training-info")) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.lastRun").doesNotExist()); + } + @Test @WithMockUser(authorities = "READ_ALL") void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception { diff --git a/backend/src/test/java/org/raddatz/familienarchiv/repository/MigrationIntegrationTest.java b/backend/src/test/java/org/raddatz/familienarchiv/repository/MigrationIntegrationTest.java index 78ff9861..da1ab7f4 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/repository/MigrationIntegrationTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/repository/MigrationIntegrationTest.java @@ -217,8 +217,71 @@ class MigrationIntegrationTest { assertThat(stored).isEqualTo("sage"); } + // ─── V42: idx_training_runs_queued_per_person partial unique index ──────── + + @Test + @Transactional(propagation = Propagation.NOT_SUPPORTED) + void v42_partialUniqueIndex_preventsTwoQueuedRunsForSamePerson() { + UUID personId = createPerson("V42Test", "SamePerson"); + jdbc.update(""" + INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id) + VALUES (gen_random_uuid(), 'QUEUED', 10, 2, 'sender_x', ?) + """, personId); + + assertThatThrownBy(() -> + jdbc.update(""" + INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id) + VALUES (gen_random_uuid(), 'QUEUED', 5, 1, 'sender_x', ?) + """, personId) + ).isInstanceOf(DataIntegrityViolationException.class); + + jdbc.update("DELETE FROM ocr_training_runs WHERE person_id = ?", personId); + jdbc.update("DELETE FROM persons WHERE id = ?", personId); + } + + @Test + void v42_partialUniqueIndex_allowsQueuedRunsForDifferentPersons() { + UUID person1 = createPerson("V42Test", "PersonA"); + UUID person2 = createPerson("V42Test", "PersonB"); + + int rows1 = jdbc.update(""" + INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id) + VALUES (gen_random_uuid(), 'QUEUED', 10, 2, 'sender_a', ?) + """, person1); + int rows2 = jdbc.update(""" + INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id) + VALUES (gen_random_uuid(), 'QUEUED', 5, 1, 'sender_b', ?) + """, person2); + + assertThat(rows1).isEqualTo(1); + assertThat(rows2).isEqualTo(1); + } + + @Test + void v42_partialUniqueIndex_allowsMultipleDoneRunsForSamePerson() { + UUID personId = createPerson("V42Test", "DonePerson"); + + int rows1 = jdbc.update(""" + INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id) + VALUES (gen_random_uuid(), 'DONE', 10, 2, 'sender_x', ?) + """, personId); + int rows2 = jdbc.update(""" + INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id) + VALUES (gen_random_uuid(), 'DONE', 15, 3, 'sender_x', ?) + """, personId); + + assertThat(rows1).isEqualTo(1); + assertThat(rows2).isEqualTo(1); + } + // ─── helpers ───────────────────────────────────────────────────────────── + private UUID createPerson(String firstName, String lastName) { + UUID id = UUID.randomUUID(); + jdbc.update("INSERT INTO persons (id, first_name, last_name) VALUES (?, ?, ?)", id, firstName, lastName); + return id; + } + private UUID createTag(String name) { UUID id = UUID.randomUUID(); jdbc.update("INSERT INTO tag (id, name) VALUES (?, ?)", id, name); diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrAsyncRunnerTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrAsyncRunnerTest.java index 75e9d3e4..ddafd4c9 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrAsyncRunnerTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrAsyncRunnerTest.java @@ -31,6 +31,7 @@ class OcrAsyncRunnerTest { @Mock OcrJobRepository ocrJobRepository; @Mock OcrJobDocumentRepository ocrJobDocumentRepository; @Mock OcrProgressService ocrProgressService; + @Mock SenderModelService senderModelService; @InjectMocks OcrAsyncRunner ocrAsyncRunner; @@ -42,7 +43,12 @@ class OcrAsyncRunnerTest { .fileHash("hash").scriptType(ScriptType.TYPEWRITER).build(); when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); - when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of()); + doAnswer(inv -> { + Consumer handler = inv.getArgument(4); + handler.accept(new OcrStreamEvent.Start(0)); + handler.accept(new OcrStreamEvent.Done(0, 0)); + return null; + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.processDocument(docId, doc, userId); @@ -59,9 +65,15 @@ class OcrAsyncRunnerTest { when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); - when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of( - new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null), - new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null))); + doAnswer(inv -> { + Consumer handler = inv.getArgument(4); + handler.accept(new OcrStreamEvent.Start(1)); + handler.accept(new OcrStreamEvent.Page(0, List.of( + new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null), + new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null)))); + handler.accept(new OcrStreamEvent.Done(2, 0)); + return null; + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build(); when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann); @@ -83,8 +95,14 @@ class OcrAsyncRunnerTest { when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); - when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of( - new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null))); + doAnswer(inv -> { + Consumer handler = inv.getArgument(4); + handler.accept(new OcrStreamEvent.Start(1)); + handler.accept(new OcrStreamEvent.Page(0, List.of( + new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null)))); + handler.accept(new OcrStreamEvent.Done(1, 0)); + return null; + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build(); when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann); @@ -112,12 +130,12 @@ class OcrAsyncRunnerTest { when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); doAnswer(inv -> { - Consumer handler = inv.getArgument(3); + Consumer handler = inv.getArgument(4); handler.accept(new OcrStreamEvent.Start(1)); handler.accept(new OcrStreamEvent.Page(0, List.of())); handler.accept(new OcrStreamEvent.Done(0, 0)); return null; - }).when(ocrClient).streamBlocks(any(), any(), any(), any()); + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.runSingleDocument(jobId, docId, userId); @@ -142,7 +160,7 @@ class OcrAsyncRunnerTest { when(documentService.getDocumentById(docId)).thenReturn(doc); when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); - doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any()); + doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.runSingleDocument(jobId, docId, userId); @@ -174,7 +192,7 @@ class OcrAsyncRunnerTest { List progressMessages = new ArrayList<>(); doAnswer(inv -> { - Consumer handler = inv.getArgument(3); + Consumer handler = inv.getArgument(4); handler.accept(new OcrStreamEvent.Start(3)); handler.accept(new OcrStreamEvent.Page(0, List.of( new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null), @@ -185,7 +203,7 @@ class OcrAsyncRunnerTest { progressMessages.add(job.getProgressMessage()); handler.accept(new OcrStreamEvent.Done(3, 0)); return null; - }).when(ocrClient).streamBlocks(any(), any(), any(), any()); + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.runSingleDocument(jobId, docId, userId); @@ -215,14 +233,14 @@ class OcrAsyncRunnerTest { when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); doAnswer(inv -> { - Consumer handler = inv.getArgument(3); + Consumer handler = inv.getArgument(4); handler.accept(new OcrStreamEvent.Start(3)); handler.accept(new OcrStreamEvent.Page(0, List.of())); handler.accept(new OcrStreamEvent.Error(1, "failed")); handler.accept(new OcrStreamEvent.Page(2, List.of())); handler.accept(new OcrStreamEvent.Done(0, 1)); return null; - }).when(ocrClient).streamBlocks(any(), any(), any(), any()); + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.runSingleDocument(jobId, docId, userId); @@ -251,14 +269,14 @@ class OcrAsyncRunnerTest { List progressMessages = new ArrayList<>(); doAnswer(inv -> { - Consumer handler = inv.getArgument(3); + Consumer handler = inv.getArgument(4); handler.accept(new OcrStreamEvent.Start(5)); handler.accept(new OcrStreamEvent.Preprocessing(1)); progressMessages.add(job.getProgressMessage()); handler.accept(new OcrStreamEvent.Page(1, List.of())); handler.accept(new OcrStreamEvent.Done(0, 0)); return null; - }).when(ocrClient).streamBlocks(any(), any(), any(), any()); + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.runSingleDocument(jobId, docId, userId); @@ -287,13 +305,13 @@ class OcrAsyncRunnerTest { when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); doAnswer(inv -> { - Consumer handler = inv.getArgument(3); + Consumer handler = inv.getArgument(4); handler.accept(new OcrStreamEvent.Start(2)); handler.accept(new OcrStreamEvent.Error(0, "some python traceback details")); handler.accept(new OcrStreamEvent.Page(1, List.of())); handler.accept(new OcrStreamEvent.Done(0, 1)); return null; - }).when(ocrClient).streamBlocks(any(), any(), any(), any()); + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.runSingleDocument(jobId, docId, userId); diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrClientDefaultStreamTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrClientDefaultStreamTest.java index 7be27c40..f865f5d6 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrClientDefaultStreamTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrClientDefaultStreamTest.java @@ -21,6 +21,7 @@ class OcrClientDefaultStreamTest { } @Override public TrainingResult trainModel(byte[] zip) { return null; } @Override public TrainingResult segtrainModel(byte[] zip) { return null; } + @Override public TrainingResult trainSenderModel(byte[] zip, String path) { return null; } }; List events = new ArrayList<>(); @@ -52,6 +53,7 @@ class OcrClientDefaultStreamTest { @Override public List extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); } @Override public TrainingResult trainModel(byte[] zip) { return null; } @Override public TrainingResult segtrainModel(byte[] zip) { return null; } + @Override public TrainingResult trainSenderModel(byte[] zip, String path) { return null; } }; List events = new ArrayList<>(); diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java index 62ef00f6..2e09eb69 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -4,11 +4,12 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.model.OcrTrainingRun; -import org.raddatz.familienarchiv.model.TrainingLabel; +import org.raddatz.familienarchiv.model.Person; import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.raddatz.familienarchiv.service.PersonService; import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionTemplate; @@ -32,6 +33,7 @@ class OcrTrainingServiceTest { OcrHealthClient healthClient; TranscriptionBlockRepository blockRepository; TransactionTemplate txTemplate; + PersonService personService; OcrTrainingService service; @BeforeEach @@ -43,6 +45,7 @@ class OcrTrainingServiceTest { healthClient = mock(OcrHealthClient.class); blockRepository = mock(TranscriptionBlockRepository.class); txTemplate = mock(TransactionTemplate.class); + personService = mock(PersonService.class); // Execute transaction callbacks inline so unit tests run without a real DataSource when(txTemplate.execute(any())).thenAnswer(inv -> { @@ -50,10 +53,10 @@ class OcrTrainingServiceTest { return callback.doInTransaction(null); }); - service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); + service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate, personService); when(blockRepository.count()).thenReturn(0L); - when(runRepository.findTop10ByOrderByCreatedAtDesc()).thenReturn(List.of()); + when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of()); when(segExportService.querySegmentationBlocks()).thenReturn(List.of()); } @@ -230,6 +233,37 @@ class OcrTrainingServiceTest { run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null)); } + // ─── getTrainingInfo: batch person name resolution ──────────────────────── + + @Test + void getTrainingInfo_resolves_person_names_in_single_batch_call() { + UUID personA = UUID.randomUUID(); + UUID personB = UUID.randomUUID(); + List runs = List.of( + OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE) + .personId(personA).blockCount(5).documentCount(1).modelName("sender_a").build(), + OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE) + .personId(personB).blockCount(5).documentCount(1).modelName("sender_b").build(), + OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE) + .personId(personA).blockCount(5).documentCount(1).modelName("sender_a").build() + ); + when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(runs); + when(exportService.queryEligibleBlocks()).thenReturn(List.of()); + + Person pa = Person.builder().id(personA).firstName("Anna").lastName("Müller").build(); + Person pb = Person.builder().id(personB).firstName("Karl").lastName("Bauer").build(); + when(personService.getAllById(argThat(ids -> ids.size() == 2 && ids.containsAll(List.of(personA, personB))))) + .thenReturn(List.of(pa, pb)); + when(healthClient.isHealthy()).thenReturn(true); + + OcrTrainingService.TrainingInfoResponse info = service.getTrainingInfo(); + + verify(personService, never()).getById(any()); + verify(personService, times(1)).getAllById(any()); + assertThat(info.personNames()).containsKey(personA.toString()); + assertThat(info.personNames()).containsKey(personB.toString()); + } + // ─── Orphan recovery ────────────────────────────────────────────────────── @Test diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/RestClientOcrClientTrainingTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/RestClientOcrClientTrainingTest.java new file mode 100644 index 00000000..17f084fa --- /dev/null +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/RestClientOcrClientTrainingTest.java @@ -0,0 +1,41 @@ +package org.raddatz.familienarchiv.service; + +import com.sun.net.httpserver.HttpServer; +import org.junit.jupiter.api.Test; + +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +import static org.assertj.core.api.Assertions.assertThat; + +class RestClientOcrClientTrainingTest { + + @Test + void trainSenderModel_postsToTrainSenderEndpoint() throws Exception { + List capturedPaths = new CopyOnWriteArrayList<>(); + HttpServer server = HttpServer.create(new InetSocketAddress(0), 0); + server.createContext("/", exchange -> { + capturedPaths.add(exchange.getRequestURI().getPath()); + byte[] response = "{\"loss\":0.1,\"accuracy\":0.95,\"cer\":0.05,\"epochs\":3}" + .getBytes(StandardCharsets.UTF_8); + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders(200, response.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(response); + } + }); + server.start(); + + try { + int port = server.getAddress().getPort(); + RestClientOcrClient client = new RestClientOcrClient("http://localhost:" + port, ""); + client.trainSenderModel(new byte[]{1, 2, 3}, "/app/models/sender_test.mlmodel"); + assertThat(capturedPaths).contains("/train-sender"); + } finally { + server.stop(0); + } + } +} diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java new file mode 100644 index 00000000..57318758 --- /dev/null +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java @@ -0,0 +1,302 @@ +package org.raddatz.familienarchiv.service; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.raddatz.familienarchiv.model.OcrTrainingRun; +import org.raddatz.familienarchiv.model.SenderModel; +import org.raddatz.familienarchiv.model.TrainingStatus; +import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; +import org.raddatz.familienarchiv.repository.SenderModelRepository; +import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.transaction.support.TransactionCallback; +import org.springframework.transaction.support.TransactionTemplate; + +import java.util.Optional; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +class SenderModelServiceTest { + + SenderModelRepository senderModelRepository; + TranscriptionBlockRepository blockRepository; + OcrTrainingRunRepository trainingRunRepository; + OcrClient ocrClient; + TransactionTemplate txTemplate; + TrainingDataExportService trainingDataExportService; + SenderModelService service; + + UUID personId = UUID.randomUUID(); + + @BeforeEach + void setUp() { + senderModelRepository = mock(SenderModelRepository.class); + blockRepository = mock(TranscriptionBlockRepository.class); + trainingRunRepository = mock(OcrTrainingRunRepository.class); + ocrClient = mock(OcrClient.class); + txTemplate = mock(TransactionTemplate.class); + trainingDataExportService = mock(TrainingDataExportService.class); + + // Execute transaction callbacks inline so unit tests run without a real DataSource. + // lenient: not every test hits the txTemplate path, but the setup is shared. + lenient().when(txTemplate.execute(any())).thenAnswer(inv -> { + TransactionCallback callback = inv.getArgument(0); + return callback.doInTransaction(null); + }); + + service = new SenderModelService(senderModelRepository, blockRepository, + trainingRunRepository, ocrClient, txTemplate, trainingDataExportService); + ReflectionTestUtils.setField(service, "activationThreshold", 100); + ReflectionTestUtils.setField(service, "retrainDelta", 50); + } + + // ─── Activation threshold ───────────────────────────────────────────────── + + @Test + void checkAndTriggerTraining_doesNothing_belowActivationThreshold() { + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + + SenderModelService spy = spy(service); + spy.checkAndTriggerTraining(personId); + + verify(spy, never()).runOrQueueSenderTraining(any(), anyInt()); + } + + @Test + void checkAndTriggerTraining_triggersTraining_atActivationThreshold() { + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + + SenderModelService spy = spy(service); + doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100); + + spy.checkAndTriggerTraining(personId); + + verify(spy).runOrQueueSenderTraining(personId, 100); + } + + // ─── Retrain delta ──────────────────────────────────────────────────────── + + @Test + void checkAndTriggerTraining_doesNothing_belowRetrainDelta() { + SenderModel existing = SenderModel.builder().personId(personId) + .correctedLinesAtTraining(100).build(); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing)); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L); + + SenderModelService spy = spy(service); + spy.checkAndTriggerTraining(personId); + + verify(spy, never()).runOrQueueSenderTraining(any(), anyInt()); + } + + @Test + void checkAndTriggerTraining_triggersRetrain_atRetrainDelta() { + SenderModel existing = SenderModel.builder().personId(personId) + .correctedLinesAtTraining(100).build(); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing)); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L); + + SenderModelService spy = spy(service); + doReturn(false).when(spy).runOrQueueSenderTraining(personId, 150); + + spy.checkAndTriggerTraining(personId); + + verify(spy).runOrQueueSenderTraining(personId, 150); + } + + // ─── runNow flag ────────────────────────────────────────────────────────── + + @Test + void checkAndTriggerTraining_callsTrigger_whenRunNow() { + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + + SenderModelService spy = spy(service); + doReturn(true).when(spy).runOrQueueSenderTraining(personId, 100); + doNothing().when(spy).triggerSenderTraining(personId, 100); + + spy.checkAndTriggerTraining(personId); + + verify(spy).triggerSenderTraining(personId, 100); + } + + @Test + void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() { + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + + SenderModelService spy = spy(service); + doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100); + + spy.checkAndTriggerTraining(personId); + + verify(spy, never()).triggerSenderTraining(any(), anyInt()); + } + + // ─── runOrQueueSenderTraining ───────────────────────────────────────────── + + @Test + void runOrQueueSenderTraining_returnsTrue_whenIdle() { + when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isTrue(); + verify(trainingRunRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED)); + } + + @Test + void runOrQueueSenderTraining_savesQueued_whenRunning() { + when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn( + Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(1).modelName("german_kurrent").build())); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L); + when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isFalse(); + verify(trainingRunRepository).save(argThat(r -> + r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId()))); + } + + @Test + void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() { + when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isFalse(); + verify(trainingRunRepository, never()).save(any()); + } + + @Test + void runOrQueueSenderTraining_createsRunningRow_atomically() { + // RUNNING row must be created inside the same @Transactional boundary as the idle check, + // eliminating the race window between the check and a separate triggerSenderTraining call. + when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L); + when(trainingRunRepository.save(any())).thenAnswer(inv -> { + OcrTrainingRun r = inv.getArgument(0); + if (r.getId() == null) r.setId(UUID.randomUUID()); + return r; + }); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isTrue(); + verify(trainingRunRepository).save(argThat(r -> + r.getStatus() == TrainingStatus.RUNNING && personId.equals(r.getPersonId()))); + } + + // ─── triggerSenderTraining — happy path ────────────────────────────────── + + @Test + void triggerSenderTraining_savesModelRecord_onSuccess() throws Exception { + OcrTrainingRun runningRun = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .personId(personId).blockCount(10).documentCount(0) + .modelName("sender_" + personId).build(); + when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING)) + .thenReturn(Optional.of(runningRun)); + when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {}); + when(ocrClient.trainSenderModel(any(), any())) + .thenReturn(new OcrClient.TrainingResult(0.08, 0.92, 0.08, 5)); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + when(senderModelRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)) + .thenReturn(Optional.empty()); + + service.triggerSenderTraining(personId, 120); + + verify(senderModelRepository).save(argThat(m -> + ("/app/models/sender_" + personId + ".mlmodel").equals(m.getModelPath()) + && Double.valueOf(0.08).equals(m.getCer()) + && personId.equals(m.getPersonId()))); + verify(trainingRunRepository, atLeastOnce()).save(argThat(r -> + r.getStatus() == TrainingStatus.DONE + && Double.valueOf(0.08).equals(r.getCer()) + && r.getCompletedAt() != null)); + } + + // ─── triggerSenderTraining — failure path ──────────────────────────────── + + @Test + void triggerSenderTraining_marksRunFailed_whenOcrClientThrows() throws Exception { + OcrTrainingRun runningRun = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .personId(personId).blockCount(10).documentCount(0) + .modelName("sender_" + personId).build(); + when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING)) + .thenReturn(Optional.of(runningRun)); + when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {}); + when(ocrClient.trainSenderModel(any(), any())) + .thenThrow(new RuntimeException("OCR service timeout")); + when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)) + .thenReturn(Optional.empty()); + + service.triggerSenderTraining(personId, 120); + + verify(trainingRunRepository, atLeastOnce()).save(argThat(r -> + r.getStatus() == TrainingStatus.FAILED + && r.getErrorMessage() != null + && r.getCompletedAt() != null)); + verify(senderModelRepository, never()).save(any()); + } + + // ─── triggerSenderTraining — queue promotion ────────────────────────────── + + @Test + void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception { + UUID nextPersonId = UUID.randomUUID(); + + OcrTrainingRun runningRun = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .personId(personId).blockCount(10).documentCount(0) + .modelName("sender_" + personId).build(); + when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING)) + .thenReturn(Optional.of(runningRun)); + when(trainingRunRepository.save(any())).thenAnswer(inv -> { + OcrTrainingRun r = inv.getArgument(0); + if (r.getId() == null) r.setId(UUID.randomUUID()); + return r; + }); + when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {}); + when(ocrClient.trainSenderModel(any(), any())) + .thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3)); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + + OcrTrainingRun queued = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.QUEUED) + .personId(nextPersonId).blockCount(5).documentCount(0) + .modelName("sender_" + nextPersonId).build(); + when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)) + .thenReturn(Optional.of(queued)); + when(blockRepository.countManualKurrentBlocksByPerson(nextPersonId)).thenReturn(5L); + + SenderModelService spy = spy(service); + // Stub the recursive call to stop the chain after one promotion + doNothing().when(spy).triggerSenderTraining(eq(nextPersonId), anyInt()); + + spy.triggerSenderTraining(personId, 10); + + // Queued run promoted to RUNNING + verify(trainingRunRepository).save(argThat(r -> + nextPersonId.equals(r.getPersonId()) && r.getStatus() == TrainingStatus.RUNNING)); + // Next training was kicked off + verify(spy).triggerSenderTraining(eq(nextPersonId), anyInt()); + } +} diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceGuidedTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceGuidedTest.java index e8d53cad..59eed034 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceGuidedTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceGuidedTest.java @@ -22,6 +22,7 @@ class TranscriptionServiceGuidedTest { AnnotationRepository annotationRepository; AnnotationService annotationService; DocumentService documentService; + SenderModelService senderModelService; TranscriptionService service; UUID docId = UUID.randomUUID(); @@ -35,9 +36,10 @@ class TranscriptionServiceGuidedTest { annotationRepository = mock(AnnotationRepository.class); annotationService = mock(AnnotationService.class); documentService = mock(DocumentService.class); + senderModelService = mock(SenderModelService.class); service = new TranscriptionService(blockRepository, versionRepository, - annotationRepository, annotationService, documentService); + annotationRepository, annotationService, documentService, senderModelService); when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java index f8ca7753..7fd2aee7 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java @@ -13,6 +13,8 @@ import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.model.BlockSource; import org.raddatz.familienarchiv.model.Document; import org.raddatz.familienarchiv.model.DocumentAnnotation; +import org.raddatz.familienarchiv.model.Person; +import org.raddatz.familienarchiv.model.ScriptType; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.model.TranscriptionBlockVersion; import org.raddatz.familienarchiv.repository.AnnotationRepository; @@ -39,6 +41,7 @@ class TranscriptionServiceTest { @Mock AnnotationRepository annotationRepository; @Mock AnnotationService annotationService; @Mock DocumentService documentService; + @Mock SenderModelService senderModelService; @InjectMocks TranscriptionService transcriptionService; // ─── getBlock ──────────────────────────────────────────────────────────────── @@ -156,6 +159,8 @@ class TranscriptionServiceTest { .id(blockId).documentId(docId).text("old").build(); when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.TYPEWRITER).build()); UpdateTranscriptionBlockDTO dto = new UpdateTranscriptionBlockDTO("new text", null); @@ -175,6 +180,8 @@ class TranscriptionServiceTest { .id(blockId).documentId(docId).text("text").label("old label").build(); when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.TYPEWRITER).build()); UpdateTranscriptionBlockDTO dto = new UpdateTranscriptionBlockDTO("text", "Anrede"); @@ -183,6 +190,57 @@ class TranscriptionServiceTest { assertThat(result.getLabel()).isEqualTo("Anrede"); } + @Test + void updateBlock_setsSourceToManual() { + UUID docId = UUID.randomUUID(); + UUID blockId = UUID.randomUUID(); + TranscriptionBlock block = TranscriptionBlock.builder() + .id(blockId).documentId(docId).text("old").source(BlockSource.OCR).build(); + when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); + when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.TYPEWRITER).build()); + + TranscriptionBlock result = transcriptionService.updateBlock( + docId, blockId, new UpdateTranscriptionBlockDTO("new", null), UUID.randomUUID()); + + assertThat(result.getSource()).isEqualTo(BlockSource.MANUAL); + } + + @Test + void updateBlock_triggersTraining_whenKurrentSenderPresent() { + UUID docId = UUID.randomUUID(); + UUID blockId = UUID.randomUUID(); + UUID senderId = UUID.randomUUID(); + Person sender = Person.builder().id(senderId).build(); + TranscriptionBlock block = TranscriptionBlock.builder() + .id(blockId).documentId(docId).text("old").build(); + when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); + when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).sender(sender).build()); + + transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID()); + + verify(senderModelService).checkAndTriggerTraining(senderId); + } + + @Test + void updateBlock_doesNotTriggerTraining_whenNoSender() { + UUID docId = UUID.randomUUID(); + UUID blockId = UUID.randomUUID(); + TranscriptionBlock block = TranscriptionBlock.builder() + .id(blockId).documentId(docId).text("old").build(); + when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); + when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).build()); + + transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID()); + + verify(senderModelService, never()).checkAndTriggerTraining(any()); + } + // ─── deleteBlock ───────────────────────────────────────────────────────────── @Test diff --git a/docker-compose.yml b/docker-compose.yml index e9105e3a..5593a105 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -97,6 +97,7 @@ services: DETECTOR_BATCH_SIZE: "8" OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency) OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page) + OCR_MAX_CACHED_MODELS: "2" # LRU cache; each model ~500 MB, so 2 = ~1 GB resident networks: - archive-net healthcheck: diff --git a/frontend/messages/de.json b/frontend/messages/de.json index 72ae0b6b..3354ac57 100644 --- a/frontend/messages/de.json +++ b/frontend/messages/de.json @@ -330,6 +330,7 @@ "comment_time_days": "vor {count} Tag(en)", "comment_panel_title": "Kommentare", "comment_panel_close": "Schließen", + "comp_dismiss": "Schließen", "doc_panel_tab_metadata": "Metadaten", "doc_panel_tab_transcription": "Transkription", "doc_panel_tab_discussion": "Diskussion", @@ -541,6 +542,7 @@ "training_start_btn": "Training starten", "training_in_progress": "…", "training_success": "Training wurde gestartet und abgeschlossen.", + "training_start_failed": "Training konnte nicht gestartet werden. Bitte versuche es erneut.", "training_too_few_blocks": "Mindestens 5 geprüfte Blöcke erforderlich (aktuell: {available}).", "training_service_down": "OCR-Dienst ist nicht erreichbar.", "training_history_heading": "Verlauf", @@ -561,6 +563,11 @@ "transcription_block_segmentation_only": "Nur Segmentierung", "training_chip_kurrent": "Kurrent-Erkennung", "training_chip_segmentation": "Segmentierung", + "training_col_type": "Typ", + "training_type_base": "Basis", + "training_type_personalized": "Personalisiert", + "training_col_person": "Absender", + "training_status_queued": "Warteschlange", "mission_control_heading": "Was braucht Aufmerksamkeit?", "mission_control_segmentation_heading": "Text markieren", "mission_control_segmentation_description": "Textbereiche markieren — keine Vorkenntnisse nötig", diff --git a/frontend/messages/en.json b/frontend/messages/en.json index c3711b3f..4c338b66 100644 --- a/frontend/messages/en.json +++ b/frontend/messages/en.json @@ -330,6 +330,7 @@ "comment_time_days": "{count} day(s) ago", "comment_panel_title": "Comments", "comment_panel_close": "Close", + "comp_dismiss": "Dismiss", "doc_panel_tab_metadata": "Metadata", "doc_panel_tab_transcription": "Transcription", "doc_panel_tab_discussion": "Discussion", @@ -541,6 +542,7 @@ "training_start_btn": "Start training", "training_in_progress": "…", "training_success": "Training started and completed.", + "training_start_failed": "Failed to start training. Please try again.", "training_too_few_blocks": "At least 5 reviewed blocks required (currently: {available}).", "training_service_down": "OCR service is unavailable.", "training_history_heading": "History", @@ -561,6 +563,11 @@ "transcription_block_segmentation_only": "Segmentation only", "training_chip_kurrent": "Kurrent recognition", "training_chip_segmentation": "Segmentation", + "training_col_type": "Type", + "training_type_base": "Base", + "training_type_personalized": "Personalized", + "training_col_person": "Sender", + "training_status_queued": "Queued", "mission_control_heading": "What needs attention?", "mission_control_segmentation_heading": "Mark text", "mission_control_segmentation_description": "Mark text areas — no prior knowledge needed", diff --git a/frontend/messages/es.json b/frontend/messages/es.json index c87a2d25..529a7f64 100644 --- a/frontend/messages/es.json +++ b/frontend/messages/es.json @@ -330,6 +330,7 @@ "comment_time_days": "hace {count} día(s)", "comment_panel_title": "Comentarios", "comment_panel_close": "Cerrar", + "comp_dismiss": "Cerrar", "doc_panel_tab_metadata": "Metadatos", "doc_panel_tab_transcription": "Transcripción", "doc_panel_tab_discussion": "Discusión", @@ -541,6 +542,7 @@ "training_start_btn": "Iniciar entrenamiento", "training_in_progress": "…", "training_success": "Entrenamiento iniciado y completado.", + "training_start_failed": "No se pudo iniciar el entrenamiento. Por favor, inténtalo de nuevo.", "training_too_few_blocks": "Se requieren al menos 5 bloques revisados (actualmente: {available}).", "training_service_down": "El servicio OCR no está disponible.", "training_history_heading": "Historial", @@ -561,6 +563,11 @@ "transcription_block_segmentation_only": "Solo segmentación", "training_chip_kurrent": "Reconocimiento Kurrent", "training_chip_segmentation": "Segmentación", + "training_col_type": "Tipo", + "training_type_base": "Base", + "training_type_personalized": "Personalizado", + "training_col_person": "Remitente", + "training_status_queued": "En cola", "mission_control_heading": "¿Qué necesita atención?", "mission_control_segmentation_heading": "Marcar texto", "mission_control_segmentation_description": "Marcar áreas de texto — sin conocimientos previos", diff --git a/frontend/src/lib/components/OcrTrainingCard.svelte b/frontend/src/lib/components/OcrTrainingCard.svelte index e34793af..77a62a94 100644 --- a/frontend/src/lib/components/OcrTrainingCard.svelte +++ b/frontend/src/lib/components/OcrTrainingCard.svelte @@ -1,26 +1,16 @@