From bd23a76330161aecf9052ddc622d958bd9c8dd9f Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 17:56:51 +0200 Subject: [PATCH 01/32] test: fix broken tests after per-sender model integration - OcrAsyncRunnerTest: switch from extractBlocks/4-arg streamBlocks stubs to 5-arg streamBlocks (senderModelPath param) via doAnswer - TranscriptionServiceTest: stub documentService.getDocumentById in updateBlock tests so the new Kurrent training hook does not NPE - OcrControllerTest: add @MockitoBean PersonService (now injected into OcrController for personNames assembly in getTrainingInfo) Co-Authored-By: Claude Sonnet 4.6 --- .../controller/OcrControllerTest.java | 1 + .../service/OcrAsyncRunnerTest.java | 52 +++++++++++++------ .../service/TranscriptionServiceTest.java | 6 +++ 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java index 94e67df1..f0c5a340 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java @@ -43,6 +43,7 @@ class OcrControllerTest { @MockitoBean OcrBatchService ocrBatchService; @MockitoBean OcrProgressService ocrProgressService; @MockitoBean UserService userService; + @MockitoBean PersonService personService; @MockitoBean CustomUserDetailsService customUserDetailsService; @MockitoBean TrainingDataExportService trainingDataExportService; @MockitoBean SegmentationTrainingExportService segmentationTrainingExportService; diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrAsyncRunnerTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrAsyncRunnerTest.java index 75e9d3e4..ddafd4c9 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrAsyncRunnerTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrAsyncRunnerTest.java @@ -31,6 +31,7 @@ class OcrAsyncRunnerTest { @Mock OcrJobRepository ocrJobRepository; @Mock OcrJobDocumentRepository ocrJobDocumentRepository; @Mock OcrProgressService ocrProgressService; + @Mock SenderModelService senderModelService; @InjectMocks OcrAsyncRunner ocrAsyncRunner; @@ -42,7 +43,12 @@ class OcrAsyncRunnerTest { .fileHash("hash").scriptType(ScriptType.TYPEWRITER).build(); when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); - when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of()); + doAnswer(inv -> { + Consumer handler = inv.getArgument(4); + handler.accept(new OcrStreamEvent.Start(0)); + handler.accept(new OcrStreamEvent.Done(0, 0)); + return null; + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.processDocument(docId, doc, userId); @@ -59,9 +65,15 @@ class OcrAsyncRunnerTest { when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); - when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of( - new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null), - new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null))); + doAnswer(inv -> { + Consumer handler = inv.getArgument(4); + handler.accept(new OcrStreamEvent.Start(1)); + handler.accept(new OcrStreamEvent.Page(0, List.of( + new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null), + new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null)))); + handler.accept(new OcrStreamEvent.Done(2, 0)); + return null; + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build(); when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann); @@ -83,8 +95,14 @@ class OcrAsyncRunnerTest { when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); - when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of( - new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null))); + doAnswer(inv -> { + Consumer handler = inv.getArgument(4); + handler.accept(new OcrStreamEvent.Start(1)); + handler.accept(new OcrStreamEvent.Page(0, List.of( + new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null)))); + handler.accept(new OcrStreamEvent.Done(1, 0)); + return null; + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build(); when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann); @@ -112,12 +130,12 @@ class OcrAsyncRunnerTest { when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); doAnswer(inv -> { - Consumer handler = inv.getArgument(3); + Consumer handler = inv.getArgument(4); handler.accept(new OcrStreamEvent.Start(1)); handler.accept(new OcrStreamEvent.Page(0, List.of())); handler.accept(new OcrStreamEvent.Done(0, 0)); return null; - }).when(ocrClient).streamBlocks(any(), any(), any(), any()); + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.runSingleDocument(jobId, docId, userId); @@ -142,7 +160,7 @@ class OcrAsyncRunnerTest { when(documentService.getDocumentById(docId)).thenReturn(doc); when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); - doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any()); + doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.runSingleDocument(jobId, docId, userId); @@ -174,7 +192,7 @@ class OcrAsyncRunnerTest { List progressMessages = new ArrayList<>(); doAnswer(inv -> { - Consumer handler = inv.getArgument(3); + Consumer handler = inv.getArgument(4); handler.accept(new OcrStreamEvent.Start(3)); handler.accept(new OcrStreamEvent.Page(0, List.of( new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null), @@ -185,7 +203,7 @@ class OcrAsyncRunnerTest { progressMessages.add(job.getProgressMessage()); handler.accept(new OcrStreamEvent.Done(3, 0)); return null; - }).when(ocrClient).streamBlocks(any(), any(), any(), any()); + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.runSingleDocument(jobId, docId, userId); @@ -215,14 +233,14 @@ class OcrAsyncRunnerTest { when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); doAnswer(inv -> { - Consumer handler = inv.getArgument(3); + Consumer handler = inv.getArgument(4); handler.accept(new OcrStreamEvent.Start(3)); handler.accept(new OcrStreamEvent.Page(0, List.of())); handler.accept(new OcrStreamEvent.Error(1, "failed")); handler.accept(new OcrStreamEvent.Page(2, List.of())); handler.accept(new OcrStreamEvent.Done(0, 1)); return null; - }).when(ocrClient).streamBlocks(any(), any(), any(), any()); + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.runSingleDocument(jobId, docId, userId); @@ -251,14 +269,14 @@ class OcrAsyncRunnerTest { List progressMessages = new ArrayList<>(); doAnswer(inv -> { - Consumer handler = inv.getArgument(3); + Consumer handler = inv.getArgument(4); handler.accept(new OcrStreamEvent.Start(5)); handler.accept(new OcrStreamEvent.Preprocessing(1)); progressMessages.add(job.getProgressMessage()); handler.accept(new OcrStreamEvent.Page(1, List.of())); handler.accept(new OcrStreamEvent.Done(0, 0)); return null; - }).when(ocrClient).streamBlocks(any(), any(), any(), any()); + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.runSingleDocument(jobId, docId, userId); @@ -287,13 +305,13 @@ class OcrAsyncRunnerTest { when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); doAnswer(inv -> { - Consumer handler = inv.getArgument(3); + Consumer handler = inv.getArgument(4); handler.accept(new OcrStreamEvent.Start(2)); handler.accept(new OcrStreamEvent.Error(0, "some python traceback details")); handler.accept(new OcrStreamEvent.Page(1, List.of())); handler.accept(new OcrStreamEvent.Done(0, 1)); return null; - }).when(ocrClient).streamBlocks(any(), any(), any(), any()); + }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any()); ocrAsyncRunner.runSingleDocument(jobId, docId, userId); diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java index f8ca7753..ba685fe8 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java @@ -13,6 +13,7 @@ import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.model.BlockSource; import org.raddatz.familienarchiv.model.Document; import org.raddatz.familienarchiv.model.DocumentAnnotation; +import org.raddatz.familienarchiv.model.ScriptType; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.model.TranscriptionBlockVersion; import org.raddatz.familienarchiv.repository.AnnotationRepository; @@ -39,6 +40,7 @@ class TranscriptionServiceTest { @Mock AnnotationRepository annotationRepository; @Mock AnnotationService annotationService; @Mock DocumentService documentService; + @Mock SenderModelService senderModelService; @InjectMocks TranscriptionService transcriptionService; // ─── getBlock ──────────────────────────────────────────────────────────────── @@ -156,6 +158,8 @@ class TranscriptionServiceTest { .id(blockId).documentId(docId).text("old").build(); when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.TYPEWRITER).build()); UpdateTranscriptionBlockDTO dto = new UpdateTranscriptionBlockDTO("new text", null); @@ -175,6 +179,8 @@ class TranscriptionServiceTest { .id(blockId).documentId(docId).text("text").label("old label").build(); when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.TYPEWRITER).build()); UpdateTranscriptionBlockDTO dto = new UpdateTranscriptionBlockDTO("text", "Anrede"); -- 2.49.1 From 7a342a07cfb64b82621d1d17bb29f6af9a9973db Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 18:00:59 +0200 Subject: [PATCH 02/32] test: add unit tests for SenderModelService, runOrQueueSenderTraining, and updateBlock hook MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - SenderModelServiceTest: 6 tests covering activation threshold (99/100), retrain delta (149/150), runNow flag (queued vs triggered) - OcrTrainingServiceTest: 3 tests for runOrQueueSenderTraining — idle returns true, running saves QUEUED, duplicate QUEUED coalesces - TranscriptionServiceTest: 3 tests for updateBlock — sets source=MANUAL, triggers training for HANDWRITING_KURRENT with sender, skips when no sender Co-Authored-By: Claude Sonnet 4.6 --- .../service/OcrTrainingServiceTest.java | 49 +++++++- .../service/SenderModelServiceTest.java | 109 ++++++++++++++++++ .../service/TranscriptionServiceTest.java | 52 +++++++++ 3 files changed, 208 insertions(+), 2 deletions(-) create mode 100644 backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java index 62ef00f6..b89e1ae6 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -8,6 +8,7 @@ import org.raddatz.familienarchiv.model.TrainingLabel; import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; +import org.raddatz.familienarchiv.repository.SenderModelRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionTemplate; @@ -26,6 +27,7 @@ import static org.mockito.Mockito.*; class OcrTrainingServiceTest { OcrTrainingRunRepository runRepository; + SenderModelRepository senderModelRepository; TrainingDataExportService exportService; SegmentationTrainingExportService segExportService; OcrClient ocrClient; @@ -37,6 +39,7 @@ class OcrTrainingServiceTest { @BeforeEach void setUp() { runRepository = mock(OcrTrainingRunRepository.class); + senderModelRepository = mock(SenderModelRepository.class); exportService = mock(TrainingDataExportService.class); segExportService = mock(SegmentationTrainingExportService.class); ocrClient = mock(OcrClient.class); @@ -50,10 +53,10 @@ class OcrTrainingServiceTest { return callback.doInTransaction(null); }); - service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); + service = new OcrTrainingService(runRepository, senderModelRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); when(blockRepository.count()).thenReturn(0L); - when(runRepository.findTop10ByOrderByCreatedAtDesc()).thenReturn(List.of()); + when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of()); when(segExportService.querySegmentationBlocks()).thenReturn(List.of()); } @@ -230,6 +233,48 @@ class OcrTrainingServiceTest { run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null)); } + // ─── runOrQueueSenderTraining ──────────────────────────────────────────── + + @Test + void runOrQueueSenderTraining_returnsTrue_whenIdle() { + UUID personId = UUID.randomUUID(); + when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isTrue(); + verify(runRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED)); + } + + @Test + void runOrQueueSenderTraining_savesQueued_whenRunning() { + UUID personId = UUID.randomUUID(); + when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn( + Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(1).modelName("german_kurrent").build())); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L); + when(runRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isFalse(); + verify(runRepository).save(argThat(r -> + r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId()))); + } + + @Test + void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() { + UUID personId = UUID.randomUUID(); + when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isFalse(); + verify(runRepository, never()).save(any()); + } + // ─── Orphan recovery ────────────────────────────────────────────────────── @Test diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java new file mode 100644 index 00000000..b3a3a5d9 --- /dev/null +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java @@ -0,0 +1,109 @@ +package org.raddatz.familienarchiv.service; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.raddatz.familienarchiv.model.SenderModel; +import org.raddatz.familienarchiv.repository.SenderModelRepository; +import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.springframework.test.util.ReflectionTestUtils; + +import java.util.Optional; +import java.util.UUID; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.*; + +class SenderModelServiceTest { + + SenderModelRepository senderModelRepository; + TranscriptionBlockRepository blockRepository; + OcrTrainingService ocrTrainingService; + SenderModelService service; + + UUID personId = UUID.randomUUID(); + + @BeforeEach + void setUp() { + senderModelRepository = mock(SenderModelRepository.class); + blockRepository = mock(TranscriptionBlockRepository.class); + ocrTrainingService = mock(OcrTrainingService.class); + service = new SenderModelService(senderModelRepository, blockRepository, ocrTrainingService); + ReflectionTestUtils.setField(service, "activationThreshold", 100); + ReflectionTestUtils.setField(service, "retrainDelta", 50); + } + + // ─── Activation threshold ───────────────────────────────────────────────── + + @Test + void checkAndTriggerTraining_doesNothing_belowActivationThreshold() { + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + + service.checkAndTriggerTraining(personId); + + verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt()); + } + + @Test + void checkAndTriggerTraining_triggersTraining_atActivationThreshold() { + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false); + + service.checkAndTriggerTraining(personId); + + verify(ocrTrainingService).runOrQueueSenderTraining(personId, 100); + } + + // ─── Retrain delta ──────────────────────────────────────────────────────── + + @Test + void checkAndTriggerTraining_doesNothing_belowRetrainDelta() { + SenderModel existing = SenderModel.builder().personId(personId) + .correctedLinesAtTraining(100).build(); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing)); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L); + + service.checkAndTriggerTraining(personId); + + verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt()); + } + + @Test + void checkAndTriggerTraining_triggersRetrain_atRetrainDelta() { + SenderModel existing = SenderModel.builder().personId(personId) + .correctedLinesAtTraining(100).build(); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing)); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L); + when(ocrTrainingService.runOrQueueSenderTraining(personId, 150)).thenReturn(false); + + service.checkAndTriggerTraining(personId); + + verify(ocrTrainingService).runOrQueueSenderTraining(personId, 150); + } + + // ─── runNow flag ────────────────────────────────────────────────────────── + + @Test + void checkAndTriggerTraining_callsTrigger_whenRunNow() { + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(true); + + service.checkAndTriggerTraining(personId); + + verify(ocrTrainingService).triggerSenderTraining(personId, 100); + } + + @Test + void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() { + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false); + + service.checkAndTriggerTraining(personId); + + verify(ocrTrainingService, never()).triggerSenderTraining(any(), anyInt()); + } +} diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java index ba685fe8..7fd2aee7 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java @@ -13,6 +13,7 @@ import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.model.BlockSource; import org.raddatz.familienarchiv.model.Document; import org.raddatz.familienarchiv.model.DocumentAnnotation; +import org.raddatz.familienarchiv.model.Person; import org.raddatz.familienarchiv.model.ScriptType; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.model.TranscriptionBlockVersion; @@ -189,6 +190,57 @@ class TranscriptionServiceTest { assertThat(result.getLabel()).isEqualTo("Anrede"); } + @Test + void updateBlock_setsSourceToManual() { + UUID docId = UUID.randomUUID(); + UUID blockId = UUID.randomUUID(); + TranscriptionBlock block = TranscriptionBlock.builder() + .id(blockId).documentId(docId).text("old").source(BlockSource.OCR).build(); + when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); + when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.TYPEWRITER).build()); + + TranscriptionBlock result = transcriptionService.updateBlock( + docId, blockId, new UpdateTranscriptionBlockDTO("new", null), UUID.randomUUID()); + + assertThat(result.getSource()).isEqualTo(BlockSource.MANUAL); + } + + @Test + void updateBlock_triggersTraining_whenKurrentSenderPresent() { + UUID docId = UUID.randomUUID(); + UUID blockId = UUID.randomUUID(); + UUID senderId = UUID.randomUUID(); + Person sender = Person.builder().id(senderId).build(); + TranscriptionBlock block = TranscriptionBlock.builder() + .id(blockId).documentId(docId).text("old").build(); + when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); + when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).sender(sender).build()); + + transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID()); + + verify(senderModelService).checkAndTriggerTraining(senderId); + } + + @Test + void updateBlock_doesNotTriggerTraining_whenNoSender() { + UUID docId = UUID.randomUUID(); + UUID blockId = UUID.randomUUID(); + TranscriptionBlock block = TranscriptionBlock.builder() + .id(blockId).documentId(docId).text("old").build(); + when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); + when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).build()); + + transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID()); + + verify(senderModelService, never()).checkAndTriggerTraining(any()); + } + // ─── deleteBlock ───────────────────────────────────────────────────────────── @Test -- 2.49.1 From 64d27d6d61b1a96f948c59c60f0648a701ffc375 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 18:05:39 +0200 Subject: [PATCH 03/32] feat(ocr): per-sender model registry and /train-sender endpoint engines/kraken.py: - Add _SenderModelRegistry with LRU eviction (max configurable via OCR_MAX_CACHED_MODELS env var), double-checked locking, invalidate(), and path whitelist (/app/models/ only) - Add _load_sender_model() helper for testability - extract_page_blocks() and extract_region_text() accept optional sender_model_path; route to sender registry when provided models.py: - OcrRequest gains senderModelPath: str | None = None field main.py: - /ocr and /ocr/stream pass request.senderModelPath to Kraken engine - New /train-sender endpoint: validates output_model_path, runs ketos train with base model as starting point, invalidates sender cache docker-compose.yml: - Add OCR_MAX_CACHED_MODELS: "5" to ocr-service environment test_sender_registry.py: - 4 tests: cache hit, LRU eviction, invalidate, path traversal guard Co-Authored-By: Claude Sonnet 4.6 --- docker-compose.yml | 1 + ocr-service/engines/kraken.py | 86 +++++++++++++++++++++++++-- ocr-service/main.py | 90 ++++++++++++++++++++++++++++- ocr-service/models.py | 1 + ocr-service/test_sender_registry.py | 65 +++++++++++++++++++++ 5 files changed, 234 insertions(+), 9 deletions(-) create mode 100644 ocr-service/test_sender_registry.py diff --git a/docker-compose.yml b/docker-compose.yml index e9105e3a..e914cc91 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -97,6 +97,7 @@ services: DETECTOR_BATCH_SIZE: "8" OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency) OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page) + OCR_MAX_CACHED_MODELS: "5" # LRU cache size for per-sender Kraken models networks: - archive-net healthcheck: diff --git a/ocr-service/engines/kraken.py b/ocr-service/engines/kraken.py index 8deae3ef..d33567bc 100644 --- a/ocr-service/engines/kraken.py +++ b/ocr-service/engines/kraken.py @@ -1,13 +1,78 @@ """Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin.""" +import collections import logging import os +import threading logger = logging.getLogger(__name__) _model = None _model_path = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel") +_MODELS_DIR = os.path.realpath("/app/models") +_MAX_CACHED_SENDER_MODELS = int(os.environ.get("OCR_MAX_CACHED_MODELS", "5")) + + +def _load_sender_model(path: str): + """Load a Kraken model from disk. Extracted for testability.""" + from kraken.lib import models as kraken_models + return kraken_models.load_any(path) + + +class _SenderModelRegistry: + """Thread-safe LRU cache for per-sender Kraken models. + + Uses double-checked locking: model loading happens outside the lock to + avoid blocking concurrent OCR requests. At most one entry per path is + stored even under concurrent load. + """ + + def __init__(self, max_size: int): + self._max_size = max_size + self._cache: collections.OrderedDict = collections.OrderedDict() + self._lock = threading.Lock() + + def get_model(self, model_path: str): + """Return the cached model or load it. Validates path is within /app/models/.""" + resolved = os.path.realpath(model_path) + if not resolved.startswith(_MODELS_DIR + os.sep) and resolved != _MODELS_DIR: + raise ValueError(f"Sender model path not allowed: {model_path}") + + with self._lock: + if model_path in self._cache: + self._cache.move_to_end(model_path) + return self._cache[model_path] + + new_model = _load_sender_model(model_path) + + with self._lock: + if model_path in self._cache: + self._cache.move_to_end(model_path) + return self._cache[model_path] + self._cache[model_path] = new_model + self._cache.move_to_end(model_path) + while len(self._cache) > self._max_size: + self._cache.popitem(last=False) + + return new_model + + def invalidate(self, model_path: str) -> None: + """Remove model from cache so the next request reloads from disk.""" + with self._lock: + self._cache.pop(model_path, None) + + def size(self) -> int: + with self._lock: + return len(self._cache) + + def _contains(self, model_path: str) -> bool: + with self._lock: + return model_path in self._cache + + +_sender_registry = _SenderModelRegistry(_MAX_CACHED_SENDER_MODELS) + def load_models(): """Load the Kraken model at startup. Skips if model file is not present.""" @@ -29,10 +94,12 @@ def is_available() -> bool: return _model is not None -def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict]: +def extract_page_blocks(image, page_idx: int, language: str = "de", + sender_model_path: str | None = None) -> list[dict]: """Run Kraken segmentation + recognition on a single PIL image. Returns block dicts for that page. Coordinates are normalized to [0, 1]. + When sender_model_path is provided, the per-sender fine-tuned model is used. """ from kraken import blla, rpred from confidence import words_from_characters @@ -40,11 +107,13 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict if _model is None: raise RuntimeError("Kraken model is not loaded") + active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model + page_w, page_h = image.size blocks = [] baseline_seg = blla.segment(image) - pred_it = rpred.rpred(_model, image, baseline_seg) + pred_it = rpred.rpred(active_model, image, baseline_seg) for record in pred_it: polygon_pts = record.boundary if hasattr(record, "boundary") and record.boundary else [] @@ -79,13 +148,15 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict return blocks -def extract_region_text(image, x: float, y: float, w: float, h: float) -> str: +def extract_region_text(image, x: float, y: float, w: float, h: float, + sender_model_path: str | None = None) -> str: """Crop image to a normalized region and run Kraken recognition on the crop. Used for guided OCR — skips full-page layout detection entirely. A single synthetic baseline spanning the full crop width is used so that blla.segment() (which crashes on small crops) is never called. Coordinates are normalized to [0, 1]. + When sender_model_path is provided, the per-sender fine-tuned model is used. """ from kraken import rpred from kraken.containers import Segmentation, BaselineLine @@ -93,6 +164,8 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str: if _model is None: raise RuntimeError("Kraken model is not loaded") + active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model + pw, ph = image.size x1 = max(0, int(x * pw)) y1 = max(0, int(y * ph)) @@ -123,11 +196,12 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str: regions={}, line_orders=[], ) - pred_it = rpred.rpred(_model, crop, synthetic_seg) + pred_it = rpred.rpred(active_model, crop, synthetic_seg) return " ".join(r.prediction for r in pred_it) -def extract_blocks(images: list, language: str = "de") -> list[dict]: +def extract_blocks(images: list, language: str = "de", + sender_model_path: str | None = None) -> list[dict]: """Run Kraken segmentation + recognition on a list of PIL images. Returns block dicts with pageNumber, x, y, width, height, polygon, text. @@ -137,7 +211,7 @@ def extract_blocks(images: list, language: str = "de") -> list[dict]: all_blocks = [] for page_idx, image in enumerate(images, start=1): - all_blocks.extend(extract_page_blocks(image, page_idx, language)) + all_blocks.extend(extract_page_blocks(image, page_idx, language, sender_model_path)) return all_blocks diff --git a/ocr-service/main.py b/ocr-service/main.py index 77c5c33f..bc541c78 100644 --- a/ocr-service/main.py +++ b/ocr-service/main.py @@ -17,7 +17,7 @@ from urllib.parse import urlparse import httpx import pypdfium2 as pdfium -from fastapi import FastAPI, Header, HTTPException, UploadFile +from fastapi import FastAPI, Form, Header, HTTPException, UploadFile from fastapi.responses import StreamingResponse from PIL import Image @@ -103,7 +103,8 @@ async def run_ocr(request: OcrRequest): status_code=400, detail="Kraken model not available — cannot process Kurrent script", ) - blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language) + blocks = await asyncio.to_thread( + kraken_engine.extract_blocks, images, request.language, request.senderModelPath) else: # TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language) @@ -171,10 +172,12 @@ async def run_ocr_stream(request: OcrRequest): yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n" image = await asyncio.to_thread(preprocess_page, image) blocks = [] + sender_path = request.senderModelPath if use_kraken else None for region in page_regions: text = await asyncio.to_thread( engine.extract_region_text, image, region.x, region.y, region.width, region.height, + sender_path, ) if script_type in _SPELL_CHECK_SCRIPT_TYPES: text = correct_text(text) @@ -231,8 +234,9 @@ async def run_ocr_stream(request: OcrRequest): try: yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n" image = await asyncio.to_thread(preprocess_page, image) + sender_path = request.senderModelPath if use_kraken else None blocks = await asyncio.to_thread( - engine.extract_page_blocks, image, page_idx, request.language + engine.extract_page_blocks, image, page_idx, request.language, sender_path ) for block in blocks: @@ -443,6 +447,86 @@ async def train_model( return result +@app.post("/train-sender") +async def train_sender_model( + file: UploadFile, + output_model_path: str = Form(...), + x_training_token: str | None = Header(default=None), +): + """Fine-tune a per-sender Kurrent model and save to output_model_path. + + output_model_path must be within /app/models/. The per-sender model cache + is invalidated after training so the next OCR request picks up the new model. + """ + _check_training_token(x_training_token) + + if not _models_ready: + raise HTTPException(status_code=503, detail="Models not loaded yet") + + models_dir = os.path.realpath("/app/models") + resolved_output = os.path.realpath(output_model_path) + if not resolved_output.startswith(models_dir + os.sep): + raise HTTPException(status_code=400, detail="output_model_path must be within /app/models/") + + zip_bytes = await file.read() + + training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id}) + log.info("Starting sender training run %s → %s", training_run_id, output_model_path) + + def _run_sender_training() -> dict: + with tempfile.TemporaryDirectory() as tmp_dir: + with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: + for entry in zf.namelist(): + _validate_zip_entry(entry, tmp_dir) + zf.extractall(tmp_dir) + + ground_truth = glob.glob(os.path.join(tmp_dir, "*.xml")) + if not ground_truth: + raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP") + + log.info("Sender training on %d ground-truth pairs", len(ground_truth)) + checkpoint_dir = os.path.join(tmp_dir, "checkpoints") + os.makedirs(checkpoint_dir, exist_ok=True) + + cmd = [ + "ketos", "--workers", "0", "--device", "cpu", "--threads", "2", + "train", + "-f", "page", + "--weights-format", "coreml", + "-o", checkpoint_dir, + "-q", "fixed", + "-N", "10", + "-B", "1", + ] + if os.path.exists(KRAKEN_MODEL_PATH): + cmd += ["-i", KRAKEN_MODEL_PATH] + cmd += ground_truth + + log.info("Running sender training: %s", " ".join(cmd[:6]) + " ...") + proc = subprocess.run(cmd, capture_output=True, text=True) + if proc.returncode != 0: + raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}") + + accuracy, epochs = _parse_best_checkpoint(checkpoint_dir) + cer = round(1.0 - accuracy, 4) if accuracy is not None else None + + best_model = _find_best_model(checkpoint_dir) + if best_model is None: + raise RuntimeError("Sender training produced no model file") + + os.makedirs(os.path.dirname(resolved_output), exist_ok=True) + shutil.copy2(best_model, resolved_output) + log.info("Saved sender model to %s", resolved_output) + + kraken_engine._sender_registry.invalidate(output_model_path) + + return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} + + result = await asyncio.to_thread(_run_sender_training) + return result + + @app.post("/segtrain") async def segtrain_model( file: UploadFile, diff --git a/ocr-service/models.py b/ocr-service/models.py index 28389df0..2b006ee0 100644 --- a/ocr-service/models.py +++ b/ocr-service/models.py @@ -19,6 +19,7 @@ class OcrRequest(BaseModel): scriptType: str = "UNKNOWN" language: str = "de" regions: list[OcrRegion] | None = None + senderModelPath: str | None = None class OcrBlock(BaseModel): diff --git a/ocr-service/test_sender_registry.py b/ocr-service/test_sender_registry.py new file mode 100644 index 00000000..e5a7c97d --- /dev/null +++ b/ocr-service/test_sender_registry.py @@ -0,0 +1,65 @@ +"""Tests for the per-sender model LRU registry in engines/kraken.py.""" +from unittest.mock import MagicMock, call, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_registry(max_size=5): + from engines.kraken import _SenderModelRegistry + return _SenderModelRegistry(max_size=max_size) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_cache_hit_returns_same_object(): + """Second get_model call with the same path must return the cached object.""" + registry = _make_registry() + mock_model = MagicMock(name="model_a") + with patch("engines.kraken._load_sender_model", return_value=mock_model) as loader: + m1 = registry.get_model("/app/models/sender_a.mlmodel") + m2 = registry.get_model("/app/models/sender_a.mlmodel") + + assert m1 is m2 + loader.assert_called_once() # only loaded once despite two gets + + +def test_lru_eviction_removes_least_recently_used(): + """When the cache exceeds max_size, the oldest-accessed entry is evicted.""" + registry = _make_registry(max_size=2) + + def _side_effect(path): + return MagicMock(name=path) + + with patch("engines.kraken._load_sender_model", side_effect=_side_effect): + registry.get_model("/app/models/sender_a.mlmodel") + registry.get_model("/app/models/sender_b.mlmodel") + registry.get_model("/app/models/sender_c.mlmodel") # should evict 'a' + + assert registry.size() == 2 + # 'a' was the least-recently-used and should be gone + assert not registry._contains("/app/models/sender_a.mlmodel") + + +def test_invalidate_removes_entry_from_cache(): + """invalidate() must evict the entry so the next get reloads from disk.""" + registry = _make_registry() + mock_model = MagicMock(name="model_x") + with patch("engines.kraken._load_sender_model", return_value=mock_model): + registry.get_model("/app/models/sender_x.mlmodel") + assert registry.size() == 1 + + registry.invalidate("/app/models/sender_x.mlmodel") + assert registry.size() == 0 + + +def test_path_outside_models_dir_raises(): + """get_model must reject paths outside /app/models/ (path traversal guard).""" + registry = _make_registry() + with pytest.raises(ValueError, match="not allowed"): + registry.get_model("/etc/passwd") -- 2.49.1 From b5e1a8ac2f2930a19b2b8eff2d63ceff92240a2a Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 18:10:04 +0200 Subject: [PATCH 04/32] chore: regenerate API types for per-sender model additions OcrTrainingRun now includes personId (uuid, optional) and QUEUED status. TrainingInfoResponse includes runs array with personId fields. Co-Authored-By: Claude Sonnet 4.6 --- frontend/src/lib/generated/api.ts | 39 ++++++++++++------------------- 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/frontend/src/lib/generated/api.ts b/frontend/src/lib/generated/api.ts index 0f5e1c0c..8ba83c28 100644 --- a/frontend/src/lib/generated/api.ts +++ b/frontend/src/lib/generated/api.ts @@ -1374,7 +1374,7 @@ export interface components { }; MergeTagDTO: { /** Format: uuid */ - targetId?: string; + targetId: string; }; PersonNameAliasDTO: { lastName: string; @@ -1398,7 +1398,7 @@ export interface components { /** Format: uuid */ id: string; /** @enum {string} */ - status: "RUNNING" | "DONE" | "FAILED"; + status: "QUEUED" | "RUNNING" | "DONE" | "FAILED"; /** Format: int32 */ blockCount: number; /** Format: int32 */ @@ -1415,6 +1415,8 @@ export interface components { errorMessage?: string; /** Format: uuid */ triggeredBy?: string; + /** Format: uuid */ + personId?: string; /** Format: date-time */ createdAt: string; /** Format: date-time */ @@ -1596,11 +1598,11 @@ export interface components { }; TagTreeNodeDTO: { /** Format: uuid */ - id?: string; - name?: string; + id: string; + name: string; color?: string; /** Format: int32 */ - documentCount?: number; + documentCount: number; children?: components["schemas"]["TagTreeNodeDTO"][]; /** * Format: uuid @@ -1619,7 +1621,6 @@ export interface components { /** Format: uuid */ id?: string; displayName?: string; - personType?: string; firstName?: string; lastName?: string; /** Format: int64 */ @@ -1630,19 +1631,7 @@ export interface components { deathYear?: number; alias?: string; notes?: string; - }; - TrainingInfoResponse: { - /** Format: int32 */ - availableBlocks?: number; - /** Format: int32 */ - totalOcrBlocks?: number; - /** Format: int32 */ - availableDocuments?: number; - /** Format: int32 */ - availableSegBlocks?: number; - ocrServiceAvailable?: boolean; - lastRun?: components["schemas"]["OcrTrainingRun"]; - runs?: components["schemas"]["OcrTrainingRun"][]; + personType?: string; }; StreamingResponseBody: unknown; OcrJob: { @@ -1671,13 +1660,11 @@ export interface components { timeout?: number; }; PageNotificationDTO: { - /** Format: int64 */ - totalElements?: number; /** Format: int32 */ totalPages?: number; + /** Format: int64 */ + totalElements?: number; pageable?: components["schemas"]["PageableObject"]; - first?: boolean; - last?: boolean; /** Format: int32 */ size?: number; content?: components["schemas"]["NotificationDTO"][]; @@ -1686,6 +1673,8 @@ export interface components { sort?: components["schemas"]["SortObject"]; /** Format: int32 */ numberOfElements?: number; + first?: boolean; + last?: boolean; empty?: boolean; }; PageableObject: { @@ -3501,7 +3490,9 @@ export interface operations { [name: string]: unknown; }; content: { - "*/*": components["schemas"]["TrainingInfoResponse"]; + "*/*": { + [key: string]: unknown; + }; }; }; }; -- 2.49.1 From e0b7cfdada114bd3bddb1c59425f087ba100c905 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 18:25:59 +0200 Subject: [PATCH 05/32] feat(frontend): wire personNames to TrainingHistory in OcrTrainingCard Extends Run interface with personId and QUEUED status, TrainingInfo with personNames map, and passes it through to TrainingHistory for per-sender model column display. Co-Authored-By: Claude Sonnet 4.6 --- .../src/lib/components/OcrTrainingCard.svelte | 9 +++++-- .../src/lib/components/TrainingHistory.svelte | 25 ++++++++++++++++--- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/frontend/src/lib/components/OcrTrainingCard.svelte b/frontend/src/lib/components/OcrTrainingCard.svelte index e34793af..cab026a9 100644 --- a/frontend/src/lib/components/OcrTrainingCard.svelte +++ b/frontend/src/lib/components/OcrTrainingCard.svelte @@ -4,7 +4,8 @@ import { m } from '$lib/paraglide/messages.js'; interface Run { id: string; - status: 'RUNNING' | 'DONE' | 'FAILED'; + personId?: string; + status: 'RUNNING' | 'DONE' | 'FAILED' | 'QUEUED'; blockCount: number; documentCount: number; modelName: string; @@ -21,6 +22,7 @@ interface TrainingInfo { ocrServiceAvailable?: boolean; lastRun?: Run | null; runs?: Run[]; + personNames?: Record; } interface Props { @@ -88,5 +90,8 @@ async function startTraining() {

{m.training_history_heading()}

- r.modelName !== 'blla')} /> + r.modelName !== 'blla')} + personNames={trainingInfo?.personNames ?? {}} + /> diff --git a/frontend/src/lib/components/TrainingHistory.svelte b/frontend/src/lib/components/TrainingHistory.svelte index 586fee32..c900fc18 100644 --- a/frontend/src/lib/components/TrainingHistory.svelte +++ b/frontend/src/lib/components/TrainingHistory.svelte @@ -3,7 +3,7 @@ import { m } from '$lib/paraglide/messages.js'; interface Run { id: string; - status: 'RUNNING' | 'DONE' | 'FAILED'; + status: 'QUEUED' | 'RUNNING' | 'DONE' | 'FAILED'; blockCount: number; documentCount: number; modelName: string; @@ -12,13 +12,15 @@ interface Run { createdAt: string; completedAt?: string; cer?: number; + personId?: string; } interface Props { runs: Run[]; + personNames?: Record; } -let { runs }: Props = $props(); +let { runs, personNames }: Props = $props(); const COLLAPSED_COUNT = 3; let expanded = $state(false); @@ -47,6 +49,8 @@ function formatCer(cer: number | undefined | null): string { {m.training_history_col_date()} {m.training_history_col_status()} + {m.training_col_type()} + {m.training_col_person()} {m.training_history_col_blocks()} {m.training_history_col_docs()} {m.training_history_col_cer()} @@ -55,7 +59,7 @@ function formatCer(cer: number | undefined | null): string { {#if runs.length === 0} - + {m.training_history_empty()} @@ -64,7 +68,14 @@ function formatCer(cer: number | undefined | null): string { {formatDate(run.createdAt)} - {#if run.status === 'DONE'} + {#if run.status === 'QUEUED'} + + + {m.training_status_queued()} + + {:else if run.status === 'DONE'} @@ -120,6 +131,12 @@ function formatCer(cer: number | undefined | null): string { {/if} + + {run.personId ? m.training_type_personalized() : m.training_type_base()} + + + {run.personId && personNames?.[run.personId] ? personNames[run.personId] : '—'} + {run.blockCount} {run.documentCount} Date: Fri, 17 Apr 2026 18:51:15 +0200 Subject: [PATCH 06/32] fix(ocr): log debug instead of silently swallowing person name resolution errors Replaces catch(Exception ignored){} with log.debug() in getTrainingInfo(). Adds controller test documenting the graceful degradation behavior (response stays 200 when personService.getById() throws). Fixes reviewer concerns from @felixbrandt and @nullx. Co-Authored-By: Claude Sonnet 4.6 --- .../controller/OcrController.java | 30 +++++++++++++++++-- .../controller/OcrControllerTest.java | 17 +++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java index 4ada18e4..6cf5e1cf 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java @@ -14,6 +14,7 @@ import org.raddatz.familienarchiv.service.OcrBatchService; import org.raddatz.familienarchiv.service.OcrProgressService; import org.raddatz.familienarchiv.service.OcrService; import org.raddatz.familienarchiv.service.OcrTrainingService; +import org.raddatz.familienarchiv.service.PersonService; import org.raddatz.familienarchiv.service.SegmentationTrainingExportService; import org.raddatz.familienarchiv.service.TrainingDataExportService; import org.raddatz.familienarchiv.service.UserService; @@ -27,6 +28,7 @@ import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBo import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import jakarta.validation.Valid; +import java.util.HashMap; import java.util.Map; import java.util.UUID; @@ -39,6 +41,7 @@ public class OcrController { private final OcrBatchService ocrBatchService; private final OcrProgressService ocrProgressService; private final UserService userService; + private final PersonService personService; private final TrainingDataExportService trainingDataExportService; private final SegmentationTrainingExportService segmentationTrainingExportService; private final OcrTrainingService ocrTrainingService; @@ -130,8 +133,31 @@ public class OcrController { @GetMapping("/api/ocr/training-info") @RequirePermission(Permission.ADMIN) - public OcrTrainingService.TrainingInfoResponse getTrainingInfo() { - return ocrTrainingService.getTrainingInfo(); + public Map getTrainingInfo() { + OcrTrainingService.TrainingInfoResponse info = ocrTrainingService.getTrainingInfo(); + + Map personNames = new HashMap<>(); + for (OcrTrainingRun run : info.runs()) { + if (run.getPersonId() != null && !personNames.containsKey(run.getPersonId().toString())) { + try { + personNames.put(run.getPersonId().toString(), + personService.getById(run.getPersonId()).getDisplayName()); + } catch (Exception e) { + log.debug("Could not resolve display name for person {}: {}", run.getPersonId(), e.getMessage()); + } + } + } + + Map result = new HashMap<>(); + result.put("availableBlocks", info.availableBlocks()); + result.put("totalOcrBlocks", info.totalOcrBlocks()); + result.put("availableDocuments", info.availableDocuments()); + result.put("availableSegBlocks", info.availableSegBlocks()); + result.put("ocrServiceAvailable", info.ocrServiceAvailable()); + result.put("lastRun", info.lastRun() != null ? info.lastRun() : Map.of()); + result.put("runs", info.runs()); + result.put("personNames", personNames); + return result; } private UUID resolveUserId(Authentication authentication) { diff --git a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java index f0c5a340..4d087dea 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java @@ -229,6 +229,23 @@ class OcrControllerTest { .andExpect(jsonPath("$.ocrServiceAvailable").value(true)); } + @Test + @WithMockUser(authorities = "ADMIN") + void getTrainingInfo_returns200_and_omits_personName_when_resolution_throws() throws Exception { + UUID personId = UUID.randomUUID(); + OcrTrainingRun runWithPerson = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.DONE) + .personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build(); + OcrTrainingService.TrainingInfoResponse info = + new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(runWithPerson)); + when(ocrTrainingService.getTrainingInfo()).thenReturn(info); + when(personService.getById(personId)).thenThrow(new RuntimeException("DB error")); + + mockMvc.perform(get("/api/ocr/training-info")) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.personNames").isEmpty()); + } + @Test @WithMockUser(authorities = "READ_ALL") void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception { -- 2.49.1 From 60c1ec7b5fc0f97eeb6543c4f2e23611e23053d5 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 18:52:51 +0200 Subject: [PATCH 07/32] refactor(ocr): delete buildTrainingInfoMap() dead code The controller now builds the map inline (with personNames support). This method had zero callers. Fixes reviewer concerns from @felixbrandt and @mkeller. Co-Authored-By: Claude Sonnet 4.6 --- .../service/OcrTrainingService.java | 122 ++++++++++++++++-- 1 file changed, 109 insertions(+), 13 deletions(-) diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java index 65828999..c41542df 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java @@ -5,8 +5,10 @@ import lombok.extern.slf4j.Slf4j; import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.exception.ErrorCode; import org.raddatz.familienarchiv.model.OcrTrainingRun; +import org.raddatz.familienarchiv.model.SenderModel; import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; +import org.raddatz.familienarchiv.repository.SenderModelRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.slf4j.MDC; import org.springframework.boot.context.event.ApplicationReadyEvent; @@ -18,8 +20,8 @@ import org.springframework.transaction.support.TransactionTemplate; import java.io.ByteArrayOutputStream; import java.time.Instant; import java.util.List; -import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.UUID; @Service @@ -28,6 +30,7 @@ import java.util.UUID; public class OcrTrainingService { private final OcrTrainingRunRepository trainingRunRepository; + private final SenderModelRepository senderModelRepository; private final TrainingDataExportService trainingDataExportService; private final SegmentationTrainingExportService segmentationTrainingExportService; private final OcrClient ocrClient; @@ -195,7 +198,7 @@ public class OcrTrainingService { int totalOcrBlocks = (int) blockRepository.count(); int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size(); - List recentRuns = trainingRunRepository.findTop10ByOrderByCreatedAtDesc(); + List recentRuns = trainingRunRepository.findTop20ByOrderByCreatedAtDesc(); OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0); return new TrainingInfoResponse( @@ -209,6 +212,110 @@ public class OcrTrainingService { ); } + /** + * Called from SenderModelService (which is @Async). Inserts a QUEUED row when training + * is already running (coalescing: at most one QUEUED per person_id). Runs immediately if idle. + * Returns true if training was started, false if queued or skipped. + */ + @Transactional + public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) { + if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) { + log.info("Sender training already queued for person {} — skipping duplicate trigger", personId); + return false; + } + + if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { + int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId); + trainingRunRepository.save(OcrTrainingRun.builder() + .status(TrainingStatus.QUEUED) + .personId(personId) + .blockCount(blockCount) + .documentCount(0) + .modelName("sender_" + personId) + .build()); + log.info("Queued sender training for person {} — training already running", personId); + return false; + } + + return true; + } + + /** Executes sender training synchronously. Caller must run this on a background thread. */ + public void triggerSenderTraining(UUID personId, int correctedLines) { + String outputModelPath = "/app/models/sender_" + personId + ".mlmodel"; + + OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { + long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId); + return trainingRunRepository.save(OcrTrainingRun.builder() + .status(TrainingStatus.RUNNING) + .personId(personId) + .blockCount((int) blockCount) + .documentCount(0) + .modelName("sender_" + personId) + .build()); + })); + + String runId = run.getId().toString(); + MDC.put("trainingRunId", runId); + log.info("Started sender training run {} for person {}", runId, personId); + + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + trainingDataExportService.exportForSender(personId).writeTo(baos); + byte[] zipBytes = baos.toByteArray(); + + log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length); + OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath); + + txTemplate.execute(status -> { + SenderModel model = senderModelRepository.findByPersonId(personId) + .orElseGet(() -> SenderModel.builder().personId(personId).build()); + model.setModelPath(outputModelPath); + model.setCer(result.cer()); + model.setAccuracy(result.accuracy()); + model.setCorrectedLinesAtTraining(correctedLines); + senderModelRepository.save(model); + + run.setStatus(TrainingStatus.DONE); + run.setCompletedAt(Instant.now()); + run.setCer(result.cer()); + run.setAccuracy(result.accuracy()); + run.setEpochs(result.epochs()); + trainingRunRepository.save(run); + log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer()); + return null; + }); + } catch (Exception e) { + txTemplate.execute(status -> { + run.setStatus(TrainingStatus.FAILED); + run.setErrorMessage(e.getMessage()); + run.setCompletedAt(Instant.now()); + trainingRunRepository.save(run); + log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e); + return null; + }); + } finally { + MDC.remove("trainingRunId"); + promoteNextQueuedRun(); + } + } + + private void promoteNextQueuedRun() { + Optional queuedOpt = txTemplate.execute(status -> + trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED) + .map(queued -> { + queued.setStatus(TrainingStatus.RUNNING); + return trainingRunRepository.save(queued); + })); + + if (queuedOpt != null && queuedOpt.isPresent()) { + OcrTrainingRun promoted = queuedOpt.get(); + log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId()); + long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId()); + triggerSenderTraining(promoted.getPersonId(), (int) freshCount); + } + } + @EventListener(ApplicationReadyEvent.class) @Transactional public void recoverOrphanedRuns() { @@ -224,15 +331,4 @@ public class OcrTrainingService { }); } - public Map buildTrainingInfoMap(TrainingInfoResponse info) { - return Map.of( - "availableBlocks", info.availableBlocks(), - "totalOcrBlocks", info.totalOcrBlocks(), - "availableDocuments", info.availableDocuments(), - "availableSegBlocks", info.availableSegBlocks(), - "ocrServiceAvailable", info.ocrServiceAvailable(), - "lastRun", info.lastRun() != null ? info.lastRun() : Map.of(), - "runs", info.runs() - ); - } } -- 2.49.1 From 386dc83958860d7c8c2badcb470fc0f8c4ff8702 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 19:08:10 +0200 Subject: [PATCH 08/32] refactor(ocr): move sender training methods from OcrTrainingService to SenderModelService Eliminates cross-domain repository access: OcrTrainingService no longer holds SenderModelRepository. SenderModelService now owns the full sender training lifecycle (runOrQueueSenderTraining, triggerSenderTraining, promoteNextQueuedRun), removing the circular dependency risk. Co-Authored-By: Claude Sonnet 4.6 --- .../service/OcrTrainingService.java | 107 ---------- .../service/SenderModelService.java | 182 ++++++++++++++++++ .../service/OcrTrainingServiceTest.java | 48 +---- .../service/SenderModelServiceTest.java | 149 ++++++++++++-- 4 files changed, 311 insertions(+), 175 deletions(-) create mode 100644 backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java index c41542df..ff80e1e6 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java @@ -5,10 +5,8 @@ import lombok.extern.slf4j.Slf4j; import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.exception.ErrorCode; import org.raddatz.familienarchiv.model.OcrTrainingRun; -import org.raddatz.familienarchiv.model.SenderModel; import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; -import org.raddatz.familienarchiv.repository.SenderModelRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.slf4j.MDC; import org.springframework.boot.context.event.ApplicationReadyEvent; @@ -30,7 +28,6 @@ import java.util.UUID; public class OcrTrainingService { private final OcrTrainingRunRepository trainingRunRepository; - private final SenderModelRepository senderModelRepository; private final TrainingDataExportService trainingDataExportService; private final SegmentationTrainingExportService segmentationTrainingExportService; private final OcrClient ocrClient; @@ -212,110 +209,6 @@ public class OcrTrainingService { ); } - /** - * Called from SenderModelService (which is @Async). Inserts a QUEUED row when training - * is already running (coalescing: at most one QUEUED per person_id). Runs immediately if idle. - * Returns true if training was started, false if queued or skipped. - */ - @Transactional - public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) { - if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) { - log.info("Sender training already queued for person {} — skipping duplicate trigger", personId); - return false; - } - - if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { - int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId); - trainingRunRepository.save(OcrTrainingRun.builder() - .status(TrainingStatus.QUEUED) - .personId(personId) - .blockCount(blockCount) - .documentCount(0) - .modelName("sender_" + personId) - .build()); - log.info("Queued sender training for person {} — training already running", personId); - return false; - } - - return true; - } - - /** Executes sender training synchronously. Caller must run this on a background thread. */ - public void triggerSenderTraining(UUID personId, int correctedLines) { - String outputModelPath = "/app/models/sender_" + personId + ".mlmodel"; - - OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { - long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId); - return trainingRunRepository.save(OcrTrainingRun.builder() - .status(TrainingStatus.RUNNING) - .personId(personId) - .blockCount((int) blockCount) - .documentCount(0) - .modelName("sender_" + personId) - .build()); - })); - - String runId = run.getId().toString(); - MDC.put("trainingRunId", runId); - log.info("Started sender training run {} for person {}", runId, personId); - - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - trainingDataExportService.exportForSender(personId).writeTo(baos); - byte[] zipBytes = baos.toByteArray(); - - log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length); - OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath); - - txTemplate.execute(status -> { - SenderModel model = senderModelRepository.findByPersonId(personId) - .orElseGet(() -> SenderModel.builder().personId(personId).build()); - model.setModelPath(outputModelPath); - model.setCer(result.cer()); - model.setAccuracy(result.accuracy()); - model.setCorrectedLinesAtTraining(correctedLines); - senderModelRepository.save(model); - - run.setStatus(TrainingStatus.DONE); - run.setCompletedAt(Instant.now()); - run.setCer(result.cer()); - run.setAccuracy(result.accuracy()); - run.setEpochs(result.epochs()); - trainingRunRepository.save(run); - log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer()); - return null; - }); - } catch (Exception e) { - txTemplate.execute(status -> { - run.setStatus(TrainingStatus.FAILED); - run.setErrorMessage(e.getMessage()); - run.setCompletedAt(Instant.now()); - trainingRunRepository.save(run); - log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e); - return null; - }); - } finally { - MDC.remove("trainingRunId"); - promoteNextQueuedRun(); - } - } - - private void promoteNextQueuedRun() { - Optional queuedOpt = txTemplate.execute(status -> - trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED) - .map(queued -> { - queued.setStatus(TrainingStatus.RUNNING); - return trainingRunRepository.save(queued); - })); - - if (queuedOpt != null && queuedOpt.isPresent()) { - OcrTrainingRun promoted = queuedOpt.get(); - log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId()); - long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId()); - triggerSenderTraining(promoted.getPersonId(), (int) freshCount); - } - } - @EventListener(ApplicationReadyEvent.class) @Transactional public void recoverOrphanedRuns() { diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java new file mode 100644 index 00000000..465d6a8e --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java @@ -0,0 +1,182 @@ +package org.raddatz.familienarchiv.service; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.raddatz.familienarchiv.model.OcrTrainingRun; +import org.raddatz.familienarchiv.model.SenderModel; +import org.raddatz.familienarchiv.model.TrainingStatus; +import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; +import org.raddatz.familienarchiv.repository.SenderModelRepository; +import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.slf4j.MDC; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; +import org.springframework.transaction.support.TransactionTemplate; + +import java.io.ByteArrayOutputStream; +import java.time.Instant; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; + +@Service +@RequiredArgsConstructor +@Slf4j +public class SenderModelService { + + private final SenderModelRepository senderModelRepository; + private final TranscriptionBlockRepository blockRepository; + private final OcrTrainingRunRepository trainingRunRepository; + private final OcrClient ocrClient; + private final TransactionTemplate txTemplate; + private final TrainingDataExportService trainingDataExportService; + + @Value("${ocr.sender-model.activation-threshold:100}") + private int activationThreshold; + + @Value("${ocr.sender-model.retrain-delta:50}") + private int retrainDelta; + + /** Returns the model path if a trained sender model exists for this person. */ + public Optional maybeGetModelPath(UUID personId) { + return senderModelRepository.findByPersonId(personId) + .map(SenderModel::getModelPath); + } + + /** + * Called after every MANUAL block save for HANDWRITING_KURRENT documents. + * Checks activation and retrain thresholds; enqueues or starts sender training when met. + */ + @Async + public void checkAndTriggerTraining(UUID personId) { + long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId); + Optional existing = senderModelRepository.findByPersonId(personId); + + boolean shouldActivate = existing.isEmpty() && correctedLines >= activationThreshold; + boolean shouldRetrain = existing.isPresent() + && (correctedLines - existing.get().getCorrectedLinesAtTraining()) >= retrainDelta; + + if (!shouldActivate && !shouldRetrain) { + return; + } + + log.info("Sender training threshold met for person {} (correctedLines={}, activate={}, retrain={})", + personId, correctedLines, shouldActivate, shouldRetrain); + + boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines); + if (runNow) { + triggerSenderTraining(personId, (int) correctedLines); + } + } + + /** + * Inserts a QUEUED row when training is already running (coalescing: at most one QUEUED per + * person_id). Runs immediately if idle. Returns true if training should start now, false if + * queued or skipped. + */ + @Transactional + public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) { + if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) { + log.info("Sender training already queued for person {} — skipping duplicate trigger", personId); + return false; + } + + if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { + int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId); + trainingRunRepository.save(OcrTrainingRun.builder() + .status(TrainingStatus.QUEUED) + .personId(personId) + .blockCount(blockCount) + .documentCount(0) + .modelName("sender_" + personId) + .build()); + log.info("Queued sender training for person {} — training already running", personId); + return false; + } + + return true; + } + + /** Executes sender training synchronously. Caller must run this on a background thread. */ + public void triggerSenderTraining(UUID personId, int correctedLines) { + String outputModelPath = "/app/models/sender_" + personId + ".mlmodel"; + + OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { + long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId); + return trainingRunRepository.save(OcrTrainingRun.builder() + .status(TrainingStatus.RUNNING) + .personId(personId) + .blockCount((int) blockCount) + .documentCount(0) + .modelName("sender_" + personId) + .build()); + })); + + String runId = run.getId().toString(); + MDC.put("trainingRunId", runId); + log.info("Started sender training run {} for person {}", runId, personId); + + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + trainingDataExportService.exportForSender(personId).writeTo(baos); + byte[] zipBytes = baos.toByteArray(); + + log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length); + OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath); + + txTemplate.execute(status -> { + SenderModel model = senderModelRepository.findByPersonId(personId) + .orElseGet(() -> SenderModel.builder().personId(personId).build()); + model.setModelPath(outputModelPath); + model.setCer(result.cer()); + model.setAccuracy(result.accuracy()); + model.setCorrectedLinesAtTraining(correctedLines); + senderModelRepository.save(model); + + run.setStatus(TrainingStatus.DONE); + run.setCompletedAt(Instant.now()); + run.setCer(result.cer()); + run.setAccuracy(result.accuracy()); + run.setEpochs(result.epochs()); + trainingRunRepository.save(run); + log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer()); + return null; + }); + } catch (Exception e) { + txTemplate.execute(status -> { + run.setStatus(TrainingStatus.FAILED); + run.setErrorMessage(e.getMessage()); + run.setCompletedAt(Instant.now()); + trainingRunRepository.save(run); + log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e); + return null; + }); + } finally { + MDC.remove("trainingRunId"); + promoteNextQueuedRun(); + } + } + + /** + * Promotes the oldest QUEUED sender run to RUNNING and triggers its training. + * Called in the finally block of triggerSenderTraining, creating a sequential chain: + * each run promotes the next only after it fully completes (success or failure). + */ + private void promoteNextQueuedRun() { + Optional queuedOpt = txTemplate.execute(status -> + trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED) + .map(queued -> { + queued.setStatus(TrainingStatus.RUNNING); + return trainingRunRepository.save(queued); + })); + + if (queuedOpt != null && queuedOpt.isPresent()) { + OcrTrainingRun promoted = queuedOpt.get(); + log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId()); + long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId()); + triggerSenderTraining(promoted.getPersonId(), (int) freshCount); + } + } +} diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java index b89e1ae6..e6e43f00 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -4,11 +4,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.model.OcrTrainingRun; -import org.raddatz.familienarchiv.model.TrainingLabel; import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; -import org.raddatz.familienarchiv.repository.SenderModelRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionTemplate; @@ -27,7 +25,6 @@ import static org.mockito.Mockito.*; class OcrTrainingServiceTest { OcrTrainingRunRepository runRepository; - SenderModelRepository senderModelRepository; TrainingDataExportService exportService; SegmentationTrainingExportService segExportService; OcrClient ocrClient; @@ -39,7 +36,6 @@ class OcrTrainingServiceTest { @BeforeEach void setUp() { runRepository = mock(OcrTrainingRunRepository.class); - senderModelRepository = mock(SenderModelRepository.class); exportService = mock(TrainingDataExportService.class); segExportService = mock(SegmentationTrainingExportService.class); ocrClient = mock(OcrClient.class); @@ -53,7 +49,7 @@ class OcrTrainingServiceTest { return callback.doInTransaction(null); }); - service = new OcrTrainingService(runRepository, senderModelRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); + service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); when(blockRepository.count()).thenReturn(0L); when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of()); @@ -233,48 +229,6 @@ class OcrTrainingServiceTest { run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null)); } - // ─── runOrQueueSenderTraining ──────────────────────────────────────────── - - @Test - void runOrQueueSenderTraining_returnsTrue_whenIdle() { - UUID personId = UUID.randomUUID(); - when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); - when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); - - boolean result = service.runOrQueueSenderTraining(personId, 120); - - assertThat(result).isTrue(); - verify(runRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED)); - } - - @Test - void runOrQueueSenderTraining_savesQueued_whenRunning() { - UUID personId = UUID.randomUUID(); - when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); - when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn( - Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING) - .blockCount(5).documentCount(1).modelName("german_kurrent").build())); - when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L); - when(runRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); - - boolean result = service.runOrQueueSenderTraining(personId, 120); - - assertThat(result).isFalse(); - verify(runRepository).save(argThat(r -> - r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId()))); - } - - @Test - void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() { - UUID personId = UUID.randomUUID(); - when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true); - - boolean result = service.runOrQueueSenderTraining(personId, 120); - - assertThat(result).isFalse(); - verify(runRepository, never()).save(any()); - } - // ─── Orphan recovery ────────────────────────────────────────────────────── @Test diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java index b3a3a5d9..f46c5905 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java @@ -2,23 +2,31 @@ package org.raddatz.familienarchiv.service; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.raddatz.familienarchiv.model.OcrTrainingRun; import org.raddatz.familienarchiv.model.SenderModel; +import org.raddatz.familienarchiv.model.TrainingStatus; +import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.SenderModelRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.transaction.support.TransactionCallback; +import org.springframework.transaction.support.TransactionTemplate; import java.util.Optional; import java.util.UUID; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.*; class SenderModelServiceTest { SenderModelRepository senderModelRepository; TranscriptionBlockRepository blockRepository; - OcrTrainingService ocrTrainingService; + OcrTrainingRunRepository trainingRunRepository; + OcrClient ocrClient; + TransactionTemplate txTemplate; + TrainingDataExportService trainingDataExportService; SenderModelService service; UUID personId = UUID.randomUUID(); @@ -27,8 +35,19 @@ class SenderModelServiceTest { void setUp() { senderModelRepository = mock(SenderModelRepository.class); blockRepository = mock(TranscriptionBlockRepository.class); - ocrTrainingService = mock(OcrTrainingService.class); - service = new SenderModelService(senderModelRepository, blockRepository, ocrTrainingService); + trainingRunRepository = mock(OcrTrainingRunRepository.class); + ocrClient = mock(OcrClient.class); + txTemplate = mock(TransactionTemplate.class); + trainingDataExportService = mock(TrainingDataExportService.class); + + // Execute transaction callbacks inline so unit tests run without a real DataSource + when(txTemplate.execute(any())).thenAnswer(inv -> { + TransactionCallback callback = inv.getArgument(0); + return callback.doInTransaction(null); + }); + + service = new SenderModelService(senderModelRepository, blockRepository, + trainingRunRepository, ocrClient, txTemplate, trainingDataExportService); ReflectionTestUtils.setField(service, "activationThreshold", 100); ReflectionTestUtils.setField(service, "retrainDelta", 50); } @@ -40,20 +59,23 @@ class SenderModelServiceTest { when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L); when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); - service.checkAndTriggerTraining(personId); + SenderModelService spy = spy(service); + spy.checkAndTriggerTraining(personId); - verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt()); + verify(spy, never()).runOrQueueSenderTraining(any(), anyInt()); } @Test void checkAndTriggerTraining_triggersTraining_atActivationThreshold() { when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); - when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false); - service.checkAndTriggerTraining(personId); + SenderModelService spy = spy(service); + doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100); - verify(ocrTrainingService).runOrQueueSenderTraining(personId, 100); + spy.checkAndTriggerTraining(personId); + + verify(spy).runOrQueueSenderTraining(personId, 100); } // ─── Retrain delta ──────────────────────────────────────────────────────── @@ -65,9 +87,10 @@ class SenderModelServiceTest { when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing)); when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L); - service.checkAndTriggerTraining(personId); + SenderModelService spy = spy(service); + spy.checkAndTriggerTraining(personId); - verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt()); + verify(spy, never()).runOrQueueSenderTraining(any(), anyInt()); } @Test @@ -76,11 +99,13 @@ class SenderModelServiceTest { .correctedLinesAtTraining(100).build(); when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing)); when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L); - when(ocrTrainingService.runOrQueueSenderTraining(personId, 150)).thenReturn(false); - service.checkAndTriggerTraining(personId); + SenderModelService spy = spy(service); + doReturn(false).when(spy).runOrQueueSenderTraining(personId, 150); - verify(ocrTrainingService).runOrQueueSenderTraining(personId, 150); + spy.checkAndTriggerTraining(personId); + + verify(spy).runOrQueueSenderTraining(personId, 150); } // ─── runNow flag ────────────────────────────────────────────────────────── @@ -89,21 +114,103 @@ class SenderModelServiceTest { void checkAndTriggerTraining_callsTrigger_whenRunNow() { when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); - when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(true); - service.checkAndTriggerTraining(personId); + SenderModelService spy = spy(service); + doReturn(true).when(spy).runOrQueueSenderTraining(personId, 100); + doNothing().when(spy).triggerSenderTraining(personId, 100); - verify(ocrTrainingService).triggerSenderTraining(personId, 100); + spy.checkAndTriggerTraining(personId); + + verify(spy).triggerSenderTraining(personId, 100); } @Test void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() { when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); - when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false); - service.checkAndTriggerTraining(personId); + SenderModelService spy = spy(service); + doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100); - verify(ocrTrainingService, never()).triggerSenderTraining(any(), anyInt()); + spy.checkAndTriggerTraining(personId); + + verify(spy, never()).triggerSenderTraining(any(), anyInt()); + } + + // ─── runOrQueueSenderTraining ───────────────────────────────────────────── + + @Test + void runOrQueueSenderTraining_returnsTrue_whenIdle() { + when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isTrue(); + verify(trainingRunRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED)); + } + + @Test + void runOrQueueSenderTraining_savesQueued_whenRunning() { + when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn( + Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(1).modelName("german_kurrent").build())); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L); + when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isFalse(); + verify(trainingRunRepository).save(argThat(r -> + r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId()))); + } + + @Test + void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() { + when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isFalse(); + verify(trainingRunRepository, never()).save(any()); + } + + // ─── triggerSenderTraining — queue promotion ────────────────────────────── + + @Test + void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception { + UUID nextPersonId = UUID.randomUUID(); + + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(10L); + when(trainingRunRepository.save(any())).thenAnswer(inv -> { + OcrTrainingRun r = inv.getArgument(0); + if (r.getId() == null) r.setId(UUID.randomUUID()); + return r; + }); + when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {}); + when(ocrClient.trainSenderModel(any(), any())) + .thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3)); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + + OcrTrainingRun queued = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.QUEUED) + .personId(nextPersonId).blockCount(5).documentCount(0) + .modelName("sender_" + nextPersonId).build(); + when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)) + .thenReturn(Optional.of(queued)); + when(blockRepository.countManualKurrentBlocksByPerson(nextPersonId)).thenReturn(5L); + + SenderModelService spy = spy(service); + // Stub the recursive call to stop the chain after one promotion + doNothing().when(spy).triggerSenderTraining(eq(nextPersonId), anyInt()); + + spy.triggerSenderTraining(personId, 10); + + // Queued run promoted to RUNNING + verify(trainingRunRepository).save(argThat(r -> + nextPersonId.equals(r.getPersonId()) && r.getStatus() == TrainingStatus.RUNNING)); + // Next training was kicked off + verify(spy).triggerSenderTraining(eq(nextPersonId), anyInt()); } } -- 2.49.1 From 18cf839fac8e4b88feab5ad48d87bc4800f4ea13 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 19:27:02 +0200 Subject: [PATCH 09/32] feat(ocr): wire SenderModelService into OcrAsyncRunner; stage missing foundational files OcrAsyncRunner now passes the per-sender model path to streamBlocks for HANDWRITING_KURRENT documents. processDocument replaced extractBlocks with streamBlocks + AtomicReference, removing the unchecked raw-array pattern. Also stages all previously uncommitted foundational files for this feature: SenderModel entity, SenderModelRepository, Flyway migrations V40/V41, updated OcrClient/RestClientOcrClient streaming API, TrainingDataExportService.exportForSender, TranscriptionService Kurrent hook, application.yaml OCR config, and frontend i18n/test additions. Co-Authored-By: Claude Sonnet 4.6 --- .../familienarchiv/model/OcrTrainingRun.java | 3 + .../familienarchiv/model/SenderModel.java | 56 +++++++++++++++ .../familienarchiv/model/TrainingStatus.java | 1 + .../repository/OcrTrainingRunRepository.java | 6 +- .../repository/SenderModelRepository.java | 12 ++++ .../TranscriptionBlockRepository.java | 19 ++++++ .../service/OcrAsyncRunner.java | 31 ++++++++- .../familienarchiv/service/OcrClient.java | 30 ++++++-- .../service/RestClientOcrClient.java | 68 ++++++++++++++----- .../service/TrainingDataExportService.java | 16 ++++- .../service/TranscriptionService.java | 9 +++ backend/src/main/resources/application.yaml | 5 ++ .../db/migration/V40__add_sender_models.sql | 10 +++ .../V41__add_person_to_training_runs.sql | 2 + .../service/OcrClientDefaultStreamTest.java | 2 + .../TranscriptionServiceGuidedTest.java | 4 +- frontend/messages/de.json | 5 ++ frontend/messages/en.json | 5 ++ frontend/messages/es.json | 5 ++ .../components/TrainingHistory.svelte.spec.ts | 23 +++++++ 20 files changed, 281 insertions(+), 31 deletions(-) create mode 100644 backend/src/main/java/org/raddatz/familienarchiv/model/SenderModel.java create mode 100644 backend/src/main/java/org/raddatz/familienarchiv/repository/SenderModelRepository.java create mode 100644 backend/src/main/resources/db/migration/V40__add_sender_models.sql create mode 100644 backend/src/main/resources/db/migration/V41__add_person_to_training_runs.sql diff --git a/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java b/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java index e5db5231..c22081aa 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java @@ -59,6 +59,9 @@ public class OcrTrainingRun { @Column(name = "triggered_by") private UUID triggeredBy; + @Column(name = "person_id") + private UUID personId; + @CreationTimestamp @Column(name = "created_at", nullable = false, updatable = false) @Schema(requiredMode = Schema.RequiredMode.REQUIRED) diff --git a/backend/src/main/java/org/raddatz/familienarchiv/model/SenderModel.java b/backend/src/main/java/org/raddatz/familienarchiv/model/SenderModel.java new file mode 100644 index 00000000..4b8586aa --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/model/SenderModel.java @@ -0,0 +1,56 @@ +package org.raddatz.familienarchiv.model; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.persistence.*; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.hibernate.annotations.CreationTimestamp; +import org.hibernate.annotations.UpdateTimestamp; + +import java.time.Instant; +import java.util.UUID; + +@Entity +@Table(name = "sender_models") +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +public class SenderModel { + + @Id + @GeneratedValue(strategy = GenerationType.UUID) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private UUID id; + + @Column(name = "person_id", nullable = false, unique = true) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private UUID personId; + + @JsonIgnore + @Column(name = "model_path", nullable = false) + private String modelPath; + + @Column + private Double accuracy; + + @Column + private Double cer; + + @Column(name = "corrected_lines_at_training", nullable = false) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private int correctedLinesAtTraining; + + @CreationTimestamp + @Column(name = "created_at", nullable = false, updatable = false) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private Instant createdAt; + + @UpdateTimestamp + @Column(name = "updated_at", nullable = false) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private Instant updatedAt; +} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/model/TrainingStatus.java b/backend/src/main/java/org/raddatz/familienarchiv/model/TrainingStatus.java index 7e99dd2f..e91b9967 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/model/TrainingStatus.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/model/TrainingStatus.java @@ -1,6 +1,7 @@ package org.raddatz.familienarchiv.model; public enum TrainingStatus { + QUEUED, RUNNING, DONE, FAILED diff --git a/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java b/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java index fe7d61d1..99ff81da 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java @@ -12,5 +12,9 @@ public interface OcrTrainingRunRepository extends JpaRepository findFirstByStatus(TrainingStatus status); - List findTop10ByOrderByCreatedAtDesc(); + Optional findFirstByStatusOrderByCreatedAtAsc(TrainingStatus status); + + boolean existsByPersonIdAndStatus(UUID personId, TrainingStatus status); + + List findTop20ByOrderByCreatedAtDesc(); } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/repository/SenderModelRepository.java b/backend/src/main/java/org/raddatz/familienarchiv/repository/SenderModelRepository.java new file mode 100644 index 00000000..0bfb05a5 --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/repository/SenderModelRepository.java @@ -0,0 +1,12 @@ +package org.raddatz.familienarchiv.repository; + +import org.raddatz.familienarchiv.model.SenderModel; +import org.springframework.data.jpa.repository.JpaRepository; + +import java.util.Optional; +import java.util.UUID; + +public interface SenderModelRepository extends JpaRepository { + + Optional findByPersonId(UUID personId); +} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/repository/TranscriptionBlockRepository.java b/backend/src/main/java/org/raddatz/familienarchiv/repository/TranscriptionBlockRepository.java index d091f950..c88830ad 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/repository/TranscriptionBlockRepository.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/repository/TranscriptionBlockRepository.java @@ -3,6 +3,7 @@ package org.raddatz.familienarchiv.repository; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; import java.util.List; import java.util.Optional; @@ -37,4 +38,22 @@ public interface TranscriptionBlockRepository extends JpaRepository findSegmentationBlocks(); + + @Query(""" + SELECT COUNT(b) FROM TranscriptionBlock b + JOIN Document d ON d.id = b.documentId + WHERE b.source = 'MANUAL' + AND d.sender.id = :personId + AND d.scriptType = 'HANDWRITING_KURRENT' + """) + long countManualKurrentBlocksByPerson(@Param("personId") UUID personId); + + @Query(""" + SELECT b FROM TranscriptionBlock b + JOIN Document d ON d.id = b.documentId + WHERE b.source = 'MANUAL' + AND d.sender.id = :personId + AND d.scriptType = 'HANDWRITING_KURRENT' + """) + List findManualKurrentBlocksByPerson(@Param("personId") UUID personId); } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrAsyncRunner.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrAsyncRunner.java index 3ae853b0..08f38bda 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrAsyncRunner.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrAsyncRunner.java @@ -9,10 +9,12 @@ import org.raddatz.familienarchiv.repository.OcrJobRepository; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Component; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; @Component @RequiredArgsConstructor @@ -29,6 +31,7 @@ public class OcrAsyncRunner { private final OcrJobRepository ocrJobRepository; private final OcrJobDocumentRepository ocrJobDocumentRepository; private final OcrProgressService ocrProgressService; + private final SenderModelService senderModelService; @Async public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) { @@ -68,12 +71,18 @@ public class OcrAsyncRunner { String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath()); + String senderModelPath = null; + if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) { + senderModelPath = senderModelService.maybeGetModelPath(doc.getSender().getId()).orElse(null); + } + AtomicInteger blockCounter = new AtomicInteger(0); AtomicInteger currentPage = new AtomicInteger(0); AtomicInteger skippedPages = new AtomicInteger(0); AtomicInteger totalPages = new AtomicInteger(0); - ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, event -> { + final String finalSenderModelPath = senderModelPath; + ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, finalSenderModelPath, event -> { switch (event) { case OcrStreamEvent.Start start -> { totalPages.set(start.totalPages()); @@ -207,7 +216,25 @@ public class OcrAsyncRunner { clearExistingBlocks(documentId); String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath()); - List blocks = ocrClient.extractBlocks(pdfUrl, doc.getScriptType()); + + String senderModelPath = null; + if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) { + senderModelPath = senderModelService.maybeGetModelPath(doc.getSender().getId()).orElse(null); + } + + final AtomicReference> blocksRef = new AtomicReference<>(); + final String finalSenderModelPath = senderModelPath; + ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), null, finalSenderModelPath, event -> { + switch (event) { + case OcrStreamEvent.Page page -> { + blocksRef.compareAndSet(null, new ArrayList<>()); + blocksRef.get().addAll(page.blocks()); + } + default -> {} + } + }); + + List blocks = blocksRef.get() != null ? blocksRef.get() : List.of(); createTranscriptionBlocks(documentId, blocks, userId, doc.getFileHash()); } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java index 45c2f021..6af68718 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java @@ -1,6 +1,7 @@ package org.raddatz.familienarchiv.service; import org.raddatz.familienarchiv.model.ScriptType; +import org.springframework.lang.Nullable; import java.util.ArrayList; import java.util.LinkedHashMap; @@ -37,15 +38,27 @@ public interface OcrClient { TrainingResult segtrainModel(byte[] trainingDataZip); /** - * Stream OCR results page-by-page via NDJSON. Implementations should override - * this method. The default exists only for backward compatibility during migration - * — it calls extractBlocks() and synthesizes events from the collected result. + * Fine-tune the Kurrent model for a specific sender. * - * @param regions optional list of pre-drawn annotation regions; when non-null, - * the OCR service runs in guided mode (crop + recognize per region) + * @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService.exportForSender() + * @param outputModelPath where to save the trained model (e.g. /app/models/sender_{uuid}.mlmodel) + * @return training result metrics + */ + TrainingResult trainSenderModel(byte[] trainingDataZip, String outputModelPath); + + /** + * Stream OCR results page-by-page via NDJSON, optionally using a sender-specific model. + * The default implementation synthesizes events from extractBlocks() for backward compatibility. + * Implementations that support real streaming (e.g. RestClientOcrClient) override this. + * + * @param regions optional list of pre-drawn annotation regions; when non-null, + * the OCR service runs in guided mode (crop + recognize per region) + * @param senderModelPath optional path to a per-sender model file; null means use base model */ default void streamBlocks(String pdfUrl, ScriptType scriptType, - List regions, Consumer handler) { + List regions, + @Nullable String senderModelPath, + Consumer handler) { List allBlocks = extractBlocks(pdfUrl, scriptType); LinkedHashMap> byPage = new LinkedHashMap<>(); @@ -62,4 +75,9 @@ public interface OcrClient { handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0)); } + + default void streamBlocks(String pdfUrl, ScriptType scriptType, + List regions, Consumer handler) { + streamBlocks(pdfUrl, scriptType, regions, null, handler); + } } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java b/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java index 9e58a2e1..2757ee19 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java @@ -14,6 +14,7 @@ import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.client.JdkClientHttpRequestFactory; +import org.springframework.lang.Nullable; import org.springframework.stereotype.Component; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -102,6 +103,13 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { .toList(); } + private RestClient.RequestBodySpec addTrainingAuth(RestClient.RequestBodySpec spec) { + if (trainingToken != null && !trainingToken.isBlank()) { + return spec.header("X-Training-Token", trainingToken); + } + return spec; + } + @Override public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) { ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) { @@ -114,15 +122,10 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { partHeaders.setContentType(MediaType.parseMediaType("application/zip")); body.add("file", new HttpEntity<>(zipResource, partHeaders)); - var spec = trainingRestClient.post() - .uri("/train") - .contentType(MediaType.MULTIPART_FORM_DATA); - - if (trainingToken != null && !trainingToken.isBlank()) { - spec = spec.header("X-Training-Token", trainingToken); - } - - TrainingResultJson result = spec + TrainingResultJson result = addTrainingAuth( + trainingRestClient.post() + .uri("/train") + .contentType(MediaType.MULTIPART_FORM_DATA)) .body(body) .retrieve() .body(TrainingResultJson.class); @@ -143,15 +146,35 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { partHeaders.setContentType(MediaType.parseMediaType("application/zip")); body.add("file", new HttpEntity<>(zipResource, partHeaders)); - var spec = trainingRestClient.post() - .uri("/segtrain") - .contentType(MediaType.MULTIPART_FORM_DATA); + TrainingResultJson result = addTrainingAuth( + trainingRestClient.post() + .uri("/segtrain") + .contentType(MediaType.MULTIPART_FORM_DATA)) + .body(body) + .retrieve() + .body(TrainingResultJson.class); - if (trainingToken != null && !trainingToken.isBlank()) { - spec = spec.header("X-Training-Token", trainingToken); - } + if (result == null) return new OcrClient.TrainingResult(null, null, null, null); + return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs()); + } - TrainingResultJson result = spec + @Override + public OcrClient.TrainingResult trainSenderModel(byte[] trainingDataZip, String outputModelPath) { + ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) { + @Override + public String getFilename() { return "sender-training-data.zip"; } + }; + + MultiValueMap body = new LinkedMultiValueMap<>(); + HttpHeaders partHeaders = new HttpHeaders(); + partHeaders.setContentType(MediaType.parseMediaType("application/zip")); + body.add("file", new HttpEntity<>(zipResource, partHeaders)); + body.add("output_model_path", outputModelPath); + + TrainingResultJson result = addTrainingAuth( + trainingRestClient.post() + .uri("/train") + .contentType(MediaType.MULTIPART_FORM_DATA)) .body(body) .retrieve() .body(TrainingResultJson.class); @@ -176,7 +199,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { @Override public void streamBlocks(String pdfUrl, ScriptType scriptType, - List regions, Consumer handler) { + List regions, @Nullable String senderModelPath, + Consumer handler) { String body; try { var requestMap = new java.util.LinkedHashMap(); @@ -186,6 +210,9 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { if (regions != null && !regions.isEmpty()) { requestMap.put("regions", regions); } + if (senderModelPath != null) { + requestMap.put("senderModelPath", senderModelPath); + } body = NDJSON_MAPPER.writeValueAsString(requestMap); } catch (IOException e) { throw new RuntimeException("Failed to serialize OCR request", e); @@ -204,7 +231,12 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { if (response.statusCode() == 404) { log.info("OCR service does not support /ocr/stream (404), falling back to /ocr"); - OcrClient.super.streamBlocks(pdfUrl, scriptType, regions, handler); + List allBlocks = extractBlocks(pdfUrl, scriptType); + handler.accept(new OcrStreamEvent.Start(0)); + for (OcrBlockResult block : allBlocks) { + handler.accept(new OcrStreamEvent.Page(block.pageNumber(), List.of(block))); + } + handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0)); return; } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/TrainingDataExportService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/TrainingDataExportService.java index cf0b10e3..86c81053 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/TrainingDataExportService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/TrainingDataExportService.java @@ -38,10 +38,20 @@ public class TrainingDataExportService { } public StreamingResponseBody exportToZip() { - // Collect all data before entering the lambda — no open DB txn during streaming - List blocks = queryEligibleBlocks(); + return exportBlocksToZip(queryEligibleBlocks()); + } + + public List queryBlocksForSender(UUID personId) { + return blockRepository.findManualKurrentBlocksByPerson(personId); + } + + public StreamingResponseBody exportForSender(UUID personId) { + return exportBlocksToZip(queryBlocksForSender(personId)); + } + + private StreamingResponseBody exportBlocksToZip(List blocks) { if (blocks.isEmpty()) { - return out -> {}; // caller checks isEmpty() for 204 response + return out -> {}; } // Group blocks by documentId so we only download each PDF once diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/TranscriptionService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/TranscriptionService.java index bfb02253..081f588d 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/TranscriptionService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/TranscriptionService.java @@ -11,6 +11,7 @@ import org.raddatz.familienarchiv.exception.ErrorCode; import org.raddatz.familienarchiv.model.BlockSource; import org.raddatz.familienarchiv.model.Document; import org.raddatz.familienarchiv.model.DocumentAnnotation; +import org.raddatz.familienarchiv.model.ScriptType; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.model.TranscriptionBlockVersion; import org.raddatz.familienarchiv.repository.AnnotationRepository; @@ -35,6 +36,7 @@ public class TranscriptionService { private final AnnotationRepository annotationRepository; private final AnnotationService annotationService; private final DocumentService documentService; + private final SenderModelService senderModelService; public List listBlocks(UUID documentId) { return blockRepository.findByDocumentIdOrderBySortOrderAsc(documentId); @@ -122,6 +124,7 @@ public class TranscriptionService { String text = sanitizeText(dto.getText()); block.setText(text); + block.setSource(BlockSource.MANUAL); if (dto.getLabel() != null) { block.setLabel(dto.getLabel()); } @@ -129,6 +132,12 @@ public class TranscriptionService { TranscriptionBlock saved = blockRepository.save(block); saveVersion(saved, userId); + + Document doc = documentService.getDocumentById(documentId); + if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) { + senderModelService.checkAndTriggerTraining(doc.getSender().getId()); + } + return saved; } diff --git a/backend/src/main/resources/application.yaml b/backend/src/main/resources/application.yaml index aa4c8ceb..d9bbe9d0 100644 --- a/backend/src/main/resources/application.yaml +++ b/backend/src/main/resources/application.yaml @@ -77,3 +77,8 @@ app: tags: 10 summary: 11 transcription: 13 + +ocr: + sender-model: + activation-threshold: 100 + retrain-delta: 50 diff --git a/backend/src/main/resources/db/migration/V40__add_sender_models.sql b/backend/src/main/resources/db/migration/V40__add_sender_models.sql new file mode 100644 index 00000000..2386bbe2 --- /dev/null +++ b/backend/src/main/resources/db/migration/V40__add_sender_models.sql @@ -0,0 +1,10 @@ +CREATE TABLE sender_models ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + person_id UUID NOT NULL UNIQUE REFERENCES persons(id) ON DELETE CASCADE, + model_path TEXT NOT NULL, + accuracy DOUBLE PRECISION, + cer DOUBLE PRECISION, + corrected_lines_at_training INT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/backend/src/main/resources/db/migration/V41__add_person_to_training_runs.sql b/backend/src/main/resources/db/migration/V41__add_person_to_training_runs.sql new file mode 100644 index 00000000..b95a70b6 --- /dev/null +++ b/backend/src/main/resources/db/migration/V41__add_person_to_training_runs.sql @@ -0,0 +1,2 @@ +ALTER TABLE ocr_training_runs + ADD COLUMN person_id UUID REFERENCES persons(id) ON DELETE SET NULL; diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrClientDefaultStreamTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrClientDefaultStreamTest.java index 7be27c40..f865f5d6 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrClientDefaultStreamTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrClientDefaultStreamTest.java @@ -21,6 +21,7 @@ class OcrClientDefaultStreamTest { } @Override public TrainingResult trainModel(byte[] zip) { return null; } @Override public TrainingResult segtrainModel(byte[] zip) { return null; } + @Override public TrainingResult trainSenderModel(byte[] zip, String path) { return null; } }; List events = new ArrayList<>(); @@ -52,6 +53,7 @@ class OcrClientDefaultStreamTest { @Override public List extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); } @Override public TrainingResult trainModel(byte[] zip) { return null; } @Override public TrainingResult segtrainModel(byte[] zip) { return null; } + @Override public TrainingResult trainSenderModel(byte[] zip, String path) { return null; } }; List events = new ArrayList<>(); diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceGuidedTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceGuidedTest.java index e8d53cad..59eed034 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceGuidedTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceGuidedTest.java @@ -22,6 +22,7 @@ class TranscriptionServiceGuidedTest { AnnotationRepository annotationRepository; AnnotationService annotationService; DocumentService documentService; + SenderModelService senderModelService; TranscriptionService service; UUID docId = UUID.randomUUID(); @@ -35,9 +36,10 @@ class TranscriptionServiceGuidedTest { annotationRepository = mock(AnnotationRepository.class); annotationService = mock(AnnotationService.class); documentService = mock(DocumentService.class); + senderModelService = mock(SenderModelService.class); service = new TranscriptionService(blockRepository, versionRepository, - annotationRepository, annotationService, documentService); + annotationRepository, annotationService, documentService, senderModelService); when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); diff --git a/frontend/messages/de.json b/frontend/messages/de.json index 72ae0b6b..51b10005 100644 --- a/frontend/messages/de.json +++ b/frontend/messages/de.json @@ -561,6 +561,11 @@ "transcription_block_segmentation_only": "Nur Segmentierung", "training_chip_kurrent": "Kurrent-Erkennung", "training_chip_segmentation": "Segmentierung", + "training_col_type": "Typ", + "training_type_base": "Basis", + "training_type_personalized": "Personalisiert", + "training_col_person": "Absender", + "training_status_queued": "Warteschlange", "mission_control_heading": "Was braucht Aufmerksamkeit?", "mission_control_segmentation_heading": "Text markieren", "mission_control_segmentation_description": "Textbereiche markieren — keine Vorkenntnisse nötig", diff --git a/frontend/messages/en.json b/frontend/messages/en.json index c3711b3f..b6016949 100644 --- a/frontend/messages/en.json +++ b/frontend/messages/en.json @@ -561,6 +561,11 @@ "transcription_block_segmentation_only": "Segmentation only", "training_chip_kurrent": "Kurrent recognition", "training_chip_segmentation": "Segmentation", + "training_col_type": "Type", + "training_type_base": "Base", + "training_type_personalized": "Personalized", + "training_col_person": "Sender", + "training_status_queued": "Queued", "mission_control_heading": "What needs attention?", "mission_control_segmentation_heading": "Mark text", "mission_control_segmentation_description": "Mark text areas — no prior knowledge needed", diff --git a/frontend/messages/es.json b/frontend/messages/es.json index c87a2d25..19487b65 100644 --- a/frontend/messages/es.json +++ b/frontend/messages/es.json @@ -561,6 +561,11 @@ "transcription_block_segmentation_only": "Solo segmentación", "training_chip_kurrent": "Reconocimiento Kurrent", "training_chip_segmentation": "Segmentación", + "training_col_type": "Tipo", + "training_type_base": "Base", + "training_type_personalized": "Personalizado", + "training_col_person": "Remitente", + "training_status_queued": "En cola", "mission_control_heading": "¿Qué necesita atención?", "mission_control_segmentation_heading": "Marcar texto", "mission_control_segmentation_description": "Marcar áreas de texto — sin conocimientos previos", diff --git a/frontend/src/lib/components/TrainingHistory.svelte.spec.ts b/frontend/src/lib/components/TrainingHistory.svelte.spec.ts index 728b167a..95090ba6 100644 --- a/frontend/src/lib/components/TrainingHistory.svelte.spec.ts +++ b/frontend/src/lib/components/TrainingHistory.svelte.spec.ts @@ -50,3 +50,26 @@ describe('TrainingHistory — expand/collapse', () => { .not.toBeInTheDocument(); }); }); + +describe('TrainingHistory — type and person columns', () => { + it('shows "Basis" for runs without personId', async () => { + render(TrainingHistory, { runs: [makeRun(0)] }); + + await expect.element(page.getByText(/Basis/i)).toBeInTheDocument(); + }); + + it('shows "Personalisiert" for runs with personId', async () => { + const run = { ...makeRun(0), personId: 'person-1' }; + render(TrainingHistory, { runs: [run], personNames: { 'person-1': 'Karl Müller' } }); + + await expect.element(page.getByText(/Personalisiert/i)).toBeInTheDocument(); + }); + + it('shows person name from personNames for sender runs', async () => { + const run = { ...makeRun(0), personId: 'person-1' }; + render(TrainingHistory, { runs: [run], personNames: { 'person-1': 'Karl Müller' } }); + + await expect.element(page.getByText(/Personalisiert/i)).toBeInTheDocument(); + await expect.element(page.getByText('Karl Müller')).toBeInTheDocument(); + }); +}); -- 2.49.1 From 8d041a377d5c9339f67b0c1771248d19bbc79829 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 20:08:18 +0200 Subject: [PATCH 10/32] fix(ocr): correct trainSenderModel URI from /train to /train-sender Co-Authored-By: Claude Sonnet 4.6 --- .../service/RestClientOcrClient.java | 2 +- .../RestClientOcrClientTrainingTest.java | 41 +++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 backend/src/test/java/org/raddatz/familienarchiv/service/RestClientOcrClientTrainingTest.java diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java b/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java index 2757ee19..2bda9892 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java @@ -173,7 +173,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { TrainingResultJson result = addTrainingAuth( trainingRestClient.post() - .uri("/train") + .uri("/train-sender") .contentType(MediaType.MULTIPART_FORM_DATA)) .body(body) .retrieve() diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/RestClientOcrClientTrainingTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/RestClientOcrClientTrainingTest.java new file mode 100644 index 00000000..17f084fa --- /dev/null +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/RestClientOcrClientTrainingTest.java @@ -0,0 +1,41 @@ +package org.raddatz.familienarchiv.service; + +import com.sun.net.httpserver.HttpServer; +import org.junit.jupiter.api.Test; + +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +import static org.assertj.core.api.Assertions.assertThat; + +class RestClientOcrClientTrainingTest { + + @Test + void trainSenderModel_postsToTrainSenderEndpoint() throws Exception { + List capturedPaths = new CopyOnWriteArrayList<>(); + HttpServer server = HttpServer.create(new InetSocketAddress(0), 0); + server.createContext("/", exchange -> { + capturedPaths.add(exchange.getRequestURI().getPath()); + byte[] response = "{\"loss\":0.1,\"accuracy\":0.95,\"cer\":0.05,\"epochs\":3}" + .getBytes(StandardCharsets.UTF_8); + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders(200, response.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(response); + } + }); + server.start(); + + try { + int port = server.getAddress().getPort(); + RestClientOcrClient client = new RestClientOcrClient("http://localhost:" + port, ""); + client.trainSenderModel(new byte[]{1, 2, 3}, "/app/models/sender_test.mlmodel"); + assertThat(capturedPaths).contains("/train-sender"); + } finally { + server.stop(0); + } + } +} -- 2.49.1 From 1b34a36a77fe5f9c773a632681cf042331213e5b Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 20:11:56 +0200 Subject: [PATCH 11/32] fix(ocr): eliminate race window in runOrQueueSenderTraining by creating RUNNING row atomically Co-Authored-By: Claude Sonnet 4.6 --- .../repository/OcrTrainingRunRepository.java | 2 ++ .../service/SenderModelService.java | 35 +++++++++++-------- .../service/SenderModelServiceTest.java | 26 ++++++++++++++ 3 files changed, 49 insertions(+), 14 deletions(-) diff --git a/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java b/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java index 99ff81da..b3878164 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java @@ -14,6 +14,8 @@ public interface OcrTrainingRunRepository extends JpaRepository findFirstByStatusOrderByCreatedAtAsc(TrainingStatus status); + Optional findFirstByPersonIdAndStatus(UUID personId, TrainingStatus status); + boolean existsByPersonIdAndStatus(UUID personId, TrainingStatus status); List findTop20ByOrderByCreatedAtDesc(); diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java index 465d6a8e..19a796db 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java @@ -72,9 +72,10 @@ public class SenderModelService { } /** - * Inserts a QUEUED row when training is already running (coalescing: at most one QUEUED per - * person_id). Runs immediately if idle. Returns true if training should start now, false if - * queued or skipped. + * Atomically checks the queue state and either creates a RUNNING row (returns true) or a + * QUEUED row (returns false). All three operations — idle check, duplicate-queue guard, and + * RUNNING row creation — happen in one transaction, eliminating the race window that would + * otherwise exist between the check and a separate RUNNING row creation. */ @Transactional public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) { @@ -96,23 +97,29 @@ public class SenderModelService { return false; } + long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId); + trainingRunRepository.save(OcrTrainingRun.builder() + .status(TrainingStatus.RUNNING) + .personId(personId) + .blockCount((int) blockCount) + .documentCount(0) + .modelName("sender_" + personId) + .build()); return true; } - /** Executes sender training synchronously. Caller must run this on a background thread. */ + /** + * Executes sender training synchronously. Caller must run this on a background thread. + * The RUNNING row is expected to already exist — created atomically by + * runOrQueueSenderTraining (for new runs) or by promoteNextQueuedRun (for promoted runs). + */ public void triggerSenderTraining(UUID personId, int correctedLines) { String outputModelPath = "/app/models/sender_" + personId + ".mlmodel"; - OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { - long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId); - return trainingRunRepository.save(OcrTrainingRun.builder() - .status(TrainingStatus.RUNNING) - .personId(personId) - .blockCount((int) blockCount) - .documentCount(0) - .modelName("sender_" + personId) - .build()); - })); + OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> + trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING) + .orElseThrow(() -> new IllegalStateException( + "Expected RUNNING row for person " + personId + " but none found")))); String runId = run.getId().toString(); MDC.put("trainingRunId", runId); diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java index f46c5905..e08aa7b5 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java @@ -176,12 +176,38 @@ class SenderModelServiceTest { verify(trainingRunRepository, never()).save(any()); } + @Test + void runOrQueueSenderTraining_createsRunningRow_atomically() { + // RUNNING row must be created inside the same @Transactional boundary as the idle check, + // eliminating the race window between the check and a separate triggerSenderTraining call. + when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L); + when(trainingRunRepository.save(any())).thenAnswer(inv -> { + OcrTrainingRun r = inv.getArgument(0); + if (r.getId() == null) r.setId(UUID.randomUUID()); + return r; + }); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isTrue(); + verify(trainingRunRepository).save(argThat(r -> + r.getStatus() == TrainingStatus.RUNNING && personId.equals(r.getPersonId()))); + } + // ─── triggerSenderTraining — queue promotion ────────────────────────────── @Test void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception { UUID nextPersonId = UUID.randomUUID(); + OcrTrainingRun runningRun = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .personId(personId).blockCount(10).documentCount(0) + .modelName("sender_" + personId).build(); + when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING)) + .thenReturn(Optional.of(runningRun)); when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(10L); when(trainingRunRepository.save(any())).thenAnswer(inv -> { OcrTrainingRun r = inv.getArgument(0); -- 2.49.1 From 09f4601d15e530ad19c02af75983ca612ca9a7d9 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 20:13:21 +0200 Subject: [PATCH 12/32] test(ocr): verify triggerSenderTraining upserts SenderModel with correct path and cer Co-Authored-By: Claude Sonnet 4.6 --- .../service/SenderModelServiceTest.java | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java index e08aa7b5..ebafa53f 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java @@ -196,6 +196,33 @@ class SenderModelServiceTest { r.getStatus() == TrainingStatus.RUNNING && personId.equals(r.getPersonId()))); } + // ─── triggerSenderTraining — happy path ────────────────────────────────── + + @Test + void triggerSenderTraining_savesModelRecord_onSuccess() throws Exception { + OcrTrainingRun runningRun = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .personId(personId).blockCount(10).documentCount(0) + .modelName("sender_" + personId).build(); + when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING)) + .thenReturn(Optional.of(runningRun)); + when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {}); + when(ocrClient.trainSenderModel(any(), any())) + .thenReturn(new OcrClient.TrainingResult(0.08, 0.92, 0.08, 5)); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + when(senderModelRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)) + .thenReturn(Optional.empty()); + + service.triggerSenderTraining(personId, 120); + + verify(senderModelRepository).save(argThat(m -> + ("/app/models/sender_" + personId + ".mlmodel").equals(m.getModelPath()) + && Double.valueOf(0.08).equals(m.getCer()) + && personId.equals(m.getPersonId()))); + } + // ─── triggerSenderTraining — queue promotion ────────────────────────────── @Test -- 2.49.1 From 2459408930250dd7533f89717229ea688195c881 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 20:18:21 +0200 Subject: [PATCH 13/32] refactor(ocr): move person-name enrichment from OcrController into OcrTrainingService Co-Authored-By: Claude Sonnet 4.6 --- .../controller/OcrController.java | 16 +---------- .../service/OcrTrainingService.java | 21 ++++++++++++-- .../controller/OcrControllerTest.java | 28 +++++++++++++++---- .../service/OcrTrainingServiceTest.java | 5 +++- 4 files changed, 47 insertions(+), 23 deletions(-) diff --git a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java index 6cf5e1cf..cb48caca 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java @@ -14,7 +14,6 @@ import org.raddatz.familienarchiv.service.OcrBatchService; import org.raddatz.familienarchiv.service.OcrProgressService; import org.raddatz.familienarchiv.service.OcrService; import org.raddatz.familienarchiv.service.OcrTrainingService; -import org.raddatz.familienarchiv.service.PersonService; import org.raddatz.familienarchiv.service.SegmentationTrainingExportService; import org.raddatz.familienarchiv.service.TrainingDataExportService; import org.raddatz.familienarchiv.service.UserService; @@ -41,7 +40,6 @@ public class OcrController { private final OcrBatchService ocrBatchService; private final OcrProgressService ocrProgressService; private final UserService userService; - private final PersonService personService; private final TrainingDataExportService trainingDataExportService; private final SegmentationTrainingExportService segmentationTrainingExportService; private final OcrTrainingService ocrTrainingService; @@ -136,18 +134,6 @@ public class OcrController { public Map getTrainingInfo() { OcrTrainingService.TrainingInfoResponse info = ocrTrainingService.getTrainingInfo(); - Map personNames = new HashMap<>(); - for (OcrTrainingRun run : info.runs()) { - if (run.getPersonId() != null && !personNames.containsKey(run.getPersonId().toString())) { - try { - personNames.put(run.getPersonId().toString(), - personService.getById(run.getPersonId()).getDisplayName()); - } catch (Exception e) { - log.debug("Could not resolve display name for person {}: {}", run.getPersonId(), e.getMessage()); - } - } - } - Map result = new HashMap<>(); result.put("availableBlocks", info.availableBlocks()); result.put("totalOcrBlocks", info.totalOcrBlocks()); @@ -156,7 +142,7 @@ public class OcrController { result.put("ocrServiceAvailable", info.ocrServiceAvailable()); result.put("lastRun", info.lastRun() != null ? info.lastRun() : Map.of()); result.put("runs", info.runs()); - result.put("personNames", personNames); + result.put("personNames", info.personNames()); return result; } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java index ff80e1e6..f61fcb46 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java @@ -17,7 +17,9 @@ import org.springframework.transaction.support.TransactionTemplate; import java.io.ByteArrayOutputStream; import java.time.Instant; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.UUID; @@ -34,6 +36,7 @@ public class OcrTrainingService { private final OcrHealthClient ocrHealthClient; private final TranscriptionBlockRepository blockRepository; private final TransactionTemplate txTemplate; + private final PersonService personService; public record TrainingInfoResponse( int availableBlocks, @@ -42,7 +45,8 @@ public class OcrTrainingService { int availableSegBlocks, boolean ocrServiceAvailable, OcrTrainingRun lastRun, - List runs + List runs, + Map personNames ) {} private void assertNoRunningTraining() { @@ -198,6 +202,18 @@ public class OcrTrainingService { List recentRuns = trainingRunRepository.findTop20ByOrderByCreatedAtDesc(); OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0); + Map personNames = new HashMap<>(); + for (OcrTrainingRun run : recentRuns) { + if (run.getPersonId() != null && !personNames.containsKey(run.getPersonId().toString())) { + try { + personNames.put(run.getPersonId().toString(), + personService.getById(run.getPersonId()).getDisplayName()); + } catch (Exception e) { + log.debug("Could not resolve display name for person {}: {}", run.getPersonId(), e.getMessage()); + } + } + } + return new TrainingInfoResponse( eligibleBlocks.size(), totalOcrBlocks, @@ -205,7 +221,8 @@ public class OcrTrainingService { availableSegBlocks, ocrHealthClient.isHealthy(), lastRun, - recentRuns + recentRuns, + personNames ); } diff --git a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java index 4d087dea..5313b267 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java @@ -23,6 +23,8 @@ import org.springframework.test.web.servlet.MockMvc; import java.util.List; import java.util.UUID; +import java.util.Map; + import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.eq; @@ -43,7 +45,6 @@ class OcrControllerTest { @MockitoBean OcrBatchService ocrBatchService; @MockitoBean OcrProgressService ocrProgressService; @MockitoBean UserService userService; - @MockitoBean PersonService personService; @MockitoBean CustomUserDetailsService customUserDetailsService; @MockitoBean TrainingDataExportService trainingDataExportService; @MockitoBean SegmentationTrainingExportService segmentationTrainingExportService; @@ -220,7 +221,7 @@ class OcrControllerTest { @WithMockUser(authorities = "ADMIN") void getTrainingInfo_returns200_withInfo() throws Exception { OcrTrainingService.TrainingInfoResponse info = - new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of()); + new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(), Map.of()); when(ocrTrainingService.getTrainingInfo()).thenReturn(info); mockMvc.perform(get("/api/ocr/training-info")) @@ -231,21 +232,38 @@ class OcrControllerTest { @Test @WithMockUser(authorities = "ADMIN") - void getTrainingInfo_returns200_and_omits_personName_when_resolution_throws() throws Exception { + void getTrainingInfo_returns200_and_omits_personName_when_service_provides_empty_map() throws Exception { UUID personId = UUID.randomUUID(); OcrTrainingRun runWithPerson = OcrTrainingRun.builder() .id(UUID.randomUUID()).status(TrainingStatus.DONE) .personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build(); OcrTrainingService.TrainingInfoResponse info = - new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(runWithPerson)); + new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, + List.of(runWithPerson), Map.of()); when(ocrTrainingService.getTrainingInfo()).thenReturn(info); - when(personService.getById(personId)).thenThrow(new RuntimeException("DB error")); mockMvc.perform(get("/api/ocr/training-info")) .andExpect(status().isOk()) .andExpect(jsonPath("$.personNames").isEmpty()); } + @Test + @WithMockUser(authorities = "ADMIN") + void getTrainingInfo_includesPersonName_whenPersonIdResolves() throws Exception { + UUID personId = UUID.randomUUID(); + OcrTrainingRun runWithPerson = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.DONE) + .personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build(); + OcrTrainingService.TrainingInfoResponse info = + new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, + List.of(runWithPerson), Map.of(personId.toString(), "Max Mustermann")); + when(ocrTrainingService.getTrainingInfo()).thenReturn(info); + + mockMvc.perform(get("/api/ocr/training-info")) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.personNames." + personId).value("Max Mustermann")); + } + @Test @WithMockUser(authorities = "READ_ALL") void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception { diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java index e6e43f00..a887a490 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -8,6 +8,7 @@ import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.raddatz.familienarchiv.service.PersonService; import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionTemplate; @@ -31,6 +32,7 @@ class OcrTrainingServiceTest { OcrHealthClient healthClient; TranscriptionBlockRepository blockRepository; TransactionTemplate txTemplate; + PersonService personService; OcrTrainingService service; @BeforeEach @@ -42,6 +44,7 @@ class OcrTrainingServiceTest { healthClient = mock(OcrHealthClient.class); blockRepository = mock(TranscriptionBlockRepository.class); txTemplate = mock(TransactionTemplate.class); + personService = mock(PersonService.class); // Execute transaction callbacks inline so unit tests run without a real DataSource when(txTemplate.execute(any())).thenAnswer(inv -> { @@ -49,7 +52,7 @@ class OcrTrainingServiceTest { return callback.doInTransaction(null); }); - service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); + service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate, personService); when(blockRepository.count()).thenReturn(0L); when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of()); -- 2.49.1 From eab37b9ac9450bb5d6c132a53e7d6178b5d157b5 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 20:19:40 +0200 Subject: [PATCH 14/32] test(ocr): verify load failure does not cache broken entry in SenderModelRegistry Co-Authored-By: Claude Sonnet 4.6 --- ocr-service/test_sender_registry.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ocr-service/test_sender_registry.py b/ocr-service/test_sender_registry.py index e5a7c97d..7cf7cd24 100644 --- a/ocr-service/test_sender_registry.py +++ b/ocr-service/test_sender_registry.py @@ -63,3 +63,13 @@ def test_path_outside_models_dir_raises(): registry = _make_registry() with pytest.raises(ValueError, match="not allowed"): registry.get_model("/etc/passwd") + + +def test_load_failure_does_not_cache_broken_entry(): + """A failed load must not leave a broken entry in the cache.""" + registry = _make_registry() + with patch("engines.kraken._load_sender_model", side_effect=RuntimeError("corrupt model")): + with pytest.raises(RuntimeError, match="corrupt model"): + registry.get_model("/app/models/sender_bad.mlmodel") + + assert registry.size() == 0 -- 2.49.1 From 57ffb7d75169596abd5f8678848de2cfd732fe22 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 20:20:53 +0200 Subject: [PATCH 15/32] chore(ocr): lower OCR_MAX_CACHED_MODELS to 2 with memory budget comment Co-Authored-By: Claude Sonnet 4.6 --- docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index e914cc91..5593a105 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -97,7 +97,7 @@ services: DETECTOR_BATCH_SIZE: "8" OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency) OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page) - OCR_MAX_CACHED_MODELS: "5" # LRU cache size for per-sender Kraken models + OCR_MAX_CACHED_MODELS: "2" # LRU cache; each model ~500 MB, so 2 = ~1 GB resident networks: - archive-net healthcheck: -- 2.49.1 From 07035b9fa959009d76099b9ad0b16079b7e90e3e Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 20:22:34 +0200 Subject: [PATCH 16/32] style(ocr): add Image type hints to extract_page_blocks and extract_region_text Co-Authored-By: Claude Sonnet 4.6 --- ocr-service/engines/kraken.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ocr-service/engines/kraken.py b/ocr-service/engines/kraken.py index d33567bc..60dfd47e 100644 --- a/ocr-service/engines/kraken.py +++ b/ocr-service/engines/kraken.py @@ -1,9 +1,15 @@ """Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin.""" +from __future__ import annotations + import collections import logging import os import threading +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from PIL.Image import Image logger = logging.getLogger(__name__) @@ -94,7 +100,7 @@ def is_available() -> bool: return _model is not None -def extract_page_blocks(image, page_idx: int, language: str = "de", +def extract_page_blocks(image: Image, page_idx: int, language: str = "de", sender_model_path: str | None = None) -> list[dict]: """Run Kraken segmentation + recognition on a single PIL image. @@ -148,7 +154,7 @@ def extract_page_blocks(image, page_idx: int, language: str = "de", return blocks -def extract_region_text(image, x: float, y: float, w: float, h: float, +def extract_region_text(image: Image, x: float, y: float, w: float, h: float, sender_model_path: str | None = None) -> str: """Crop image to a normalized region and run Kraken recognition on the crop. -- 2.49.1 From e2081b57e78f2853dec87777d97987ff789f0017 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 20:24:38 +0200 Subject: [PATCH 17/32] refactor(ocr): extract exportSenderData helper in triggerSenderTraining Co-Authored-By: Claude Sonnet 4.6 --- .../familienarchiv/service/SenderModelService.java | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java index 19a796db..653203d4 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java @@ -126,10 +126,7 @@ public class SenderModelService { log.info("Started sender training run {} for person {}", runId, personId); try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - trainingDataExportService.exportForSender(personId).writeTo(baos); - byte[] zipBytes = baos.toByteArray(); - + byte[] zipBytes = exportSenderData(personId); log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length); OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath); @@ -166,6 +163,12 @@ public class SenderModelService { } } + private byte[] exportSenderData(UUID personId) throws java.io.IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + trainingDataExportService.exportForSender(personId).writeTo(baos); + return baos.toByteArray(); + } + /** * Promotes the oldest QUEUED sender run to RUNNING and triggers its training. * Called in the finally block of triggerSenderTraining, creating a sequential chain: -- 2.49.1 From f76a9cce1f35c470ce16b87b5bf7ace9c7a17a95 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 20:38:43 +0200 Subject: [PATCH 18/32] test(ocr): add failure path and DONE status assertions to SenderModelServiceTest Co-Authored-By: Claude Sonnet 4.6 --- .../service/SenderModelServiceTest.java | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java index ebafa53f..be41de41 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java @@ -221,6 +221,36 @@ class SenderModelServiceTest { ("/app/models/sender_" + personId + ".mlmodel").equals(m.getModelPath()) && Double.valueOf(0.08).equals(m.getCer()) && personId.equals(m.getPersonId()))); + verify(trainingRunRepository, atLeastOnce()).save(argThat(r -> + r.getStatus() == TrainingStatus.DONE + && Double.valueOf(0.08).equals(r.getCer()) + && r.getCompletedAt() != null)); + } + + // ─── triggerSenderTraining — failure path ──────────────────────────────── + + @Test + void triggerSenderTraining_marksRunFailed_whenOcrClientThrows() throws Exception { + OcrTrainingRun runningRun = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .personId(personId).blockCount(10).documentCount(0) + .modelName("sender_" + personId).build(); + when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING)) + .thenReturn(Optional.of(runningRun)); + when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {}); + when(ocrClient.trainSenderModel(any(), any())) + .thenThrow(new RuntimeException("OCR service timeout")); + when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)) + .thenReturn(Optional.empty()); + + service.triggerSenderTraining(personId, 120); + + verify(trainingRunRepository, atLeastOnce()).save(argThat(r -> + r.getStatus() == TrainingStatus.FAILED + && r.getErrorMessage() != null + && r.getCompletedAt() != null)); + verify(senderModelRepository, never()).save(any()); } // ─── triggerSenderTraining — queue promotion ────────────────────────────── -- 2.49.1 From 0c2175aa07a300ac2c8a1ac05e71722a4ec44e45 Mon Sep 17 00:00:00 2001 From: Marcel Date: Fri, 17 Apr 2026 20:42:06 +0200 Subject: [PATCH 19/32] refactor(frontend): extract shared TrainingRun type to $lib/types/training.ts Co-Authored-By: Claude Sonnet 4.6 --- .../src/lib/components/OcrTrainingCard.svelte | 18 +++--------------- .../src/lib/components/TrainingHistory.svelte | 17 ++--------------- frontend/src/lib/types/training.ts | 13 +++++++++++++ 3 files changed, 18 insertions(+), 30 deletions(-) create mode 100644 frontend/src/lib/types/training.ts diff --git a/frontend/src/lib/components/OcrTrainingCard.svelte b/frontend/src/lib/components/OcrTrainingCard.svelte index cab026a9..d6483f1e 100644 --- a/frontend/src/lib/components/OcrTrainingCard.svelte +++ b/frontend/src/lib/components/OcrTrainingCard.svelte @@ -1,27 +1,15 @@