feat(ocr): add training history + POST /train + GET /training-info endpoints
- OcrTrainingRun entity + V30 migration (partial unique index prevents concurrent runs at DB level) - OcrTrainingService: concurrent-run guard, 5-block threshold, MDC log correlation, orphan recovery on ApplicationReadyEvent - POST /api/ocr/train (ADMIN) + GET /api/ocr/training-info (ADMIN) - TRAINING_ALREADY_RUNNING ErrorCode - 6 OcrTrainingServiceTest + 6 OcrControllerTest tests for the new endpoints Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,11 +7,13 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO;
|
|||||||
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
|
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
|
||||||
import org.raddatz.familienarchiv.model.AppUser;
|
import org.raddatz.familienarchiv.model.AppUser;
|
||||||
import org.raddatz.familienarchiv.model.OcrJob;
|
import org.raddatz.familienarchiv.model.OcrJob;
|
||||||
|
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||||
import org.raddatz.familienarchiv.security.Permission;
|
import org.raddatz.familienarchiv.security.Permission;
|
||||||
import org.raddatz.familienarchiv.security.RequirePermission;
|
import org.raddatz.familienarchiv.security.RequirePermission;
|
||||||
import org.raddatz.familienarchiv.service.OcrBatchService;
|
import org.raddatz.familienarchiv.service.OcrBatchService;
|
||||||
import org.raddatz.familienarchiv.service.OcrProgressService;
|
import org.raddatz.familienarchiv.service.OcrProgressService;
|
||||||
import org.raddatz.familienarchiv.service.OcrService;
|
import org.raddatz.familienarchiv.service.OcrService;
|
||||||
|
import org.raddatz.familienarchiv.service.OcrTrainingService;
|
||||||
import org.raddatz.familienarchiv.service.TrainingDataExportService;
|
import org.raddatz.familienarchiv.service.TrainingDataExportService;
|
||||||
import org.raddatz.familienarchiv.service.UserService;
|
import org.raddatz.familienarchiv.service.UserService;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
@@ -37,6 +39,7 @@ public class OcrController {
|
|||||||
private final OcrProgressService ocrProgressService;
|
private final OcrProgressService ocrProgressService;
|
||||||
private final UserService userService;
|
private final UserService userService;
|
||||||
private final TrainingDataExportService trainingDataExportService;
|
private final TrainingDataExportService trainingDataExportService;
|
||||||
|
private final OcrTrainingService ocrTrainingService;
|
||||||
|
|
||||||
@PostMapping("/api/documents/{documentId}/ocr")
|
@PostMapping("/api/documents/{documentId}/ocr")
|
||||||
@ResponseStatus(HttpStatus.ACCEPTED)
|
@ResponseStatus(HttpStatus.ACCEPTED)
|
||||||
@@ -93,6 +96,20 @@ public class OcrController {
|
|||||||
.body(body);
|
.body(body);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@PostMapping("/api/ocr/train")
|
||||||
|
@ResponseStatus(HttpStatus.CREATED)
|
||||||
|
@RequirePermission(Permission.ADMIN)
|
||||||
|
public OcrTrainingRun triggerTraining(Authentication authentication) {
|
||||||
|
UUID userId = resolveUserId(authentication);
|
||||||
|
return ocrTrainingService.triggerTraining(userId);
|
||||||
|
}
|
||||||
|
|
||||||
|
@GetMapping("/api/ocr/training-info")
|
||||||
|
@RequirePermission(Permission.ADMIN)
|
||||||
|
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {
|
||||||
|
return ocrTrainingService.getTrainingInfo();
|
||||||
|
}
|
||||||
|
|
||||||
private UUID resolveUserId(Authentication authentication) {
|
private UUID resolveUserId(Authentication authentication) {
|
||||||
if (authentication == null || !authentication.isAuthenticated()) return null;
|
if (authentication == null || !authentication.isAuthenticated()) return null;
|
||||||
try {
|
try {
|
||||||
|
|||||||
@@ -75,6 +75,8 @@ public enum ErrorCode {
|
|||||||
OCR_DOCUMENT_NOT_UPLOADED,
|
OCR_DOCUMENT_NOT_UPLOADED,
|
||||||
/** OCR processing failed for the document. 500 */
|
/** OCR processing failed for the document. 500 */
|
||||||
OCR_PROCESSING_FAILED,
|
OCR_PROCESSING_FAILED,
|
||||||
|
/** A training run is already in progress. 409 */
|
||||||
|
TRAINING_ALREADY_RUNNING,
|
||||||
|
|
||||||
// --- Generic ---
|
// --- Generic ---
|
||||||
/** Request validation failed (missing or malformed fields). 400 */
|
/** Request validation failed (missing or malformed fields). 400 */
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package org.raddatz.familienarchiv.model;
|
||||||
|
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
|
import jakarta.persistence.*;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.hibernate.annotations.CreationTimestamp;
|
||||||
|
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
@Entity
|
||||||
|
@Table(name = "ocr_training_runs")
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Builder
|
||||||
|
public class OcrTrainingRun {
|
||||||
|
|
||||||
|
@Id
|
||||||
|
@GeneratedValue(strategy = GenerationType.UUID)
|
||||||
|
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||||
|
private UUID id;
|
||||||
|
|
||||||
|
@Enumerated(EnumType.STRING)
|
||||||
|
@Column(nullable = false)
|
||||||
|
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||||
|
private TrainingStatus status;
|
||||||
|
|
||||||
|
@Column(name = "block_count", nullable = false)
|
||||||
|
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||||
|
private int blockCount;
|
||||||
|
|
||||||
|
@Column(name = "document_count", nullable = false)
|
||||||
|
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||||
|
private int documentCount;
|
||||||
|
|
||||||
|
@Column(name = "model_name", nullable = false)
|
||||||
|
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||||
|
private String modelName;
|
||||||
|
|
||||||
|
@Column(name = "error_message")
|
||||||
|
private String errorMessage;
|
||||||
|
|
||||||
|
@Column(name = "triggered_by")
|
||||||
|
private UUID triggeredBy;
|
||||||
|
|
||||||
|
@CreationTimestamp
|
||||||
|
@Column(name = "created_at", nullable = false, updatable = false)
|
||||||
|
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||||
|
private Instant createdAt;
|
||||||
|
|
||||||
|
@Column(name = "completed_at")
|
||||||
|
private Instant completedAt;
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package org.raddatz.familienarchiv.model;
|
||||||
|
|
||||||
|
public enum TrainingStatus {
|
||||||
|
RUNNING,
|
||||||
|
DONE,
|
||||||
|
FAILED
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package org.raddatz.familienarchiv.repository;
|
||||||
|
|
||||||
|
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||||
|
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||||
|
import org.springframework.data.jpa.repository.JpaRepository;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun, UUID> {
|
||||||
|
|
||||||
|
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
|
||||||
|
|
||||||
|
List<OcrTrainingRun> findTop5ByOrderByCreatedAtDesc();
|
||||||
|
}
|
||||||
@@ -0,0 +1,147 @@
|
|||||||
|
package org.raddatz.familienarchiv.service;
|
||||||
|
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.raddatz.familienarchiv.exception.DomainException;
|
||||||
|
import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||||
|
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||||
|
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||||
|
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||||
|
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||||
|
import org.slf4j.MDC;
|
||||||
|
import org.springframework.boot.context.event.ApplicationReadyEvent;
|
||||||
|
import org.springframework.context.event.EventListener;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
@Slf4j
|
||||||
|
public class OcrTrainingService {
|
||||||
|
|
||||||
|
private final OcrTrainingRunRepository trainingRunRepository;
|
||||||
|
private final TrainingDataExportService trainingDataExportService;
|
||||||
|
private final OcrClient ocrClient;
|
||||||
|
private final OcrHealthClient ocrHealthClient;
|
||||||
|
private final TranscriptionBlockRepository blockRepository;
|
||||||
|
|
||||||
|
public record TrainingInfoResponse(
|
||||||
|
int availableBlocks,
|
||||||
|
int totalOcrBlocks,
|
||||||
|
int availableDocuments,
|
||||||
|
boolean ocrServiceAvailable,
|
||||||
|
OcrTrainingRun lastRun,
|
||||||
|
List<OcrTrainingRun> runs
|
||||||
|
) {}
|
||||||
|
|
||||||
|
@Transactional
|
||||||
|
public OcrTrainingRun triggerTraining(UUID triggeredBy) {
|
||||||
|
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||||
|
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||||
|
"A training run is already in progress");
|
||||||
|
}
|
||||||
|
|
||||||
|
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||||
|
if (eligibleBlocks.size() < 5) {
|
||||||
|
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||||
|
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
long documentCount = eligibleBlocks.stream()
|
||||||
|
.map(b -> b.getDocumentId())
|
||||||
|
.distinct()
|
||||||
|
.count();
|
||||||
|
|
||||||
|
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||||
|
.status(TrainingStatus.RUNNING)
|
||||||
|
.blockCount(eligibleBlocks.size())
|
||||||
|
.documentCount((int) documentCount)
|
||||||
|
.modelName("german_kurrent")
|
||||||
|
.triggeredBy(triggeredBy)
|
||||||
|
.build();
|
||||||
|
run = trainingRunRepository.save(run);
|
||||||
|
|
||||||
|
String runId = run.getId().toString();
|
||||||
|
MDC.put("trainingRunId", runId);
|
||||||
|
log.info("Started training run {} with {} blocks from {} documents",
|
||||||
|
runId, eligibleBlocks.size(), documentCount);
|
||||||
|
|
||||||
|
try {
|
||||||
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
|
trainingDataExportService.exportToZip().writeTo(baos);
|
||||||
|
byte[] zipBytes = baos.toByteArray();
|
||||||
|
|
||||||
|
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
|
||||||
|
ocrClient.trainModel(zipBytes);
|
||||||
|
|
||||||
|
run.setStatus(TrainingStatus.DONE);
|
||||||
|
run.setCompletedAt(Instant.now());
|
||||||
|
run = trainingRunRepository.save(run);
|
||||||
|
log.info("[trainingRun={}] Training completed successfully", runId);
|
||||||
|
} catch (Exception e) {
|
||||||
|
run.setStatus(TrainingStatus.FAILED);
|
||||||
|
run.setErrorMessage(e.getMessage());
|
||||||
|
run.setCompletedAt(Instant.now());
|
||||||
|
run = trainingRunRepository.save(run);
|
||||||
|
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
|
||||||
|
} finally {
|
||||||
|
MDC.remove("trainingRunId");
|
||||||
|
}
|
||||||
|
|
||||||
|
return run;
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainingInfoResponse getTrainingInfo() {
|
||||||
|
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||||
|
int availableDocuments = (int) eligibleBlocks.stream()
|
||||||
|
.map(b -> b.getDocumentId())
|
||||||
|
.distinct()
|
||||||
|
.count();
|
||||||
|
|
||||||
|
int totalOcrBlocks = blockRepository.findAll().size();
|
||||||
|
|
||||||
|
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
|
||||||
|
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
|
||||||
|
|
||||||
|
return new TrainingInfoResponse(
|
||||||
|
eligibleBlocks.size(),
|
||||||
|
totalOcrBlocks,
|
||||||
|
availableDocuments,
|
||||||
|
ocrHealthClient.isHealthy(),
|
||||||
|
lastRun,
|
||||||
|
recentRuns
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@EventListener(ApplicationReadyEvent.class)
|
||||||
|
@Transactional
|
||||||
|
public void recoverOrphanedRuns() {
|
||||||
|
var cutoff = Instant.now().minusSeconds(3600);
|
||||||
|
trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).ifPresent(run -> {
|
||||||
|
if (run.getCreatedAt().isBefore(cutoff)) {
|
||||||
|
run.setStatus(TrainingStatus.FAILED);
|
||||||
|
run.setErrorMessage("Abgebrochen: Dienst wurde neugestartet");
|
||||||
|
run.setCompletedAt(Instant.now());
|
||||||
|
trainingRunRepository.save(run);
|
||||||
|
log.warn("Recovered orphaned training run {} (marked FAILED on startup)", run.getId());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<String, Object> buildTrainingInfoMap(TrainingInfoResponse info) {
|
||||||
|
return Map.of(
|
||||||
|
"availableBlocks", info.availableBlocks(),
|
||||||
|
"totalOcrBlocks", info.totalOcrBlocks(),
|
||||||
|
"availableDocuments", info.availableDocuments(),
|
||||||
|
"ocrServiceAvailable", info.ocrServiceAvailable(),
|
||||||
|
"lastRun", info.lastRun() != null ? info.lastRun() : Map.of(),
|
||||||
|
"runs", info.runs()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
CREATE TABLE ocr_training_runs (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
status VARCHAR(20) NOT NULL DEFAULT 'RUNNING',
|
||||||
|
block_count INT NOT NULL,
|
||||||
|
document_count INT NOT NULL,
|
||||||
|
model_name VARCHAR(100) NOT NULL,
|
||||||
|
error_message TEXT,
|
||||||
|
triggered_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
completed_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Enforce single active run at the DB layer (application check is the UX layer)
|
||||||
|
CREATE UNIQUE INDEX idx_ocr_training_runs_one_running
|
||||||
|
ON ocr_training_runs (status)
|
||||||
|
WHERE status = 'RUNNING';
|
||||||
@@ -44,6 +44,7 @@ class OcrControllerTest {
|
|||||||
@MockitoBean UserService userService;
|
@MockitoBean UserService userService;
|
||||||
@MockitoBean CustomUserDetailsService customUserDetailsService;
|
@MockitoBean CustomUserDetailsService customUserDetailsService;
|
||||||
@MockitoBean TrainingDataExportService trainingDataExportService;
|
@MockitoBean TrainingDataExportService trainingDataExportService;
|
||||||
|
@MockitoBean OcrTrainingService ocrTrainingService;
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@WithMockUser(authorities = "WRITE_ALL")
|
@WithMockUser(authorities = "WRITE_ALL")
|
||||||
@@ -163,6 +164,68 @@ class OcrControllerTest {
|
|||||||
result.getResponse().getContentType()).contains("application/zip"));
|
result.getResponse().getContentType()).contains("application/zip"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─── POST /api/ocr/train ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void triggerTraining_returns401_whenUnauthenticated() throws Exception {
|
||||||
|
mockMvc.perform(post("/api/ocr/train"))
|
||||||
|
.andExpect(status().isUnauthorized());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@WithMockUser(authorities = "READ_ALL")
|
||||||
|
void triggerTraining_returns403_whenNotAdmin() throws Exception {
|
||||||
|
mockMvc.perform(post("/api/ocr/train"))
|
||||||
|
.andExpect(status().isForbidden());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@WithMockUser(authorities = "ADMIN")
|
||||||
|
void triggerTraining_returns409_whenRunAlreadyRunning() throws Exception {
|
||||||
|
when(ocrTrainingService.triggerTraining(any()))
|
||||||
|
.thenThrow(DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, "Already running"));
|
||||||
|
|
||||||
|
mockMvc.perform(post("/api/ocr/train"))
|
||||||
|
.andExpect(status().isConflict());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@WithMockUser(authorities = "ADMIN")
|
||||||
|
void triggerTraining_returns201_withRunInfo() throws Exception {
|
||||||
|
UUID runId = UUID.randomUUID();
|
||||||
|
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||||
|
.id(runId).status(TrainingStatus.DONE)
|
||||||
|
.blockCount(10).documentCount(3).modelName("german_kurrent").build();
|
||||||
|
when(ocrTrainingService.triggerTraining(any())).thenReturn(run);
|
||||||
|
|
||||||
|
mockMvc.perform(post("/api/ocr/train"))
|
||||||
|
.andExpect(status().isCreated())
|
||||||
|
.andExpect(jsonPath("$.status").value("DONE"))
|
||||||
|
.andExpect(jsonPath("$.blockCount").value(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── GET /api/ocr/training-info ───────────────────────────────────────────
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@WithMockUser(authorities = "READ_ALL")
|
||||||
|
void getTrainingInfo_returns403_whenNotAdmin() throws Exception {
|
||||||
|
mockMvc.perform(get("/api/ocr/training-info"))
|
||||||
|
.andExpect(status().isForbidden());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@WithMockUser(authorities = "ADMIN")
|
||||||
|
void getTrainingInfo_returns200_withInfo() throws Exception {
|
||||||
|
OcrTrainingService.TrainingInfoResponse info =
|
||||||
|
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, true, null, List.of());
|
||||||
|
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
|
||||||
|
|
||||||
|
mockMvc.perform(get("/api/ocr/training-info"))
|
||||||
|
.andExpect(status().isOk())
|
||||||
|
.andExpect(jsonPath("$.availableBlocks").value(5))
|
||||||
|
.andExpect(jsonPath("$.ocrServiceAvailable").value(true));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@WithMockUser(authorities = "READ_ALL")
|
@WithMockUser(authorities = "READ_ALL")
|
||||||
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {
|
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {
|
||||||
|
|||||||
@@ -0,0 +1,168 @@
|
|||||||
|
package org.raddatz.familienarchiv.service;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.raddatz.familienarchiv.exception.DomainException;
|
||||||
|
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||||
|
import org.raddatz.familienarchiv.model.TrainingLabel;
|
||||||
|
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||||
|
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||||
|
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||||
|
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||||
|
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.argThat;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
class OcrTrainingServiceTest {
|
||||||
|
|
||||||
|
OcrTrainingRunRepository runRepository;
|
||||||
|
TrainingDataExportService exportService;
|
||||||
|
OcrClient ocrClient;
|
||||||
|
OcrHealthClient healthClient;
|
||||||
|
TranscriptionBlockRepository blockRepository;
|
||||||
|
OcrTrainingService service;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
runRepository = mock(OcrTrainingRunRepository.class);
|
||||||
|
exportService = mock(TrainingDataExportService.class);
|
||||||
|
ocrClient = mock(OcrClient.class);
|
||||||
|
healthClient = mock(OcrHealthClient.class);
|
||||||
|
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||||
|
|
||||||
|
service = new OcrTrainingService(runRepository, exportService, ocrClient, healthClient, blockRepository);
|
||||||
|
|
||||||
|
when(blockRepository.findAll()).thenReturn(List.of());
|
||||||
|
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Concurrent guard ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void triggerTraining_throws409_whenRunningRunExists() {
|
||||||
|
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING))
|
||||||
|
.thenReturn(Optional.of(OcrTrainingRun.builder()
|
||||||
|
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||||
|
.blockCount(5).documentCount(2).modelName("german_kurrent").build()));
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> service.triggerTraining(null))
|
||||||
|
.isInstanceOf(DomainException.class)
|
||||||
|
.extracting("status")
|
||||||
|
.satisfies(s -> assertThat(s.toString()).contains("409"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Threshold guard ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void triggerTraining_throws422_whenFewerThan5Blocks() {
|
||||||
|
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||||
|
when(exportService.queryEligibleBlocks()).thenReturn(List.of(
|
||||||
|
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build(),
|
||||||
|
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build()
|
||||||
|
));
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> service.triggerTraining(null))
|
||||||
|
.isInstanceOf(DomainException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Happy path ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void triggerTraining_createsRunWithCorrectCounts_andMarksDone() throws Exception {
|
||||||
|
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||||
|
|
||||||
|
UUID docA = UUID.randomUUID();
|
||||||
|
UUID docB = UUID.randomUUID();
|
||||||
|
List<TranscriptionBlock> blocks = List.of(
|
||||||
|
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||||
|
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||||
|
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||||
|
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||||
|
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docB).build()
|
||||||
|
);
|
||||||
|
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
|
||||||
|
when(exportService.exportToZip()).thenReturn(out -> {});
|
||||||
|
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 3));
|
||||||
|
|
||||||
|
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||||
|
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||||
|
.blockCount(5).documentCount(2).modelName("german_kurrent").build();
|
||||||
|
when(runRepository.save(any())).thenReturn(saved);
|
||||||
|
|
||||||
|
service.triggerTraining(null);
|
||||||
|
|
||||||
|
// Verify run created with correct counts and then updated to DONE
|
||||||
|
verify(runRepository, times(2)).save(argThat(run ->
|
||||||
|
run.getBlockCount() == 5 || run.getStatus() == TrainingStatus.DONE));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Failure path ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void triggerTraining_marksRunFailed_whenOcrClientThrows() throws Exception {
|
||||||
|
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||||
|
|
||||||
|
UUID docA = UUID.randomUUID();
|
||||||
|
List<TranscriptionBlock> blocks = List.of(
|
||||||
|
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||||
|
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||||
|
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||||
|
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||||
|
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build()
|
||||||
|
);
|
||||||
|
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
|
||||||
|
when(exportService.exportToZip()).thenReturn(out -> {});
|
||||||
|
when(ocrClient.trainModel(any())).thenThrow(new RuntimeException("OCR service timeout"));
|
||||||
|
|
||||||
|
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||||
|
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||||
|
.blockCount(5).documentCount(1).modelName("german_kurrent").build();
|
||||||
|
when(runRepository.save(any())).thenReturn(saved);
|
||||||
|
|
||||||
|
service.triggerTraining(null);
|
||||||
|
|
||||||
|
verify(runRepository, atLeastOnce()).save(argThat(run ->
|
||||||
|
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Orphan recovery ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void recoverOrphanedRuns_marksRunFailed_whenOlderThanOneHour() {
|
||||||
|
OcrTrainingRun orphan = OcrTrainingRun.builder()
|
||||||
|
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||||
|
.blockCount(5).documentCount(1).modelName("german_kurrent")
|
||||||
|
.createdAt(Instant.now().minusSeconds(7200))
|
||||||
|
.build();
|
||||||
|
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(orphan));
|
||||||
|
when(runRepository.save(any())).thenReturn(orphan);
|
||||||
|
|
||||||
|
service.recoverOrphanedRuns();
|
||||||
|
|
||||||
|
verify(runRepository).save(argThat(run ->
|
||||||
|
run.getStatus() == TrainingStatus.FAILED
|
||||||
|
&& run.getErrorMessage().contains("Abgebrochen")));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void recoverOrphanedRuns_doesNothing_whenRunIsRecent() {
|
||||||
|
OcrTrainingRun recent = OcrTrainingRun.builder()
|
||||||
|
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||||
|
.blockCount(5).documentCount(1).modelName("german_kurrent")
|
||||||
|
.createdAt(Instant.now().minusSeconds(60))
|
||||||
|
.build();
|
||||||
|
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(recent));
|
||||||
|
|
||||||
|
service.recoverOrphanedRuns();
|
||||||
|
|
||||||
|
verify(runRepository, never()).save(any());
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user