feat(ocr): add training history + POST /train + GET /training-info endpoints
- OcrTrainingRun entity + V30 migration (partial unique index prevents concurrent runs at DB level) - OcrTrainingService: concurrent-run guard, 5-block threshold, MDC log correlation, orphan recovery on ApplicationReadyEvent - POST /api/ocr/train (ADMIN) + GET /api/ocr/training-info (ADMIN) - TRAINING_ALREADY_RUNNING ErrorCode - 6 OcrTrainingServiceTest + 6 OcrControllerTest tests for the new endpoints Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -44,6 +44,7 @@ class OcrControllerTest {
|
||||
@MockitoBean UserService userService;
|
||||
@MockitoBean CustomUserDetailsService customUserDetailsService;
|
||||
@MockitoBean TrainingDataExportService trainingDataExportService;
|
||||
@MockitoBean OcrTrainingService ocrTrainingService;
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
@@ -163,6 +164,68 @@ class OcrControllerTest {
|
||||
result.getResponse().getContentType()).contains("application/zip"));
|
||||
}
|
||||
|
||||
// ─── POST /api/ocr/train ───────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_returns401_whenUnauthenticated() throws Exception {
|
||||
mockMvc.perform(post("/api/ocr/train"))
|
||||
.andExpect(status().isUnauthorized());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void triggerTraining_returns403_whenNotAdmin() throws Exception {
|
||||
mockMvc.perform(post("/api/ocr/train"))
|
||||
.andExpect(status().isForbidden());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void triggerTraining_returns409_whenRunAlreadyRunning() throws Exception {
|
||||
when(ocrTrainingService.triggerTraining(any()))
|
||||
.thenThrow(DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, "Already running"));
|
||||
|
||||
mockMvc.perform(post("/api/ocr/train"))
|
||||
.andExpect(status().isConflict());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void triggerTraining_returns201_withRunInfo() throws Exception {
|
||||
UUID runId = UUID.randomUUID();
|
||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||
.id(runId).status(TrainingStatus.DONE)
|
||||
.blockCount(10).documentCount(3).modelName("german_kurrent").build();
|
||||
when(ocrTrainingService.triggerTraining(any())).thenReturn(run);
|
||||
|
||||
mockMvc.perform(post("/api/ocr/train"))
|
||||
.andExpect(status().isCreated())
|
||||
.andExpect(jsonPath("$.status").value("DONE"))
|
||||
.andExpect(jsonPath("$.blockCount").value(10));
|
||||
}
|
||||
|
||||
// ─── GET /api/ocr/training-info ───────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void getTrainingInfo_returns403_whenNotAdmin() throws Exception {
|
||||
mockMvc.perform(get("/api/ocr/training-info"))
|
||||
.andExpect(status().isForbidden());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void getTrainingInfo_returns200_withInfo() throws Exception {
|
||||
OcrTrainingService.TrainingInfoResponse info =
|
||||
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, true, null, List.of());
|
||||
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
|
||||
|
||||
mockMvc.perform(get("/api/ocr/training-info"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.availableBlocks").value(5))
|
||||
.andExpect(jsonPath("$.ocrServiceAvailable").value(true));
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingLabel;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.argThat;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class OcrTrainingServiceTest {
|
||||
|
||||
OcrTrainingRunRepository runRepository;
|
||||
TrainingDataExportService exportService;
|
||||
OcrClient ocrClient;
|
||||
OcrHealthClient healthClient;
|
||||
TranscriptionBlockRepository blockRepository;
|
||||
OcrTrainingService service;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
runRepository = mock(OcrTrainingRunRepository.class);
|
||||
exportService = mock(TrainingDataExportService.class);
|
||||
ocrClient = mock(OcrClient.class);
|
||||
healthClient = mock(OcrHealthClient.class);
|
||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||
|
||||
service = new OcrTrainingService(runRepository, exportService, ocrClient, healthClient, blockRepository);
|
||||
|
||||
when(blockRepository.findAll()).thenReturn(List.of());
|
||||
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||
}
|
||||
|
||||
// ─── Concurrent guard ─────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_throws409_whenRunningRunExists() {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING))
|
||||
.thenReturn(Optional.of(OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(2).modelName("german_kurrent").build()));
|
||||
|
||||
assertThatThrownBy(() -> service.triggerTraining(null))
|
||||
.isInstanceOf(DomainException.class)
|
||||
.extracting("status")
|
||||
.satisfies(s -> assertThat(s.toString()).contains("409"));
|
||||
}
|
||||
|
||||
// ─── Threshold guard ──────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_throws422_whenFewerThan5Blocks() {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
when(exportService.queryEligibleBlocks()).thenReturn(List.of(
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build()
|
||||
));
|
||||
|
||||
assertThatThrownBy(() -> service.triggerTraining(null))
|
||||
.isInstanceOf(DomainException.class);
|
||||
}
|
||||
|
||||
// ─── Happy path ───────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_createsRunWithCorrectCounts_andMarksDone() throws Exception {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
UUID docA = UUID.randomUUID();
|
||||
UUID docB = UUID.randomUUID();
|
||||
List<TranscriptionBlock> blocks = List.of(
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docB).build()
|
||||
);
|
||||
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
|
||||
when(exportService.exportToZip()).thenReturn(out -> {});
|
||||
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 3));
|
||||
|
||||
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(2).modelName("german_kurrent").build();
|
||||
when(runRepository.save(any())).thenReturn(saved);
|
||||
|
||||
service.triggerTraining(null);
|
||||
|
||||
// Verify run created with correct counts and then updated to DONE
|
||||
verify(runRepository, times(2)).save(argThat(run ->
|
||||
run.getBlockCount() == 5 || run.getStatus() == TrainingStatus.DONE));
|
||||
}
|
||||
|
||||
// ─── Failure path ─────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_marksRunFailed_whenOcrClientThrows() throws Exception {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
UUID docA = UUID.randomUUID();
|
||||
List<TranscriptionBlock> blocks = List.of(
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build()
|
||||
);
|
||||
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
|
||||
when(exportService.exportToZip()).thenReturn(out -> {});
|
||||
when(ocrClient.trainModel(any())).thenThrow(new RuntimeException("OCR service timeout"));
|
||||
|
||||
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent").build();
|
||||
when(runRepository.save(any())).thenReturn(saved);
|
||||
|
||||
service.triggerTraining(null);
|
||||
|
||||
verify(runRepository, atLeastOnce()).save(argThat(run ->
|
||||
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
|
||||
}
|
||||
|
||||
// ─── Orphan recovery ──────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void recoverOrphanedRuns_marksRunFailed_whenOlderThanOneHour() {
|
||||
OcrTrainingRun orphan = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent")
|
||||
.createdAt(Instant.now().minusSeconds(7200))
|
||||
.build();
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(orphan));
|
||||
when(runRepository.save(any())).thenReturn(orphan);
|
||||
|
||||
service.recoverOrphanedRuns();
|
||||
|
||||
verify(runRepository).save(argThat(run ->
|
||||
run.getStatus() == TrainingStatus.FAILED
|
||||
&& run.getErrorMessage().contains("Abgebrochen")));
|
||||
}
|
||||
|
||||
@Test
|
||||
void recoverOrphanedRuns_doesNothing_whenRunIsRecent() {
|
||||
OcrTrainingRun recent = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent")
|
||||
.createdAt(Instant.now().minusSeconds(60))
|
||||
.build();
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(recent));
|
||||
|
||||
service.recoverOrphanedRuns();
|
||||
|
||||
verify(runRepository, never()).save(any());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user