feat(ocr): add training history + POST /train + GET /training-info endpoints

- OcrTrainingRun entity + V30 migration (partial unique index prevents
  concurrent runs at DB level)
- OcrTrainingService: concurrent-run guard, 5-block threshold, MDC log
  correlation, orphan recovery on ApplicationReadyEvent
- POST /api/ocr/train (ADMIN) + GET /api/ocr/training-info (ADMIN)
- TRAINING_ALREADY_RUNNING ErrorCode
- 6 OcrTrainingServiceTest + 6 OcrControllerTest tests for the new endpoints

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-13 14:47:56 +02:00
parent bc97a2dade
commit 88e005eb49
9 changed files with 493 additions and 0 deletions

View File

@@ -44,6 +44,7 @@ class OcrControllerTest {
@MockitoBean UserService userService;
@MockitoBean CustomUserDetailsService customUserDetailsService;
@MockitoBean TrainingDataExportService trainingDataExportService;
@MockitoBean OcrTrainingService ocrTrainingService;
@Test
@WithMockUser(authorities = "WRITE_ALL")
@@ -163,6 +164,68 @@ class OcrControllerTest {
result.getResponse().getContentType()).contains("application/zip"));
}
// ─── POST /api/ocr/train ───────────────────────────────────────────────────
@Test
void triggerTraining_returns401_whenUnauthenticated() throws Exception {
mockMvc.perform(post("/api/ocr/train"))
.andExpect(status().isUnauthorized());
}
@Test
@WithMockUser(authorities = "READ_ALL")
void triggerTraining_returns403_whenNotAdmin() throws Exception {
mockMvc.perform(post("/api/ocr/train"))
.andExpect(status().isForbidden());
}
@Test
@WithMockUser(authorities = "ADMIN")
void triggerTraining_returns409_whenRunAlreadyRunning() throws Exception {
when(ocrTrainingService.triggerTraining(any()))
.thenThrow(DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, "Already running"));
mockMvc.perform(post("/api/ocr/train"))
.andExpect(status().isConflict());
}
@Test
@WithMockUser(authorities = "ADMIN")
void triggerTraining_returns201_withRunInfo() throws Exception {
UUID runId = UUID.randomUUID();
OcrTrainingRun run = OcrTrainingRun.builder()
.id(runId).status(TrainingStatus.DONE)
.blockCount(10).documentCount(3).modelName("german_kurrent").build();
when(ocrTrainingService.triggerTraining(any())).thenReturn(run);
mockMvc.perform(post("/api/ocr/train"))
.andExpect(status().isCreated())
.andExpect(jsonPath("$.status").value("DONE"))
.andExpect(jsonPath("$.blockCount").value(10));
}
// ─── GET /api/ocr/training-info ───────────────────────────────────────────
@Test
@WithMockUser(authorities = "READ_ALL")
void getTrainingInfo_returns403_whenNotAdmin() throws Exception {
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isForbidden());
}
@Test
@WithMockUser(authorities = "ADMIN")
void getTrainingInfo_returns200_withInfo() throws Exception {
OcrTrainingService.TrainingInfoResponse info =
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, true, null, List.of());
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.availableBlocks").value(5))
.andExpect(jsonPath("$.ocrServiceAvailable").value(true));
}
@Test
@WithMockUser(authorities = "READ_ALL")
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {

View File

@@ -0,0 +1,168 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingLabel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.*;
class OcrTrainingServiceTest {
OcrTrainingRunRepository runRepository;
TrainingDataExportService exportService;
OcrClient ocrClient;
OcrHealthClient healthClient;
TranscriptionBlockRepository blockRepository;
OcrTrainingService service;
@BeforeEach
void setUp() {
runRepository = mock(OcrTrainingRunRepository.class);
exportService = mock(TrainingDataExportService.class);
ocrClient = mock(OcrClient.class);
healthClient = mock(OcrHealthClient.class);
blockRepository = mock(TranscriptionBlockRepository.class);
service = new OcrTrainingService(runRepository, exportService, ocrClient, healthClient, blockRepository);
when(blockRepository.findAll()).thenReturn(List.of());
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
}
// ─── Concurrent guard ─────────────────────────────────────────────────────
@Test
void triggerTraining_throws409_whenRunningRunExists() {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING))
.thenReturn(Optional.of(OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(2).modelName("german_kurrent").build()));
assertThatThrownBy(() -> service.triggerTraining(null))
.isInstanceOf(DomainException.class)
.extracting("status")
.satisfies(s -> assertThat(s.toString()).contains("409"));
}
// ─── Threshold guard ──────────────────────────────────────────────────────
@Test
void triggerTraining_throws422_whenFewerThan5Blocks() {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
when(exportService.queryEligibleBlocks()).thenReturn(List.of(
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build()
));
assertThatThrownBy(() -> service.triggerTraining(null))
.isInstanceOf(DomainException.class);
}
// ─── Happy path ───────────────────────────────────────────────────────────
@Test
void triggerTraining_createsRunWithCorrectCounts_andMarksDone() throws Exception {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
UUID docA = UUID.randomUUID();
UUID docB = UUID.randomUUID();
List<TranscriptionBlock> blocks = List.of(
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docB).build()
);
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
when(exportService.exportToZip()).thenReturn(out -> {});
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 3));
OcrTrainingRun saved = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(2).modelName("german_kurrent").build();
when(runRepository.save(any())).thenReturn(saved);
service.triggerTraining(null);
// Verify run created with correct counts and then updated to DONE
verify(runRepository, times(2)).save(argThat(run ->
run.getBlockCount() == 5 || run.getStatus() == TrainingStatus.DONE));
}
// ─── Failure path ─────────────────────────────────────────────────────────
@Test
void triggerTraining_marksRunFailed_whenOcrClientThrows() throws Exception {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
UUID docA = UUID.randomUUID();
List<TranscriptionBlock> blocks = List.of(
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build()
);
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
when(exportService.exportToZip()).thenReturn(out -> {});
when(ocrClient.trainModel(any())).thenThrow(new RuntimeException("OCR service timeout"));
OcrTrainingRun saved = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent").build();
when(runRepository.save(any())).thenReturn(saved);
service.triggerTraining(null);
verify(runRepository, atLeastOnce()).save(argThat(run ->
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
}
// ─── Orphan recovery ──────────────────────────────────────────────────────
@Test
void recoverOrphanedRuns_marksRunFailed_whenOlderThanOneHour() {
OcrTrainingRun orphan = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent")
.createdAt(Instant.now().minusSeconds(7200))
.build();
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(orphan));
when(runRepository.save(any())).thenReturn(orphan);
service.recoverOrphanedRuns();
verify(runRepository).save(argThat(run ->
run.getStatus() == TrainingStatus.FAILED
&& run.getErrorMessage().contains("Abgebrochen")));
}
@Test
void recoverOrphanedRuns_doesNothing_whenRunIsRecent() {
OcrTrainingRun recent = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent")
.createdAt(Instant.now().minusSeconds(60))
.build();
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(recent));
service.recoverOrphanedRuns();
verify(runRepository, never()).save(any());
}
}