feat(ocr): full OCR pipeline with polygon annotations, training, and guided mode #232
@@ -7,11 +7,13 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO;
|
||||
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
|
||||
import org.raddatz.familienarchiv.model.AppUser;
|
||||
import org.raddatz.familienarchiv.model.OcrJob;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.security.Permission;
|
||||
import org.raddatz.familienarchiv.security.RequirePermission;
|
||||
import org.raddatz.familienarchiv.service.OcrBatchService;
|
||||
import org.raddatz.familienarchiv.service.OcrProgressService;
|
||||
import org.raddatz.familienarchiv.service.OcrService;
|
||||
import org.raddatz.familienarchiv.service.OcrTrainingService;
|
||||
import org.raddatz.familienarchiv.service.TrainingDataExportService;
|
||||
import org.raddatz.familienarchiv.service.UserService;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
@@ -37,6 +39,7 @@ public class OcrController {
|
||||
private final OcrProgressService ocrProgressService;
|
||||
private final UserService userService;
|
||||
private final TrainingDataExportService trainingDataExportService;
|
||||
private final OcrTrainingService ocrTrainingService;
|
||||
|
||||
@PostMapping("/api/documents/{documentId}/ocr")
|
||||
@ResponseStatus(HttpStatus.ACCEPTED)
|
||||
@@ -93,6 +96,20 @@ public class OcrController {
|
||||
.body(body);
|
||||
}
|
||||
|
||||
@PostMapping("/api/ocr/train")
|
||||
@ResponseStatus(HttpStatus.CREATED)
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public OcrTrainingRun triggerTraining(Authentication authentication) {
|
||||
UUID userId = resolveUserId(authentication);
|
||||
return ocrTrainingService.triggerTraining(userId);
|
||||
}
|
||||
|
||||
@GetMapping("/api/ocr/training-info")
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {
|
||||
return ocrTrainingService.getTrainingInfo();
|
||||
}
|
||||
|
||||
private UUID resolveUserId(Authentication authentication) {
|
||||
if (authentication == null || !authentication.isAuthenticated()) return null;
|
||||
try {
|
||||
|
||||
@@ -75,6 +75,8 @@ public enum ErrorCode {
|
||||
OCR_DOCUMENT_NOT_UPLOADED,
|
||||
/** OCR processing failed for the document. 500 */
|
||||
OCR_PROCESSING_FAILED,
|
||||
/** A training run is already in progress. 409 */
|
||||
TRAINING_ALREADY_RUNNING,
|
||||
|
||||
// --- Generic ---
|
||||
/** Request validation failed (missing or malformed fields). 400 */
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.persistence.*;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.hibernate.annotations.CreationTimestamp;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.UUID;
|
||||
|
||||
@Entity
|
||||
@Table(name = "ocr_training_runs")
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class OcrTrainingRun {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.UUID)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private UUID id;
|
||||
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Column(nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private TrainingStatus status;
|
||||
|
||||
@Column(name = "block_count", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private int blockCount;
|
||||
|
||||
@Column(name = "document_count", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private int documentCount;
|
||||
|
||||
@Column(name = "model_name", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private String modelName;
|
||||
|
||||
@Column(name = "error_message")
|
||||
private String errorMessage;
|
||||
|
||||
@Column(name = "triggered_by")
|
||||
private UUID triggeredBy;
|
||||
|
||||
@CreationTimestamp
|
||||
@Column(name = "created_at", nullable = false, updatable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private Instant createdAt;
|
||||
|
||||
@Column(name = "completed_at")
|
||||
private Instant completedAt;
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
public enum TrainingStatus {
|
||||
RUNNING,
|
||||
DONE,
|
||||
FAILED
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package org.raddatz.familienarchiv.repository;
|
||||
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun, UUID> {
|
||||
|
||||
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
|
||||
|
||||
List<OcrTrainingRun> findTop5ByOrderByCreatedAtDesc();
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.slf4j.MDC;
|
||||
import org.springframework.boot.context.event.ApplicationReadyEvent;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class OcrTrainingService {
|
||||
|
||||
private final OcrTrainingRunRepository trainingRunRepository;
|
||||
private final TrainingDataExportService trainingDataExportService;
|
||||
private final OcrClient ocrClient;
|
||||
private final OcrHealthClient ocrHealthClient;
|
||||
private final TranscriptionBlockRepository blockRepository;
|
||||
|
||||
public record TrainingInfoResponse(
|
||||
int availableBlocks,
|
||||
int totalOcrBlocks,
|
||||
int availableDocuments,
|
||||
boolean ocrServiceAvailable,
|
||||
OcrTrainingRun lastRun,
|
||||
List<OcrTrainingRun> runs
|
||||
) {}
|
||||
|
||||
@Transactional
|
||||
public OcrTrainingRun triggerTraining(UUID triggeredBy) {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
if (eligibleBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
|
||||
}
|
||||
|
||||
long documentCount = eligibleBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(eligibleBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("german_kurrent")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
run = trainingRunRepository.save(run);
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started training run {} with {} blocks from {} documents",
|
||||
runId, eligibleBlocks.size(), documentCount);
|
||||
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
trainingDataExportService.exportToZip().writeTo(baos);
|
||||
byte[] zipBytes = baos.toByteArray();
|
||||
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
|
||||
ocrClient.trainModel(zipBytes);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Training completed successfully", runId);
|
||||
} catch (Exception e) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
}
|
||||
|
||||
return run;
|
||||
}
|
||||
|
||||
public TrainingInfoResponse getTrainingInfo() {
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
int availableDocuments = (int) eligibleBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
int totalOcrBlocks = blockRepository.findAll().size();
|
||||
|
||||
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
|
||||
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
|
||||
|
||||
return new TrainingInfoResponse(
|
||||
eligibleBlocks.size(),
|
||||
totalOcrBlocks,
|
||||
availableDocuments,
|
||||
ocrHealthClient.isHealthy(),
|
||||
lastRun,
|
||||
recentRuns
|
||||
);
|
||||
}
|
||||
|
||||
@EventListener(ApplicationReadyEvent.class)
|
||||
@Transactional
|
||||
public void recoverOrphanedRuns() {
|
||||
var cutoff = Instant.now().minusSeconds(3600);
|
||||
trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).ifPresent(run -> {
|
||||
if (run.getCreatedAt().isBefore(cutoff)) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage("Abgebrochen: Dienst wurde neugestartet");
|
||||
run.setCompletedAt(Instant.now());
|
||||
trainingRunRepository.save(run);
|
||||
log.warn("Recovered orphaned training run {} (marked FAILED on startup)", run.getId());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public Map<String, Object> buildTrainingInfoMap(TrainingInfoResponse info) {
|
||||
return Map.of(
|
||||
"availableBlocks", info.availableBlocks(),
|
||||
"totalOcrBlocks", info.totalOcrBlocks(),
|
||||
"availableDocuments", info.availableDocuments(),
|
||||
"ocrServiceAvailable", info.ocrServiceAvailable(),
|
||||
"lastRun", info.lastRun() != null ? info.lastRun() : Map.of(),
|
||||
"runs", info.runs()
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
CREATE TABLE ocr_training_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'RUNNING',
|
||||
block_count INT NOT NULL,
|
||||
document_count INT NOT NULL,
|
||||
model_name VARCHAR(100) NOT NULL,
|
||||
error_message TEXT,
|
||||
triggered_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
completed_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
-- Enforce single active run at the DB layer (application check is the UX layer)
|
||||
CREATE UNIQUE INDEX idx_ocr_training_runs_one_running
|
||||
ON ocr_training_runs (status)
|
||||
WHERE status = 'RUNNING';
|
||||
@@ -44,6 +44,7 @@ class OcrControllerTest {
|
||||
@MockitoBean UserService userService;
|
||||
@MockitoBean CustomUserDetailsService customUserDetailsService;
|
||||
@MockitoBean TrainingDataExportService trainingDataExportService;
|
||||
@MockitoBean OcrTrainingService ocrTrainingService;
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
@@ -163,6 +164,68 @@ class OcrControllerTest {
|
||||
result.getResponse().getContentType()).contains("application/zip"));
|
||||
}
|
||||
|
||||
// ─── POST /api/ocr/train ───────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_returns401_whenUnauthenticated() throws Exception {
|
||||
mockMvc.perform(post("/api/ocr/train"))
|
||||
.andExpect(status().isUnauthorized());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void triggerTraining_returns403_whenNotAdmin() throws Exception {
|
||||
mockMvc.perform(post("/api/ocr/train"))
|
||||
.andExpect(status().isForbidden());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void triggerTraining_returns409_whenRunAlreadyRunning() throws Exception {
|
||||
when(ocrTrainingService.triggerTraining(any()))
|
||||
.thenThrow(DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, "Already running"));
|
||||
|
||||
mockMvc.perform(post("/api/ocr/train"))
|
||||
.andExpect(status().isConflict());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void triggerTraining_returns201_withRunInfo() throws Exception {
|
||||
UUID runId = UUID.randomUUID();
|
||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||
.id(runId).status(TrainingStatus.DONE)
|
||||
.blockCount(10).documentCount(3).modelName("german_kurrent").build();
|
||||
when(ocrTrainingService.triggerTraining(any())).thenReturn(run);
|
||||
|
||||
mockMvc.perform(post("/api/ocr/train"))
|
||||
.andExpect(status().isCreated())
|
||||
.andExpect(jsonPath("$.status").value("DONE"))
|
||||
.andExpect(jsonPath("$.blockCount").value(10));
|
||||
}
|
||||
|
||||
// ─── GET /api/ocr/training-info ───────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void getTrainingInfo_returns403_whenNotAdmin() throws Exception {
|
||||
mockMvc.perform(get("/api/ocr/training-info"))
|
||||
.andExpect(status().isForbidden());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void getTrainingInfo_returns200_withInfo() throws Exception {
|
||||
OcrTrainingService.TrainingInfoResponse info =
|
||||
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, true, null, List.of());
|
||||
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
|
||||
|
||||
mockMvc.perform(get("/api/ocr/training-info"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.availableBlocks").value(5))
|
||||
.andExpect(jsonPath("$.ocrServiceAvailable").value(true));
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingLabel;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.argThat;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class OcrTrainingServiceTest {
|
||||
|
||||
OcrTrainingRunRepository runRepository;
|
||||
TrainingDataExportService exportService;
|
||||
OcrClient ocrClient;
|
||||
OcrHealthClient healthClient;
|
||||
TranscriptionBlockRepository blockRepository;
|
||||
OcrTrainingService service;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
runRepository = mock(OcrTrainingRunRepository.class);
|
||||
exportService = mock(TrainingDataExportService.class);
|
||||
ocrClient = mock(OcrClient.class);
|
||||
healthClient = mock(OcrHealthClient.class);
|
||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||
|
||||
service = new OcrTrainingService(runRepository, exportService, ocrClient, healthClient, blockRepository);
|
||||
|
||||
when(blockRepository.findAll()).thenReturn(List.of());
|
||||
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||
}
|
||||
|
||||
// ─── Concurrent guard ─────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_throws409_whenRunningRunExists() {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING))
|
||||
.thenReturn(Optional.of(OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(2).modelName("german_kurrent").build()));
|
||||
|
||||
assertThatThrownBy(() -> service.triggerTraining(null))
|
||||
.isInstanceOf(DomainException.class)
|
||||
.extracting("status")
|
||||
.satisfies(s -> assertThat(s.toString()).contains("409"));
|
||||
}
|
||||
|
||||
// ─── Threshold guard ──────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_throws422_whenFewerThan5Blocks() {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
when(exportService.queryEligibleBlocks()).thenReturn(List.of(
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build()
|
||||
));
|
||||
|
||||
assertThatThrownBy(() -> service.triggerTraining(null))
|
||||
.isInstanceOf(DomainException.class);
|
||||
}
|
||||
|
||||
// ─── Happy path ───────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_createsRunWithCorrectCounts_andMarksDone() throws Exception {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
UUID docA = UUID.randomUUID();
|
||||
UUID docB = UUID.randomUUID();
|
||||
List<TranscriptionBlock> blocks = List.of(
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docB).build()
|
||||
);
|
||||
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
|
||||
when(exportService.exportToZip()).thenReturn(out -> {});
|
||||
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 3));
|
||||
|
||||
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(2).modelName("german_kurrent").build();
|
||||
when(runRepository.save(any())).thenReturn(saved);
|
||||
|
||||
service.triggerTraining(null);
|
||||
|
||||
// Verify run created with correct counts and then updated to DONE
|
||||
verify(runRepository, times(2)).save(argThat(run ->
|
||||
run.getBlockCount() == 5 || run.getStatus() == TrainingStatus.DONE));
|
||||
}
|
||||
|
||||
// ─── Failure path ─────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_marksRunFailed_whenOcrClientThrows() throws Exception {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
UUID docA = UUID.randomUUID();
|
||||
List<TranscriptionBlock> blocks = List.of(
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build()
|
||||
);
|
||||
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
|
||||
when(exportService.exportToZip()).thenReturn(out -> {});
|
||||
when(ocrClient.trainModel(any())).thenThrow(new RuntimeException("OCR service timeout"));
|
||||
|
||||
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent").build();
|
||||
when(runRepository.save(any())).thenReturn(saved);
|
||||
|
||||
service.triggerTraining(null);
|
||||
|
||||
verify(runRepository, atLeastOnce()).save(argThat(run ->
|
||||
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
|
||||
}
|
||||
|
||||
// ─── Orphan recovery ──────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void recoverOrphanedRuns_marksRunFailed_whenOlderThanOneHour() {
|
||||
OcrTrainingRun orphan = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent")
|
||||
.createdAt(Instant.now().minusSeconds(7200))
|
||||
.build();
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(orphan));
|
||||
when(runRepository.save(any())).thenReturn(orphan);
|
||||
|
||||
service.recoverOrphanedRuns();
|
||||
|
||||
verify(runRepository).save(argThat(run ->
|
||||
run.getStatus() == TrainingStatus.FAILED
|
||||
&& run.getErrorMessage().contains("Abgebrochen")));
|
||||
}
|
||||
|
||||
@Test
|
||||
void recoverOrphanedRuns_doesNothing_whenRunIsRecent() {
|
||||
OcrTrainingRun recent = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent")
|
||||
.createdAt(Instant.now().minusSeconds(60))
|
||||
.build();
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(recent));
|
||||
|
||||
service.recoverOrphanedRuns();
|
||||
|
||||
verify(runRepository, never()).save(any());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user