feat(ocr): add training history + POST /train + GET /training-info endpoints
- OcrTrainingRun entity + V30 migration (partial unique index prevents concurrent runs at DB level) - OcrTrainingService: concurrent-run guard, 5-block threshold, MDC log correlation, orphan recovery on ApplicationReadyEvent - POST /api/ocr/train (ADMIN) + GET /api/ocr/training-info (ADMIN) - TRAINING_ALREADY_RUNNING ErrorCode - 6 OcrTrainingServiceTest + 6 OcrControllerTest tests for the new endpoints Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,11 +7,13 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO;
|
||||
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
|
||||
import org.raddatz.familienarchiv.model.AppUser;
|
||||
import org.raddatz.familienarchiv.model.OcrJob;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.security.Permission;
|
||||
import org.raddatz.familienarchiv.security.RequirePermission;
|
||||
import org.raddatz.familienarchiv.service.OcrBatchService;
|
||||
import org.raddatz.familienarchiv.service.OcrProgressService;
|
||||
import org.raddatz.familienarchiv.service.OcrService;
|
||||
import org.raddatz.familienarchiv.service.OcrTrainingService;
|
||||
import org.raddatz.familienarchiv.service.TrainingDataExportService;
|
||||
import org.raddatz.familienarchiv.service.UserService;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
@@ -37,6 +39,7 @@ public class OcrController {
|
||||
private final OcrProgressService ocrProgressService;
|
||||
private final UserService userService;
|
||||
private final TrainingDataExportService trainingDataExportService;
|
||||
private final OcrTrainingService ocrTrainingService;
|
||||
|
||||
@PostMapping("/api/documents/{documentId}/ocr")
|
||||
@ResponseStatus(HttpStatus.ACCEPTED)
|
||||
@@ -93,6 +96,20 @@ public class OcrController {
|
||||
.body(body);
|
||||
}
|
||||
|
||||
@PostMapping("/api/ocr/train")
|
||||
@ResponseStatus(HttpStatus.CREATED)
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public OcrTrainingRun triggerTraining(Authentication authentication) {
|
||||
UUID userId = resolveUserId(authentication);
|
||||
return ocrTrainingService.triggerTraining(userId);
|
||||
}
|
||||
|
||||
@GetMapping("/api/ocr/training-info")
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {
|
||||
return ocrTrainingService.getTrainingInfo();
|
||||
}
|
||||
|
||||
private UUID resolveUserId(Authentication authentication) {
|
||||
if (authentication == null || !authentication.isAuthenticated()) return null;
|
||||
try {
|
||||
|
||||
@@ -75,6 +75,8 @@ public enum ErrorCode {
|
||||
OCR_DOCUMENT_NOT_UPLOADED,
|
||||
/** OCR processing failed for the document. 500 */
|
||||
OCR_PROCESSING_FAILED,
|
||||
/** A training run is already in progress. 409 */
|
||||
TRAINING_ALREADY_RUNNING,
|
||||
|
||||
// --- Generic ---
|
||||
/** Request validation failed (missing or malformed fields). 400 */
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.persistence.*;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.hibernate.annotations.CreationTimestamp;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.UUID;
|
||||
|
||||
@Entity
|
||||
@Table(name = "ocr_training_runs")
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class OcrTrainingRun {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.UUID)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private UUID id;
|
||||
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Column(nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private TrainingStatus status;
|
||||
|
||||
@Column(name = "block_count", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private int blockCount;
|
||||
|
||||
@Column(name = "document_count", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private int documentCount;
|
||||
|
||||
@Column(name = "model_name", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private String modelName;
|
||||
|
||||
@Column(name = "error_message")
|
||||
private String errorMessage;
|
||||
|
||||
@Column(name = "triggered_by")
|
||||
private UUID triggeredBy;
|
||||
|
||||
@CreationTimestamp
|
||||
@Column(name = "created_at", nullable = false, updatable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private Instant createdAt;
|
||||
|
||||
@Column(name = "completed_at")
|
||||
private Instant completedAt;
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
public enum TrainingStatus {
|
||||
RUNNING,
|
||||
DONE,
|
||||
FAILED
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package org.raddatz.familienarchiv.repository;
|
||||
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun, UUID> {
|
||||
|
||||
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
|
||||
|
||||
List<OcrTrainingRun> findTop5ByOrderByCreatedAtDesc();
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.slf4j.MDC;
|
||||
import org.springframework.boot.context.event.ApplicationReadyEvent;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class OcrTrainingService {
|
||||
|
||||
private final OcrTrainingRunRepository trainingRunRepository;
|
||||
private final TrainingDataExportService trainingDataExportService;
|
||||
private final OcrClient ocrClient;
|
||||
private final OcrHealthClient ocrHealthClient;
|
||||
private final TranscriptionBlockRepository blockRepository;
|
||||
|
||||
public record TrainingInfoResponse(
|
||||
int availableBlocks,
|
||||
int totalOcrBlocks,
|
||||
int availableDocuments,
|
||||
boolean ocrServiceAvailable,
|
||||
OcrTrainingRun lastRun,
|
||||
List<OcrTrainingRun> runs
|
||||
) {}
|
||||
|
||||
@Transactional
|
||||
public OcrTrainingRun triggerTraining(UUID triggeredBy) {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
if (eligibleBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
|
||||
}
|
||||
|
||||
long documentCount = eligibleBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(eligibleBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("german_kurrent")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
run = trainingRunRepository.save(run);
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started training run {} with {} blocks from {} documents",
|
||||
runId, eligibleBlocks.size(), documentCount);
|
||||
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
trainingDataExportService.exportToZip().writeTo(baos);
|
||||
byte[] zipBytes = baos.toByteArray();
|
||||
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
|
||||
ocrClient.trainModel(zipBytes);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Training completed successfully", runId);
|
||||
} catch (Exception e) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
}
|
||||
|
||||
return run;
|
||||
}
|
||||
|
||||
public TrainingInfoResponse getTrainingInfo() {
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
int availableDocuments = (int) eligibleBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
int totalOcrBlocks = blockRepository.findAll().size();
|
||||
|
||||
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
|
||||
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
|
||||
|
||||
return new TrainingInfoResponse(
|
||||
eligibleBlocks.size(),
|
||||
totalOcrBlocks,
|
||||
availableDocuments,
|
||||
ocrHealthClient.isHealthy(),
|
||||
lastRun,
|
||||
recentRuns
|
||||
);
|
||||
}
|
||||
|
||||
@EventListener(ApplicationReadyEvent.class)
|
||||
@Transactional
|
||||
public void recoverOrphanedRuns() {
|
||||
var cutoff = Instant.now().minusSeconds(3600);
|
||||
trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).ifPresent(run -> {
|
||||
if (run.getCreatedAt().isBefore(cutoff)) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage("Abgebrochen: Dienst wurde neugestartet");
|
||||
run.setCompletedAt(Instant.now());
|
||||
trainingRunRepository.save(run);
|
||||
log.warn("Recovered orphaned training run {} (marked FAILED on startup)", run.getId());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public Map<String, Object> buildTrainingInfoMap(TrainingInfoResponse info) {
|
||||
return Map.of(
|
||||
"availableBlocks", info.availableBlocks(),
|
||||
"totalOcrBlocks", info.totalOcrBlocks(),
|
||||
"availableDocuments", info.availableDocuments(),
|
||||
"ocrServiceAvailable", info.ocrServiceAvailable(),
|
||||
"lastRun", info.lastRun() != null ? info.lastRun() : Map.of(),
|
||||
"runs", info.runs()
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
CREATE TABLE ocr_training_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'RUNNING',
|
||||
block_count INT NOT NULL,
|
||||
document_count INT NOT NULL,
|
||||
model_name VARCHAR(100) NOT NULL,
|
||||
error_message TEXT,
|
||||
triggered_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
completed_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
-- Enforce single active run at the DB layer (application check is the UX layer)
|
||||
CREATE UNIQUE INDEX idx_ocr_training_runs_one_running
|
||||
ON ocr_training_runs (status)
|
||||
WHERE status = 'RUNNING';
|
||||
Reference in New Issue
Block a user