feat(ocr): full OCR pipeline with polygon annotations, training, and guided mode #232

Merged
marcel merged 40 commits from feat/issue-226-227-ocr-pipeline-polygon into main 2026-04-14 10:31:35 +02:00
9 changed files with 493 additions and 0 deletions
Showing only changes of commit 88e005eb49 - Show all commits

View File

@@ -7,11 +7,13 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO;
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
import org.raddatz.familienarchiv.model.AppUser;
import org.raddatz.familienarchiv.model.OcrJob;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.security.Permission;
import org.raddatz.familienarchiv.security.RequirePermission;
import org.raddatz.familienarchiv.service.OcrBatchService;
import org.raddatz.familienarchiv.service.OcrProgressService;
import org.raddatz.familienarchiv.service.OcrService;
import org.raddatz.familienarchiv.service.OcrTrainingService;
import org.raddatz.familienarchiv.service.TrainingDataExportService;
import org.raddatz.familienarchiv.service.UserService;
import org.springframework.http.HttpHeaders;
@@ -37,6 +39,7 @@ public class OcrController {
private final OcrProgressService ocrProgressService;
private final UserService userService;
private final TrainingDataExportService trainingDataExportService;
private final OcrTrainingService ocrTrainingService;
@PostMapping("/api/documents/{documentId}/ocr")
@ResponseStatus(HttpStatus.ACCEPTED)
@@ -93,6 +96,20 @@ public class OcrController {
.body(body);
}
@PostMapping("/api/ocr/train")
@ResponseStatus(HttpStatus.CREATED)
@RequirePermission(Permission.ADMIN)
public OcrTrainingRun triggerTraining(Authentication authentication) {
UUID userId = resolveUserId(authentication);
return ocrTrainingService.triggerTraining(userId);
}
@GetMapping("/api/ocr/training-info")
@RequirePermission(Permission.ADMIN)
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {
return ocrTrainingService.getTrainingInfo();
}
private UUID resolveUserId(Authentication authentication) {
if (authentication == null || !authentication.isAuthenticated()) return null;
try {

View File

@@ -75,6 +75,8 @@ public enum ErrorCode {
OCR_DOCUMENT_NOT_UPLOADED,
/** OCR processing failed for the document. 500 */
OCR_PROCESSING_FAILED,
/** A training run is already in progress. 409 */
TRAINING_ALREADY_RUNNING,
// --- Generic ---
/** Request validation failed (missing or malformed fields). 400 */

View File

@@ -0,0 +1,57 @@
package org.raddatz.familienarchiv.model;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.persistence.*;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.hibernate.annotations.CreationTimestamp;
import java.time.Instant;
import java.util.UUID;
@Entity
@Table(name = "ocr_training_runs")
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class OcrTrainingRun {
@Id
@GeneratedValue(strategy = GenerationType.UUID)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID id;
@Enumerated(EnumType.STRING)
@Column(nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private TrainingStatus status;
@Column(name = "block_count", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private int blockCount;
@Column(name = "document_count", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private int documentCount;
@Column(name = "model_name", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private String modelName;
@Column(name = "error_message")
private String errorMessage;
@Column(name = "triggered_by")
private UUID triggeredBy;
@CreationTimestamp
@Column(name = "created_at", nullable = false, updatable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private Instant createdAt;
@Column(name = "completed_at")
private Instant completedAt;
}

View File

@@ -0,0 +1,7 @@
package org.raddatz.familienarchiv.model;
public enum TrainingStatus {
RUNNING,
DONE,
FAILED
}

View File

@@ -0,0 +1,16 @@
package org.raddatz.familienarchiv.repository;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.springframework.data.jpa.repository.JpaRepository;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun, UUID> {
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
List<OcrTrainingRun> findTop5ByOrderByCreatedAtDesc();
}

View File

@@ -0,0 +1,147 @@
package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.slf4j.MDC;
import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.io.ByteArrayOutputStream;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.UUID;
@Service
@RequiredArgsConstructor
@Slf4j
public class OcrTrainingService {
private final OcrTrainingRunRepository trainingRunRepository;
private final TrainingDataExportService trainingDataExportService;
private final OcrClient ocrClient;
private final OcrHealthClient ocrHealthClient;
private final TranscriptionBlockRepository blockRepository;
public record TrainingInfoResponse(
int availableBlocks,
int totalOcrBlocks,
int availableDocuments,
boolean ocrServiceAvailable,
OcrTrainingRun lastRun,
List<OcrTrainingRun> runs
) {}
@Transactional
public OcrTrainingRun triggerTraining(UUID triggeredBy) {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
if (eligibleBlocks.size() < 5) {
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
}
long documentCount = eligibleBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
OcrTrainingRun run = OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.blockCount(eligibleBlocks.size())
.documentCount((int) documentCount)
.modelName("german_kurrent")
.triggeredBy(triggeredBy)
.build();
run = trainingRunRepository.save(run);
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started training run {} with {} blocks from {} documents",
runId, eligibleBlocks.size(), documentCount);
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
trainingDataExportService.exportToZip().writeTo(baos);
byte[] zipBytes = baos.toByteArray();
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
ocrClient.trainModel(zipBytes);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run = trainingRunRepository.save(run);
log.info("[trainingRun={}] Training completed successfully", runId);
} catch (Exception e) {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
run = trainingRunRepository.save(run);
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
} finally {
MDC.remove("trainingRunId");
}
return run;
}
public TrainingInfoResponse getTrainingInfo() {
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
int availableDocuments = (int) eligibleBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
int totalOcrBlocks = blockRepository.findAll().size();
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
return new TrainingInfoResponse(
eligibleBlocks.size(),
totalOcrBlocks,
availableDocuments,
ocrHealthClient.isHealthy(),
lastRun,
recentRuns
);
}
@EventListener(ApplicationReadyEvent.class)
@Transactional
public void recoverOrphanedRuns() {
var cutoff = Instant.now().minusSeconds(3600);
trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).ifPresent(run -> {
if (run.getCreatedAt().isBefore(cutoff)) {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage("Abgebrochen: Dienst wurde neugestartet");
run.setCompletedAt(Instant.now());
trainingRunRepository.save(run);
log.warn("Recovered orphaned training run {} (marked FAILED on startup)", run.getId());
}
});
}
public Map<String, Object> buildTrainingInfoMap(TrainingInfoResponse info) {
return Map.of(
"availableBlocks", info.availableBlocks(),
"totalOcrBlocks", info.totalOcrBlocks(),
"availableDocuments", info.availableDocuments(),
"ocrServiceAvailable", info.ocrServiceAvailable(),
"lastRun", info.lastRun() != null ? info.lastRun() : Map.of(),
"runs", info.runs()
);
}
}

View File

@@ -0,0 +1,16 @@
CREATE TABLE ocr_training_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
status VARCHAR(20) NOT NULL DEFAULT 'RUNNING',
block_count INT NOT NULL,
document_count INT NOT NULL,
model_name VARCHAR(100) NOT NULL,
error_message TEXT,
triggered_by UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
completed_at TIMESTAMPTZ
);
-- Enforce single active run at the DB layer (application check is the UX layer)
CREATE UNIQUE INDEX idx_ocr_training_runs_one_running
ON ocr_training_runs (status)
WHERE status = 'RUNNING';

View File

@@ -44,6 +44,7 @@ class OcrControllerTest {
@MockitoBean UserService userService;
@MockitoBean CustomUserDetailsService customUserDetailsService;
@MockitoBean TrainingDataExportService trainingDataExportService;
@MockitoBean OcrTrainingService ocrTrainingService;
@Test
@WithMockUser(authorities = "WRITE_ALL")
@@ -163,6 +164,68 @@ class OcrControllerTest {
result.getResponse().getContentType()).contains("application/zip"));
}
// ─── POST /api/ocr/train ───────────────────────────────────────────────────
@Test
void triggerTraining_returns401_whenUnauthenticated() throws Exception {
mockMvc.perform(post("/api/ocr/train"))
.andExpect(status().isUnauthorized());
}
@Test
@WithMockUser(authorities = "READ_ALL")
void triggerTraining_returns403_whenNotAdmin() throws Exception {
mockMvc.perform(post("/api/ocr/train"))
.andExpect(status().isForbidden());
}
@Test
@WithMockUser(authorities = "ADMIN")
void triggerTraining_returns409_whenRunAlreadyRunning() throws Exception {
when(ocrTrainingService.triggerTraining(any()))
.thenThrow(DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, "Already running"));
mockMvc.perform(post("/api/ocr/train"))
.andExpect(status().isConflict());
}
@Test
@WithMockUser(authorities = "ADMIN")
void triggerTraining_returns201_withRunInfo() throws Exception {
UUID runId = UUID.randomUUID();
OcrTrainingRun run = OcrTrainingRun.builder()
.id(runId).status(TrainingStatus.DONE)
.blockCount(10).documentCount(3).modelName("german_kurrent").build();
when(ocrTrainingService.triggerTraining(any())).thenReturn(run);
mockMvc.perform(post("/api/ocr/train"))
.andExpect(status().isCreated())
.andExpect(jsonPath("$.status").value("DONE"))
.andExpect(jsonPath("$.blockCount").value(10));
}
// ─── GET /api/ocr/training-info ───────────────────────────────────────────
@Test
@WithMockUser(authorities = "READ_ALL")
void getTrainingInfo_returns403_whenNotAdmin() throws Exception {
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isForbidden());
}
@Test
@WithMockUser(authorities = "ADMIN")
void getTrainingInfo_returns200_withInfo() throws Exception {
OcrTrainingService.TrainingInfoResponse info =
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, true, null, List.of());
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.availableBlocks").value(5))
.andExpect(jsonPath("$.ocrServiceAvailable").value(true));
}
@Test
@WithMockUser(authorities = "READ_ALL")
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {

View File

@@ -0,0 +1,168 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingLabel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.*;
class OcrTrainingServiceTest {
OcrTrainingRunRepository runRepository;
TrainingDataExportService exportService;
OcrClient ocrClient;
OcrHealthClient healthClient;
TranscriptionBlockRepository blockRepository;
OcrTrainingService service;
@BeforeEach
void setUp() {
runRepository = mock(OcrTrainingRunRepository.class);
exportService = mock(TrainingDataExportService.class);
ocrClient = mock(OcrClient.class);
healthClient = mock(OcrHealthClient.class);
blockRepository = mock(TranscriptionBlockRepository.class);
service = new OcrTrainingService(runRepository, exportService, ocrClient, healthClient, blockRepository);
when(blockRepository.findAll()).thenReturn(List.of());
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
}
// ─── Concurrent guard ─────────────────────────────────────────────────────
@Test
void triggerTraining_throws409_whenRunningRunExists() {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING))
.thenReturn(Optional.of(OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(2).modelName("german_kurrent").build()));
assertThatThrownBy(() -> service.triggerTraining(null))
.isInstanceOf(DomainException.class)
.extracting("status")
.satisfies(s -> assertThat(s.toString()).contains("409"));
}
// ─── Threshold guard ──────────────────────────────────────────────────────
@Test
void triggerTraining_throws422_whenFewerThan5Blocks() {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
when(exportService.queryEligibleBlocks()).thenReturn(List.of(
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build()
));
assertThatThrownBy(() -> service.triggerTraining(null))
.isInstanceOf(DomainException.class);
}
// ─── Happy path ───────────────────────────────────────────────────────────
@Test
void triggerTraining_createsRunWithCorrectCounts_andMarksDone() throws Exception {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
UUID docA = UUID.randomUUID();
UUID docB = UUID.randomUUID();
List<TranscriptionBlock> blocks = List.of(
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docB).build()
);
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
when(exportService.exportToZip()).thenReturn(out -> {});
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 3));
OcrTrainingRun saved = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(2).modelName("german_kurrent").build();
when(runRepository.save(any())).thenReturn(saved);
service.triggerTraining(null);
// Verify run created with correct counts and then updated to DONE
verify(runRepository, times(2)).save(argThat(run ->
run.getBlockCount() == 5 || run.getStatus() == TrainingStatus.DONE));
}
// ─── Failure path ─────────────────────────────────────────────────────────
@Test
void triggerTraining_marksRunFailed_whenOcrClientThrows() throws Exception {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
UUID docA = UUID.randomUUID();
List<TranscriptionBlock> blocks = List.of(
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build()
);
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
when(exportService.exportToZip()).thenReturn(out -> {});
when(ocrClient.trainModel(any())).thenThrow(new RuntimeException("OCR service timeout"));
OcrTrainingRun saved = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent").build();
when(runRepository.save(any())).thenReturn(saved);
service.triggerTraining(null);
verify(runRepository, atLeastOnce()).save(argThat(run ->
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
}
// ─── Orphan recovery ──────────────────────────────────────────────────────
@Test
void recoverOrphanedRuns_marksRunFailed_whenOlderThanOneHour() {
OcrTrainingRun orphan = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent")
.createdAt(Instant.now().minusSeconds(7200))
.build();
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(orphan));
when(runRepository.save(any())).thenReturn(orphan);
service.recoverOrphanedRuns();
verify(runRepository).save(argThat(run ->
run.getStatus() == TrainingStatus.FAILED
&& run.getErrorMessage().contains("Abgebrochen")));
}
@Test
void recoverOrphanedRuns_doesNothing_whenRunIsRecent() {
OcrTrainingRun recent = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent")
.createdAt(Instant.now().minusSeconds(60))
.build();
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(recent));
service.recoverOrphanedRuns();
verify(runRepository, never()).save(any());
}
}