feat(ocr): add training history + POST /train + GET /training-info endpoints

- OcrTrainingRun entity + V30 migration (partial unique index prevents
  concurrent runs at DB level)
- OcrTrainingService: concurrent-run guard, 5-block threshold, MDC log
  correlation, orphan recovery on ApplicationReadyEvent
- POST /api/ocr/train (ADMIN) + GET /api/ocr/training-info (ADMIN)
- TRAINING_ALREADY_RUNNING ErrorCode
- 6 OcrTrainingServiceTest + 6 OcrControllerTest tests for the new endpoints

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-13 14:47:56 +02:00
parent bc97a2dade
commit 88e005eb49
9 changed files with 493 additions and 0 deletions

View File

@@ -7,11 +7,13 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO;
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
import org.raddatz.familienarchiv.model.AppUser;
import org.raddatz.familienarchiv.model.OcrJob;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.security.Permission;
import org.raddatz.familienarchiv.security.RequirePermission;
import org.raddatz.familienarchiv.service.OcrBatchService;
import org.raddatz.familienarchiv.service.OcrProgressService;
import org.raddatz.familienarchiv.service.OcrService;
import org.raddatz.familienarchiv.service.OcrTrainingService;
import org.raddatz.familienarchiv.service.TrainingDataExportService;
import org.raddatz.familienarchiv.service.UserService;
import org.springframework.http.HttpHeaders;
@@ -37,6 +39,7 @@ public class OcrController {
private final OcrProgressService ocrProgressService;
private final UserService userService;
private final TrainingDataExportService trainingDataExportService;
private final OcrTrainingService ocrTrainingService;
@PostMapping("/api/documents/{documentId}/ocr")
@ResponseStatus(HttpStatus.ACCEPTED)
@@ -93,6 +96,20 @@ public class OcrController {
.body(body);
}
@PostMapping("/api/ocr/train")
@ResponseStatus(HttpStatus.CREATED)
@RequirePermission(Permission.ADMIN)
public OcrTrainingRun triggerTraining(Authentication authentication) {
UUID userId = resolveUserId(authentication);
return ocrTrainingService.triggerTraining(userId);
}
@GetMapping("/api/ocr/training-info")
@RequirePermission(Permission.ADMIN)
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {
return ocrTrainingService.getTrainingInfo();
}
private UUID resolveUserId(Authentication authentication) {
if (authentication == null || !authentication.isAuthenticated()) return null;
try {

View File

@@ -75,6 +75,8 @@ public enum ErrorCode {
OCR_DOCUMENT_NOT_UPLOADED,
/** OCR processing failed for the document. 500 */
OCR_PROCESSING_FAILED,
/** A training run is already in progress. 409 */
TRAINING_ALREADY_RUNNING,
// --- Generic ---
/** Request validation failed (missing or malformed fields). 400 */

View File

@@ -0,0 +1,57 @@
package org.raddatz.familienarchiv.model;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.persistence.*;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.hibernate.annotations.CreationTimestamp;
import java.time.Instant;
import java.util.UUID;
@Entity
@Table(name = "ocr_training_runs")
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class OcrTrainingRun {
@Id
@GeneratedValue(strategy = GenerationType.UUID)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID id;
@Enumerated(EnumType.STRING)
@Column(nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private TrainingStatus status;
@Column(name = "block_count", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private int blockCount;
@Column(name = "document_count", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private int documentCount;
@Column(name = "model_name", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private String modelName;
@Column(name = "error_message")
private String errorMessage;
@Column(name = "triggered_by")
private UUID triggeredBy;
@CreationTimestamp
@Column(name = "created_at", nullable = false, updatable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private Instant createdAt;
@Column(name = "completed_at")
private Instant completedAt;
}

View File

@@ -0,0 +1,7 @@
package org.raddatz.familienarchiv.model;
public enum TrainingStatus {
RUNNING,
DONE,
FAILED
}

View File

@@ -0,0 +1,16 @@
package org.raddatz.familienarchiv.repository;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.springframework.data.jpa.repository.JpaRepository;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun, UUID> {
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
List<OcrTrainingRun> findTop5ByOrderByCreatedAtDesc();
}

View File

@@ -0,0 +1,147 @@
package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.slf4j.MDC;
import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.io.ByteArrayOutputStream;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.UUID;
@Service
@RequiredArgsConstructor
@Slf4j
public class OcrTrainingService {
private final OcrTrainingRunRepository trainingRunRepository;
private final TrainingDataExportService trainingDataExportService;
private final OcrClient ocrClient;
private final OcrHealthClient ocrHealthClient;
private final TranscriptionBlockRepository blockRepository;
public record TrainingInfoResponse(
int availableBlocks,
int totalOcrBlocks,
int availableDocuments,
boolean ocrServiceAvailable,
OcrTrainingRun lastRun,
List<OcrTrainingRun> runs
) {}
@Transactional
public OcrTrainingRun triggerTraining(UUID triggeredBy) {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
if (eligibleBlocks.size() < 5) {
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
}
long documentCount = eligibleBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
OcrTrainingRun run = OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.blockCount(eligibleBlocks.size())
.documentCount((int) documentCount)
.modelName("german_kurrent")
.triggeredBy(triggeredBy)
.build();
run = trainingRunRepository.save(run);
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started training run {} with {} blocks from {} documents",
runId, eligibleBlocks.size(), documentCount);
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
trainingDataExportService.exportToZip().writeTo(baos);
byte[] zipBytes = baos.toByteArray();
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
ocrClient.trainModel(zipBytes);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run = trainingRunRepository.save(run);
log.info("[trainingRun={}] Training completed successfully", runId);
} catch (Exception e) {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
run = trainingRunRepository.save(run);
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
} finally {
MDC.remove("trainingRunId");
}
return run;
}
public TrainingInfoResponse getTrainingInfo() {
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
int availableDocuments = (int) eligibleBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
int totalOcrBlocks = blockRepository.findAll().size();
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
return new TrainingInfoResponse(
eligibleBlocks.size(),
totalOcrBlocks,
availableDocuments,
ocrHealthClient.isHealthy(),
lastRun,
recentRuns
);
}
@EventListener(ApplicationReadyEvent.class)
@Transactional
public void recoverOrphanedRuns() {
var cutoff = Instant.now().minusSeconds(3600);
trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).ifPresent(run -> {
if (run.getCreatedAt().isBefore(cutoff)) {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage("Abgebrochen: Dienst wurde neugestartet");
run.setCompletedAt(Instant.now());
trainingRunRepository.save(run);
log.warn("Recovered orphaned training run {} (marked FAILED on startup)", run.getId());
}
});
}
public Map<String, Object> buildTrainingInfoMap(TrainingInfoResponse info) {
return Map.of(
"availableBlocks", info.availableBlocks(),
"totalOcrBlocks", info.totalOcrBlocks(),
"availableDocuments", info.availableDocuments(),
"ocrServiceAvailable", info.ocrServiceAvailable(),
"lastRun", info.lastRun() != null ? info.lastRun() : Map.of(),
"runs", info.runs()
);
}
}

View File

@@ -0,0 +1,16 @@
CREATE TABLE ocr_training_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
status VARCHAR(20) NOT NULL DEFAULT 'RUNNING',
block_count INT NOT NULL,
document_count INT NOT NULL,
model_name VARCHAR(100) NOT NULL,
error_message TEXT,
triggered_by UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
completed_at TIMESTAMPTZ
);
-- Enforce single active run at the DB layer (application check is the UX layer)
CREATE UNIQUE INDEX idx_ocr_training_runs_one_running
ON ocr_training_runs (status)
WHERE status = 'RUNNING';