diff --git a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java index 4c0d1d4a..23636180 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java @@ -7,11 +7,13 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO; import org.raddatz.familienarchiv.dto.TriggerOcrDTO; import org.raddatz.familienarchiv.model.AppUser; import org.raddatz.familienarchiv.model.OcrJob; +import org.raddatz.familienarchiv.model.OcrTrainingRun; import org.raddatz.familienarchiv.security.Permission; import org.raddatz.familienarchiv.security.RequirePermission; import org.raddatz.familienarchiv.service.OcrBatchService; import org.raddatz.familienarchiv.service.OcrProgressService; import org.raddatz.familienarchiv.service.OcrService; +import org.raddatz.familienarchiv.service.OcrTrainingService; import org.raddatz.familienarchiv.service.TrainingDataExportService; import org.raddatz.familienarchiv.service.UserService; import org.springframework.http.HttpHeaders; @@ -37,6 +39,7 @@ public class OcrController { private final OcrProgressService ocrProgressService; private final UserService userService; private final TrainingDataExportService trainingDataExportService; + private final OcrTrainingService ocrTrainingService; @PostMapping("/api/documents/{documentId}/ocr") @ResponseStatus(HttpStatus.ACCEPTED) @@ -93,6 +96,20 @@ public class OcrController { .body(body); } + @PostMapping("/api/ocr/train") + @ResponseStatus(HttpStatus.CREATED) + @RequirePermission(Permission.ADMIN) + public OcrTrainingRun triggerTraining(Authentication authentication) { + UUID userId = resolveUserId(authentication); + return ocrTrainingService.triggerTraining(userId); + } + + @GetMapping("/api/ocr/training-info") + @RequirePermission(Permission.ADMIN) + public OcrTrainingService.TrainingInfoResponse getTrainingInfo() { + return ocrTrainingService.getTrainingInfo(); + } + private UUID resolveUserId(Authentication authentication) { if (authentication == null || !authentication.isAuthenticated()) return null; try { diff --git a/backend/src/main/java/org/raddatz/familienarchiv/exception/ErrorCode.java b/backend/src/main/java/org/raddatz/familienarchiv/exception/ErrorCode.java index e3b0c99c..d8a03b83 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/exception/ErrorCode.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/exception/ErrorCode.java @@ -75,6 +75,8 @@ public enum ErrorCode { OCR_DOCUMENT_NOT_UPLOADED, /** OCR processing failed for the document. 500 */ OCR_PROCESSING_FAILED, + /** A training run is already in progress. 409 */ + TRAINING_ALREADY_RUNNING, // --- Generic --- /** Request validation failed (missing or malformed fields). 400 */ diff --git a/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java b/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java new file mode 100644 index 00000000..add2a2a7 --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java @@ -0,0 +1,57 @@ +package org.raddatz.familienarchiv.model; + +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.persistence.*; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.hibernate.annotations.CreationTimestamp; + +import java.time.Instant; +import java.util.UUID; + +@Entity +@Table(name = "ocr_training_runs") +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +public class OcrTrainingRun { + + @Id + @GeneratedValue(strategy = GenerationType.UUID) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private UUID id; + + @Enumerated(EnumType.STRING) + @Column(nullable = false) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private TrainingStatus status; + + @Column(name = "block_count", nullable = false) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private int blockCount; + + @Column(name = "document_count", nullable = false) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private int documentCount; + + @Column(name = "model_name", nullable = false) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private String modelName; + + @Column(name = "error_message") + private String errorMessage; + + @Column(name = "triggered_by") + private UUID triggeredBy; + + @CreationTimestamp + @Column(name = "created_at", nullable = false, updatable = false) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED) + private Instant createdAt; + + @Column(name = "completed_at") + private Instant completedAt; +} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/model/TrainingStatus.java b/backend/src/main/java/org/raddatz/familienarchiv/model/TrainingStatus.java new file mode 100644 index 00000000..7e99dd2f --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/model/TrainingStatus.java @@ -0,0 +1,7 @@ +package org.raddatz.familienarchiv.model; + +public enum TrainingStatus { + RUNNING, + DONE, + FAILED +} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java b/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java new file mode 100644 index 00000000..0bab0e99 --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java @@ -0,0 +1,16 @@ +package org.raddatz.familienarchiv.repository; + +import org.raddatz.familienarchiv.model.OcrTrainingRun; +import org.raddatz.familienarchiv.model.TrainingStatus; +import org.springframework.data.jpa.repository.JpaRepository; + +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +public interface OcrTrainingRunRepository extends JpaRepository { + + Optional findFirstByStatus(TrainingStatus status); + + List findTop5ByOrderByCreatedAtDesc(); +} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java new file mode 100644 index 00000000..1315a2de --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java @@ -0,0 +1,147 @@ +package org.raddatz.familienarchiv.service; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.raddatz.familienarchiv.exception.DomainException; +import org.raddatz.familienarchiv.exception.ErrorCode; +import org.raddatz.familienarchiv.model.OcrTrainingRun; +import org.raddatz.familienarchiv.model.TrainingStatus; +import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; +import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.slf4j.MDC; +import org.springframework.boot.context.event.ApplicationReadyEvent; +import org.springframework.context.event.EventListener; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.io.ByteArrayOutputStream; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +@Service +@RequiredArgsConstructor +@Slf4j +public class OcrTrainingService { + + private final OcrTrainingRunRepository trainingRunRepository; + private final TrainingDataExportService trainingDataExportService; + private final OcrClient ocrClient; + private final OcrHealthClient ocrHealthClient; + private final TranscriptionBlockRepository blockRepository; + + public record TrainingInfoResponse( + int availableBlocks, + int totalOcrBlocks, + int availableDocuments, + boolean ocrServiceAvailable, + OcrTrainingRun lastRun, + List runs + ) {} + + @Transactional + public OcrTrainingRun triggerTraining(UUID triggeredBy) { + if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { + throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, + "A training run is already in progress"); + } + + var eligibleBlocks = trainingDataExportService.queryEligibleBlocks(); + if (eligibleBlocks.size() < 5) { + throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR, + "At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")"); + } + + long documentCount = eligibleBlocks.stream() + .map(b -> b.getDocumentId()) + .distinct() + .count(); + + OcrTrainingRun run = OcrTrainingRun.builder() + .status(TrainingStatus.RUNNING) + .blockCount(eligibleBlocks.size()) + .documentCount((int) documentCount) + .modelName("german_kurrent") + .triggeredBy(triggeredBy) + .build(); + run = trainingRunRepository.save(run); + + String runId = run.getId().toString(); + MDC.put("trainingRunId", runId); + log.info("Started training run {} with {} blocks from {} documents", + runId, eligibleBlocks.size(), documentCount); + + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + trainingDataExportService.exportToZip().writeTo(baos); + byte[] zipBytes = baos.toByteArray(); + + log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length); + ocrClient.trainModel(zipBytes); + + run.setStatus(TrainingStatus.DONE); + run.setCompletedAt(Instant.now()); + run = trainingRunRepository.save(run); + log.info("[trainingRun={}] Training completed successfully", runId); + } catch (Exception e) { + run.setStatus(TrainingStatus.FAILED); + run.setErrorMessage(e.getMessage()); + run.setCompletedAt(Instant.now()); + run = trainingRunRepository.save(run); + log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e); + } finally { + MDC.remove("trainingRunId"); + } + + return run; + } + + public TrainingInfoResponse getTrainingInfo() { + var eligibleBlocks = trainingDataExportService.queryEligibleBlocks(); + int availableDocuments = (int) eligibleBlocks.stream() + .map(b -> b.getDocumentId()) + .distinct() + .count(); + + int totalOcrBlocks = blockRepository.findAll().size(); + + List recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc(); + OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0); + + return new TrainingInfoResponse( + eligibleBlocks.size(), + totalOcrBlocks, + availableDocuments, + ocrHealthClient.isHealthy(), + lastRun, + recentRuns + ); + } + + @EventListener(ApplicationReadyEvent.class) + @Transactional + public void recoverOrphanedRuns() { + var cutoff = Instant.now().minusSeconds(3600); + trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).ifPresent(run -> { + if (run.getCreatedAt().isBefore(cutoff)) { + run.setStatus(TrainingStatus.FAILED); + run.setErrorMessage("Abgebrochen: Dienst wurde neugestartet"); + run.setCompletedAt(Instant.now()); + trainingRunRepository.save(run); + log.warn("Recovered orphaned training run {} (marked FAILED on startup)", run.getId()); + } + }); + } + + public Map buildTrainingInfoMap(TrainingInfoResponse info) { + return Map.of( + "availableBlocks", info.availableBlocks(), + "totalOcrBlocks", info.totalOcrBlocks(), + "availableDocuments", info.availableDocuments(), + "ocrServiceAvailable", info.ocrServiceAvailable(), + "lastRun", info.lastRun() != null ? info.lastRun() : Map.of(), + "runs", info.runs() + ); + } +} diff --git a/backend/src/main/resources/db/migration/V30__add_ocr_training_runs.sql b/backend/src/main/resources/db/migration/V30__add_ocr_training_runs.sql new file mode 100644 index 00000000..e1420d72 --- /dev/null +++ b/backend/src/main/resources/db/migration/V30__add_ocr_training_runs.sql @@ -0,0 +1,16 @@ +CREATE TABLE ocr_training_runs ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + status VARCHAR(20) NOT NULL DEFAULT 'RUNNING', + block_count INT NOT NULL, + document_count INT NOT NULL, + model_name VARCHAR(100) NOT NULL, + error_message TEXT, + triggered_by UUID REFERENCES users(id) ON DELETE SET NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + completed_at TIMESTAMPTZ +); + +-- Enforce single active run at the DB layer (application check is the UX layer) +CREATE UNIQUE INDEX idx_ocr_training_runs_one_running + ON ocr_training_runs (status) + WHERE status = 'RUNNING'; diff --git a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java index bdb3a346..075c18be 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java @@ -44,6 +44,7 @@ class OcrControllerTest { @MockitoBean UserService userService; @MockitoBean CustomUserDetailsService customUserDetailsService; @MockitoBean TrainingDataExportService trainingDataExportService; + @MockitoBean OcrTrainingService ocrTrainingService; @Test @WithMockUser(authorities = "WRITE_ALL") @@ -163,6 +164,68 @@ class OcrControllerTest { result.getResponse().getContentType()).contains("application/zip")); } + // ─── POST /api/ocr/train ─────────────────────────────────────────────────── + + @Test + void triggerTraining_returns401_whenUnauthenticated() throws Exception { + mockMvc.perform(post("/api/ocr/train")) + .andExpect(status().isUnauthorized()); + } + + @Test + @WithMockUser(authorities = "READ_ALL") + void triggerTraining_returns403_whenNotAdmin() throws Exception { + mockMvc.perform(post("/api/ocr/train")) + .andExpect(status().isForbidden()); + } + + @Test + @WithMockUser(authorities = "ADMIN") + void triggerTraining_returns409_whenRunAlreadyRunning() throws Exception { + when(ocrTrainingService.triggerTraining(any())) + .thenThrow(DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, "Already running")); + + mockMvc.perform(post("/api/ocr/train")) + .andExpect(status().isConflict()); + } + + @Test + @WithMockUser(authorities = "ADMIN") + void triggerTraining_returns201_withRunInfo() throws Exception { + UUID runId = UUID.randomUUID(); + OcrTrainingRun run = OcrTrainingRun.builder() + .id(runId).status(TrainingStatus.DONE) + .blockCount(10).documentCount(3).modelName("german_kurrent").build(); + when(ocrTrainingService.triggerTraining(any())).thenReturn(run); + + mockMvc.perform(post("/api/ocr/train")) + .andExpect(status().isCreated()) + .andExpect(jsonPath("$.status").value("DONE")) + .andExpect(jsonPath("$.blockCount").value(10)); + } + + // ─── GET /api/ocr/training-info ─────────────────────────────────────────── + + @Test + @WithMockUser(authorities = "READ_ALL") + void getTrainingInfo_returns403_whenNotAdmin() throws Exception { + mockMvc.perform(get("/api/ocr/training-info")) + .andExpect(status().isForbidden()); + } + + @Test + @WithMockUser(authorities = "ADMIN") + void getTrainingInfo_returns200_withInfo() throws Exception { + OcrTrainingService.TrainingInfoResponse info = + new OcrTrainingService.TrainingInfoResponse(5, 20, 2, true, null, List.of()); + when(ocrTrainingService.getTrainingInfo()).thenReturn(info); + + mockMvc.perform(get("/api/ocr/training-info")) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.availableBlocks").value(5)) + .andExpect(jsonPath("$.ocrServiceAvailable").value(true)); + } + @Test @WithMockUser(authorities = "READ_ALL") void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception { diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java new file mode 100644 index 00000000..d4aea29c --- /dev/null +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -0,0 +1,168 @@ +package org.raddatz.familienarchiv.service; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.raddatz.familienarchiv.exception.DomainException; +import org.raddatz.familienarchiv.model.OcrTrainingRun; +import org.raddatz.familienarchiv.model.TrainingLabel; +import org.raddatz.familienarchiv.model.TrainingStatus; +import org.raddatz.familienarchiv.model.TranscriptionBlock; +import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; +import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; + +import java.time.Instant; +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.*; + +class OcrTrainingServiceTest { + + OcrTrainingRunRepository runRepository; + TrainingDataExportService exportService; + OcrClient ocrClient; + OcrHealthClient healthClient; + TranscriptionBlockRepository blockRepository; + OcrTrainingService service; + + @BeforeEach + void setUp() { + runRepository = mock(OcrTrainingRunRepository.class); + exportService = mock(TrainingDataExportService.class); + ocrClient = mock(OcrClient.class); + healthClient = mock(OcrHealthClient.class); + blockRepository = mock(TranscriptionBlockRepository.class); + + service = new OcrTrainingService(runRepository, exportService, ocrClient, healthClient, blockRepository); + + when(blockRepository.findAll()).thenReturn(List.of()); + when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of()); + } + + // ─── Concurrent guard ───────────────────────────────────────────────────── + + @Test + void triggerTraining_throws409_whenRunningRunExists() { + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)) + .thenReturn(Optional.of(OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(2).modelName("german_kurrent").build())); + + assertThatThrownBy(() -> service.triggerTraining(null)) + .isInstanceOf(DomainException.class) + .extracting("status") + .satisfies(s -> assertThat(s.toString()).contains("409")); + } + + // ─── Threshold guard ────────────────────────────────────────────────────── + + @Test + void triggerTraining_throws422_whenFewerThan5Blocks() { + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + when(exportService.queryEligibleBlocks()).thenReturn(List.of( + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build() + )); + + assertThatThrownBy(() -> service.triggerTraining(null)) + .isInstanceOf(DomainException.class); + } + + // ─── Happy path ─────────────────────────────────────────────────────────── + + @Test + void triggerTraining_createsRunWithCorrectCounts_andMarksDone() throws Exception { + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + + UUID docA = UUID.randomUUID(); + UUID docB = UUID.randomUUID(); + List blocks = List.of( + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docB).build() + ); + when(exportService.queryEligibleBlocks()).thenReturn(blocks); + when(exportService.exportToZip()).thenReturn(out -> {}); + when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 3)); + + OcrTrainingRun saved = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(2).modelName("german_kurrent").build(); + when(runRepository.save(any())).thenReturn(saved); + + service.triggerTraining(null); + + // Verify run created with correct counts and then updated to DONE + verify(runRepository, times(2)).save(argThat(run -> + run.getBlockCount() == 5 || run.getStatus() == TrainingStatus.DONE)); + } + + // ─── Failure path ───────────────────────────────────────────────────────── + + @Test + void triggerTraining_marksRunFailed_whenOcrClientThrows() throws Exception { + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + + UUID docA = UUID.randomUUID(); + List blocks = List.of( + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build() + ); + when(exportService.queryEligibleBlocks()).thenReturn(blocks); + when(exportService.exportToZip()).thenReturn(out -> {}); + when(ocrClient.trainModel(any())).thenThrow(new RuntimeException("OCR service timeout")); + + OcrTrainingRun saved = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(1).modelName("german_kurrent").build(); + when(runRepository.save(any())).thenReturn(saved); + + service.triggerTraining(null); + + verify(runRepository, atLeastOnce()).save(argThat(run -> + run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null)); + } + + // ─── Orphan recovery ────────────────────────────────────────────────────── + + @Test + void recoverOrphanedRuns_marksRunFailed_whenOlderThanOneHour() { + OcrTrainingRun orphan = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(1).modelName("german_kurrent") + .createdAt(Instant.now().minusSeconds(7200)) + .build(); + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(orphan)); + when(runRepository.save(any())).thenReturn(orphan); + + service.recoverOrphanedRuns(); + + verify(runRepository).save(argThat(run -> + run.getStatus() == TrainingStatus.FAILED + && run.getErrorMessage().contains("Abgebrochen"))); + } + + @Test + void recoverOrphanedRuns_doesNothing_whenRunIsRecent() { + OcrTrainingRun recent = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(1).modelName("german_kurrent") + .createdAt(Instant.now().minusSeconds(60)) + .build(); + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(recent)); + + service.recoverOrphanedRuns(); + + verify(runRepository, never()).save(any()); + } +}