diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java index 62ef00f6..b89e1ae6 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -8,6 +8,7 @@ import org.raddatz.familienarchiv.model.TrainingLabel; import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; +import org.raddatz.familienarchiv.repository.SenderModelRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionTemplate; @@ -26,6 +27,7 @@ import static org.mockito.Mockito.*; class OcrTrainingServiceTest { OcrTrainingRunRepository runRepository; + SenderModelRepository senderModelRepository; TrainingDataExportService exportService; SegmentationTrainingExportService segExportService; OcrClient ocrClient; @@ -37,6 +39,7 @@ class OcrTrainingServiceTest { @BeforeEach void setUp() { runRepository = mock(OcrTrainingRunRepository.class); + senderModelRepository = mock(SenderModelRepository.class); exportService = mock(TrainingDataExportService.class); segExportService = mock(SegmentationTrainingExportService.class); ocrClient = mock(OcrClient.class); @@ -50,10 +53,10 @@ class OcrTrainingServiceTest { return callback.doInTransaction(null); }); - service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); + service = new OcrTrainingService(runRepository, senderModelRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); when(blockRepository.count()).thenReturn(0L); - when(runRepository.findTop10ByOrderByCreatedAtDesc()).thenReturn(List.of()); + when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of()); when(segExportService.querySegmentationBlocks()).thenReturn(List.of()); } @@ -230,6 +233,48 @@ class OcrTrainingServiceTest { run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null)); } + // ─── runOrQueueSenderTraining ──────────────────────────────────────────── + + @Test + void runOrQueueSenderTraining_returnsTrue_whenIdle() { + UUID personId = UUID.randomUUID(); + when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isTrue(); + verify(runRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED)); + } + + @Test + void runOrQueueSenderTraining_savesQueued_whenRunning() { + UUID personId = UUID.randomUUID(); + when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn( + Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(1).modelName("german_kurrent").build())); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L); + when(runRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isFalse(); + verify(runRepository).save(argThat(r -> + r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId()))); + } + + @Test + void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() { + UUID personId = UUID.randomUUID(); + when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isFalse(); + verify(runRepository, never()).save(any()); + } + // ─── Orphan recovery ────────────────────────────────────────────────────── @Test diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java new file mode 100644 index 00000000..b3a3a5d9 --- /dev/null +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java @@ -0,0 +1,109 @@ +package org.raddatz.familienarchiv.service; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.raddatz.familienarchiv.model.SenderModel; +import org.raddatz.familienarchiv.repository.SenderModelRepository; +import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.springframework.test.util.ReflectionTestUtils; + +import java.util.Optional; +import java.util.UUID; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.*; + +class SenderModelServiceTest { + + SenderModelRepository senderModelRepository; + TranscriptionBlockRepository blockRepository; + OcrTrainingService ocrTrainingService; + SenderModelService service; + + UUID personId = UUID.randomUUID(); + + @BeforeEach + void setUp() { + senderModelRepository = mock(SenderModelRepository.class); + blockRepository = mock(TranscriptionBlockRepository.class); + ocrTrainingService = mock(OcrTrainingService.class); + service = new SenderModelService(senderModelRepository, blockRepository, ocrTrainingService); + ReflectionTestUtils.setField(service, "activationThreshold", 100); + ReflectionTestUtils.setField(service, "retrainDelta", 50); + } + + // ─── Activation threshold ───────────────────────────────────────────────── + + @Test + void checkAndTriggerTraining_doesNothing_belowActivationThreshold() { + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + + service.checkAndTriggerTraining(personId); + + verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt()); + } + + @Test + void checkAndTriggerTraining_triggersTraining_atActivationThreshold() { + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false); + + service.checkAndTriggerTraining(personId); + + verify(ocrTrainingService).runOrQueueSenderTraining(personId, 100); + } + + // ─── Retrain delta ──────────────────────────────────────────────────────── + + @Test + void checkAndTriggerTraining_doesNothing_belowRetrainDelta() { + SenderModel existing = SenderModel.builder().personId(personId) + .correctedLinesAtTraining(100).build(); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing)); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L); + + service.checkAndTriggerTraining(personId); + + verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt()); + } + + @Test + void checkAndTriggerTraining_triggersRetrain_atRetrainDelta() { + SenderModel existing = SenderModel.builder().personId(personId) + .correctedLinesAtTraining(100).build(); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing)); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L); + when(ocrTrainingService.runOrQueueSenderTraining(personId, 150)).thenReturn(false); + + service.checkAndTriggerTraining(personId); + + verify(ocrTrainingService).runOrQueueSenderTraining(personId, 150); + } + + // ─── runNow flag ────────────────────────────────────────────────────────── + + @Test + void checkAndTriggerTraining_callsTrigger_whenRunNow() { + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(true); + + service.checkAndTriggerTraining(personId); + + verify(ocrTrainingService).triggerSenderTraining(personId, 100); + } + + @Test + void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() { + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false); + + service.checkAndTriggerTraining(personId); + + verify(ocrTrainingService, never()).triggerSenderTraining(any(), anyInt()); + } +} diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java index ba685fe8..7fd2aee7 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/TranscriptionServiceTest.java @@ -13,6 +13,7 @@ import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.model.BlockSource; import org.raddatz.familienarchiv.model.Document; import org.raddatz.familienarchiv.model.DocumentAnnotation; +import org.raddatz.familienarchiv.model.Person; import org.raddatz.familienarchiv.model.ScriptType; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.model.TranscriptionBlockVersion; @@ -189,6 +190,57 @@ class TranscriptionServiceTest { assertThat(result.getLabel()).isEqualTo("Anrede"); } + @Test + void updateBlock_setsSourceToManual() { + UUID docId = UUID.randomUUID(); + UUID blockId = UUID.randomUUID(); + TranscriptionBlock block = TranscriptionBlock.builder() + .id(blockId).documentId(docId).text("old").source(BlockSource.OCR).build(); + when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); + when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.TYPEWRITER).build()); + + TranscriptionBlock result = transcriptionService.updateBlock( + docId, blockId, new UpdateTranscriptionBlockDTO("new", null), UUID.randomUUID()); + + assertThat(result.getSource()).isEqualTo(BlockSource.MANUAL); + } + + @Test + void updateBlock_triggersTraining_whenKurrentSenderPresent() { + UUID docId = UUID.randomUUID(); + UUID blockId = UUID.randomUUID(); + UUID senderId = UUID.randomUUID(); + Person sender = Person.builder().id(senderId).build(); + TranscriptionBlock block = TranscriptionBlock.builder() + .id(blockId).documentId(docId).text("old").build(); + when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); + when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).sender(sender).build()); + + transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID()); + + verify(senderModelService).checkAndTriggerTraining(senderId); + } + + @Test + void updateBlock_doesNotTriggerTraining_whenNoSender() { + UUID docId = UUID.randomUUID(); + UUID blockId = UUID.randomUUID(); + TranscriptionBlock block = TranscriptionBlock.builder() + .id(blockId).documentId(docId).text("old").build(); + when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); + when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(documentService.getDocumentById(any())).thenReturn( + Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).build()); + + transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID()); + + verify(senderModelService, never()).checkAndTriggerTraining(any()); + } + // ─── deleteBlock ───────────────────────────────────────────────────────────── @Test