diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java index 7d83188f..1d1e47e9 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -146,6 +146,90 @@ class OcrTrainingServiceTest { run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null)); } + // ─── triggerSegTraining ─────────────────────────────────────────────────── + + @Test + void triggerSegTraining_throws409_whenRunningRunExists() { + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)) + .thenReturn(Optional.of(OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(2).modelName("blla").build())); + + assertThatThrownBy(() -> service.triggerSegTraining(null)) + .isInstanceOf(DomainException.class) + .extracting("status") + .satisfies(s -> assertThat(s.toString()).contains("409")); + } + + @Test + void triggerSegTraining_throws422_whenFewerThan5Segments() { + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + when(segExportService.querySegmentationBlocks()).thenReturn(List.of( + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build() + )); + + assertThatThrownBy(() -> service.triggerSegTraining(null)) + .isInstanceOf(DomainException.class); + } + + @Test + void triggerSegTraining_createsRunWithBlla_andMarksDoneWithCer() throws Exception { + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + + UUID docA = UUID.randomUUID(); + UUID docB = UUID.randomUUID(); + List segs = List.of( + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docB).build() + ); + when(segExportService.querySegmentationBlocks()).thenReturn(segs); + when(segExportService.exportToZip()).thenReturn(out -> {}); + when(ocrClient.segtrainModel(any())).thenReturn(new OcrClient.TrainingResult(null, 0.92, 0.08, 5)); + + OcrTrainingRun saved = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(2).modelName("blla").build(); + when(runRepository.save(any())).thenReturn(saved); + + service.triggerSegTraining(null); + + verify(runRepository, atLeastOnce()).save(argThat(run -> + run.getStatus() == TrainingStatus.DONE + && "blla".equals(run.getModelName()) + && run.getCer() != null)); + } + + @Test + void triggerSegTraining_marksRunFailed_whenOcrClientThrows() throws Exception { + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + + UUID docA = UUID.randomUUID(); + List segs = List.of( + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build() + ); + when(segExportService.querySegmentationBlocks()).thenReturn(segs); + when(segExportService.exportToZip()).thenReturn(out -> {}); + when(ocrClient.segtrainModel(any())).thenThrow(new RuntimeException("seg timeout")); + + OcrTrainingRun saved = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(1).modelName("blla").build(); + when(runRepository.save(any())).thenReturn(saved); + + service.triggerSegTraining(null); + + verify(runRepository, atLeastOnce()).save(argThat(run -> + run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null)); + } + // ─── Orphan recovery ────────────────────────────────────────────────────── @Test