feat(ocr): full OCR pipeline with polygon annotations, training, and guided mode #232

Merged
marcel merged 40 commits from feat/issue-226-227-ocr-pipeline-polygon into main 2026-04-14 10:31:35 +02:00
66 changed files with 4017 additions and 404 deletions

36
.env.example Normal file
View File

@@ -0,0 +1,36 @@
# Datenbank (PostgreSQL)
POSTGRES_USER=archive_user
POSTGRES_PASSWORD=change-me
POSTGRES_DB=family_archive_db
# Object Storage (MinIO)
MINIO_ROOT_USER=minio_admin
MINIO_ROOT_PASSWORD=change-me
MINIO_DEFAULT_BUCKETS=archive-documents
# Ports (für Zugriff vom Host/NAS)
PORT_DB=5432
PORT_MINIO_API=9000
PORT_MINIO_CONSOLE=9001
PORT_BACKEND=8080
PORT_FRONTEND=5173
# Mailpit — local mail catcher (dev only, included in docker-compose)
# Web UI: http://localhost:8025
# SMTP: localhost:1025 (used automatically by the backend container)
PORT_MAILPIT_UI=8100
PORT_MAILPIT_SMTP=1025
# OCR Training — set a secret token to protect the /train and /segtrain endpoints on the
# Python OCR microservice. Leave empty to disable token authentication (development only).
# OCR_TRAINING_TOKEN=change-me-in-production
# Production SMTP — uncomment and fill in to send real emails instead of catching them
# APP_BASE_URL=https://your-domain.example.com
# MAIL_HOST=smtp.example.com
# MAIL_PORT=587
# MAIL_USERNAME=your-smtp-user
# MAIL_PASSWORD=your-smtp-password
# MAIL_SMTP_AUTH=true
# MAIL_STARTTLS_ENABLE=true
# APP_MAIL_FROM=noreply@your-domain.example.com

View File

@@ -152,6 +152,13 @@
<artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
<version>3.0.2</version>
</dependency>
<!-- PDF rendering for training data export -->
<dependency>
<groupId>org.apache.pdfbox</groupId>
<artifactId>pdfbox</artifactId>
<version>3.0.4</version>
</dependency>
</dependencies>

View File

@@ -12,6 +12,7 @@ import java.util.UUID;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import org.raddatz.familienarchiv.dto.DocumentSearchResult;
import org.raddatz.familienarchiv.dto.DocumentUpdateDTO;
import org.raddatz.familienarchiv.dto.DocumentVersionSummary;
@@ -21,6 +22,7 @@ import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.dto.DocumentSort;
import org.raddatz.familienarchiv.model.DocumentStatus;
import org.raddatz.familienarchiv.model.TrainingLabel;
import org.raddatz.familienarchiv.model.DocumentVersion;
import org.raddatz.familienarchiv.security.Permission;
import org.raddatz.familienarchiv.security.RequirePermission;
@@ -35,9 +37,11 @@ import org.springframework.core.io.InputStreamResource;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.PatchMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RequestPart;
@@ -208,6 +212,30 @@ public class DocumentController {
return ResponseEntity.ok(DocumentSearchResult.of(results));
}
// --- TRAINING LABELS ---
public record TrainingLabelRequest(String label, boolean enrolled) {}
@PatchMapping("/{id}/training-labels")
@RequirePermission(Permission.WRITE_ALL)
@ApiResponse(responseCode = "204")
public ResponseEntity<Void> patchTrainingLabel(
@PathVariable UUID id,
@RequestBody TrainingLabelRequest req) {
TrainingLabel label;
try {
label = TrainingLabel.valueOf(req.label());
} catch (IllegalArgumentException e) {
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Unknown training label: " + req.label());
}
if (req.enrolled()) {
documentService.addTrainingLabel(id, label);
} else {
documentService.removeTrainingLabel(id, label);
}
return ResponseEntity.noContent().build();
}
// --- VERSIONS ---
@GetMapping("/{id}/versions")

View File

@@ -7,16 +7,23 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO;
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
import org.raddatz.familienarchiv.model.AppUser;
import org.raddatz.familienarchiv.model.OcrJob;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.security.Permission;
import org.raddatz.familienarchiv.security.RequirePermission;
import org.raddatz.familienarchiv.service.OcrBatchService;
import org.raddatz.familienarchiv.service.OcrProgressService;
import org.raddatz.familienarchiv.service.OcrService;
import org.raddatz.familienarchiv.service.OcrTrainingService;
import org.raddatz.familienarchiv.service.SegmentationTrainingExportService;
import org.raddatz.familienarchiv.service.TrainingDataExportService;
import org.raddatz.familienarchiv.service.UserService;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.Authentication;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import jakarta.validation.Valid;
@@ -32,6 +39,9 @@ public class OcrController {
private final OcrBatchService ocrBatchService;
private final OcrProgressService ocrProgressService;
private final UserService userService;
private final TrainingDataExportService trainingDataExportService;
private final SegmentationTrainingExportService segmentationTrainingExportService;
private final OcrTrainingService ocrTrainingService;
@PostMapping("/api/documents/{documentId}/ocr")
@ResponseStatus(HttpStatus.ACCEPTED)
@@ -41,7 +51,8 @@ public class OcrController {
@RequestBody TriggerOcrDTO dto,
Authentication authentication) {
UUID userId = resolveUserId(authentication);
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId);
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId,
Boolean.TRUE.equals(dto.getUseExistingAnnotations()));
return Map.of("jobId", jobId);
}
@@ -75,6 +86,54 @@ public class OcrController {
return ocrService.getDocumentOcrStatus(documentId);
}
@GetMapping("/api/ocr/training-data/export")
@RequirePermission(Permission.ADMIN)
public ResponseEntity<StreamingResponseBody> exportTrainingData() {
if (trainingDataExportService.queryEligibleBlocks().isEmpty()) {
return ResponseEntity.noContent().build();
}
StreamingResponseBody body = trainingDataExportService.exportToZip();
return ResponseEntity.ok()
.contentType(MediaType.parseMediaType("application/zip"))
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"training-data.zip\"")
.body(body);
}
@GetMapping("/api/ocr/segmentation-training-data/export")
@RequirePermission(Permission.ADMIN)
public ResponseEntity<StreamingResponseBody> exportSegmentationTrainingData() {
if (segmentationTrainingExportService.querySegmentationBlocks().isEmpty()) {
return ResponseEntity.noContent().build();
}
StreamingResponseBody body = segmentationTrainingExportService.exportToZip();
return ResponseEntity.ok()
.contentType(MediaType.parseMediaType("application/zip"))
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"segmentation-data.zip\"")
.body(body);
}
@PostMapping("/api/ocr/train")
@ResponseStatus(HttpStatus.CREATED)
@RequirePermission(Permission.ADMIN)
public OcrTrainingRun triggerTraining(Authentication authentication) {
UUID userId = resolveUserId(authentication);
return ocrTrainingService.triggerTraining(userId);
}
@PostMapping("/api/ocr/segtrain")
@ResponseStatus(HttpStatus.CREATED)
@RequirePermission(Permission.ADMIN)
public OcrTrainingRun triggerSegTraining(Authentication authentication) {
UUID userId = resolveUserId(authentication);
return ocrTrainingService.triggerSegTraining(userId);
}
@GetMapping("/api/ocr/training-info")
@RequirePermission(Permission.ADMIN)
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {
return ocrTrainingService.getTrainingInfo();
}
private UUID resolveUserId(Authentication authentication) {
if (authentication == null || !authentication.isAuthenticated()) return null;
try {

View File

@@ -10,4 +10,5 @@ import org.raddatz.familienarchiv.model.ScriptType;
@AllArgsConstructor
public class TriggerOcrDTO {
private ScriptType scriptType;
private Boolean useExistingAnnotations = false;
}

View File

@@ -49,8 +49,6 @@ public enum ErrorCode {
// --- Annotations ---
/** The annotation with the given ID does not exist. 404 */
ANNOTATION_NOT_FOUND,
/** The new annotation overlaps an existing one on the same page. 409 */
ANNOTATION_OVERLAP,
// --- Transcription Blocks ---
/** The transcription block with the given ID does not exist. 404 */
@@ -75,6 +73,8 @@ public enum ErrorCode {
OCR_DOCUMENT_NOT_UPLOADED,
/** OCR processing failed for the document. 500 */
OCR_PROCESSING_FAILED,
/** A training run is already in progress. 409 */
TRAINING_ALREADY_RUNNING,
// --- Generic ---
/** Request validation failed (missing or malformed fields). 400 */

View File

@@ -110,4 +110,11 @@ public class Document {
@JoinTable(name = "document_tags", joinColumns = @JoinColumn(name = "document_id"), inverseJoinColumns = @JoinColumn(name = "tag_id"))
@Builder.Default
private Set<Tag> tags = new HashSet<>();
@ElementCollection(fetch = FetchType.EAGER)
@CollectionTable(name = "document_training_labels", joinColumns = @JoinColumn(name = "document_id"))
@Column(name = "label")
@Enumerated(EnumType.STRING)
@Builder.Default
private Set<TrainingLabel> trainingLabels = new HashSet<>();
}

View File

@@ -0,0 +1,69 @@
package org.raddatz.familienarchiv.model;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.persistence.*;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.hibernate.annotations.CreationTimestamp;
import java.time.Instant;
import java.util.UUID;
@Entity
@Table(name = "ocr_training_runs")
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class OcrTrainingRun {
@Id
@GeneratedValue(strategy = GenerationType.UUID)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID id;
@Enumerated(EnumType.STRING)
@Column(nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private TrainingStatus status;
@Column(name = "block_count", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private int blockCount;
@Column(name = "document_count", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private int documentCount;
@Column(name = "model_name", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private String modelName;
@Column(name = "cer")
private Double cer;
@Column(name = "loss")
private Double loss;
@Column(name = "accuracy")
private Double accuracy;
@Column(name = "epochs")
private Integer epochs;
@Column(name = "error_message")
private String errorMessage;
@Column(name = "triggered_by")
private UUID triggeredBy;
@CreationTimestamp
@Column(name = "created_at", nullable = false, updatable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private Instant createdAt;
@Column(name = "completed_at")
private Instant completedAt;
}

View File

@@ -0,0 +1,6 @@
package org.raddatz.familienarchiv.model;
public enum TrainingLabel {
KURRENT_RECOGNITION,
KURRENT_SEGMENTATION
}

View File

@@ -0,0 +1,7 @@
package org.raddatz.familienarchiv.model;
public enum TrainingStatus {
RUNNING,
DONE,
FAILED
}

View File

@@ -30,8 +30,7 @@ public class TranscriptionBlock {
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID documentId;
@Column(nullable = false, columnDefinition = "TEXT")
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
@Column(columnDefinition = "TEXT")
private String text;
@Column(length = 200)

View File

@@ -0,0 +1,16 @@
package org.raddatz.familienarchiv.repository;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.springframework.data.jpa.repository.JpaRepository;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun, UUID> {
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
List<OcrTrainingRun> findTop5ByOrderByCreatedAtDesc();
}

View File

@@ -2,6 +2,7 @@ package org.raddatz.familienarchiv.repository;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import java.util.List;
import java.util.Optional;
@@ -13,5 +14,27 @@ public interface TranscriptionBlockRepository extends JpaRepository<Transcriptio
Optional<TranscriptionBlock> findByIdAndDocumentId(UUID id, UUID documentId);
Optional<TranscriptionBlock> findByAnnotationId(UUID annotationId);
void deleteByAnnotationId(UUID annotationId);
int countByDocumentId(UUID documentId);
@Query("""
SELECT b FROM TranscriptionBlock b
JOIN DocumentAnnotation a ON a.id = b.annotationId
JOIN Document d ON d.id = b.documentId
WHERE b.reviewed = true
AND 'KURRENT_RECOGNITION' MEMBER OF d.trainingLabels
""")
List<TranscriptionBlock> findEligibleKurrentBlocks();
@Query("""
SELECT b FROM TranscriptionBlock b
JOIN DocumentAnnotation a ON a.id = b.annotationId
JOIN Document d ON d.id = b.documentId
WHERE b.source = 'MANUAL'
AND 'KURRENT_SEGMENTATION' MEMBER OF d.trainingLabels
""")
List<TranscriptionBlock> findSegmentationBlocks();
}

View File

@@ -6,6 +6,7 @@ import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.DocumentAnnotation;
import org.raddatz.familienarchiv.repository.AnnotationRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@@ -17,6 +18,7 @@ import java.util.UUID;
public class AnnotationService {
private final AnnotationRepository annotationRepository;
private final TranscriptionBlockRepository blockRepository;
public List<DocumentAnnotation> listAnnotations(UUID documentId) {
return annotationRepository.findByDocumentId(documentId);
@@ -24,15 +26,6 @@ public class AnnotationService {
@Transactional
public DocumentAnnotation createAnnotation(UUID documentId, CreateAnnotationDTO dto, UUID userId, String fileHash) {
List<DocumentAnnotation> existing =
annotationRepository.findByDocumentIdAndPageNumber(documentId, dto.getPageNumber());
boolean overlaps = existing.stream().anyMatch(a -> overlaps(a, dto));
if (overlaps) {
throw DomainException.conflict(
ErrorCode.ANNOTATION_OVERLAP, "Annotation overlaps an existing one on this page");
}
DocumentAnnotation annotation = DocumentAnnotation.builder()
.documentId(documentId)
.pageNumber(dto.getPageNumber())
@@ -79,6 +72,7 @@ public class AnnotationService {
throw DomainException.forbidden("Only the annotation author can delete it");
}
blockRepository.deleteByAnnotationId(annotationId);
annotationRepository.delete(annotation);
}
@@ -90,14 +84,4 @@ public class AnnotationService {
});
}
// ─── private helpers ──────────────────────────────────────────────────────
private boolean overlaps(DocumentAnnotation existing, CreateAnnotationDTO dto) {
double ex2 = existing.getX() + existing.getWidth();
double ey2 = existing.getY() + existing.getHeight();
double dx2 = dto.getX() + dto.getWidth();
double dy2 = dto.getY() + dto.getHeight();
return existing.getX() < dx2 && ex2 > dto.getX()
&& existing.getY() < dy2 && ey2 > dto.getY();
}
}

View File

@@ -9,6 +9,7 @@ import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.dto.DocumentSort;
import org.raddatz.familienarchiv.model.DocumentStatus;
import org.raddatz.familienarchiv.model.ScriptType;
import org.raddatz.familienarchiv.model.TrainingLabel;
import org.raddatz.familienarchiv.model.Person;
import org.raddatz.familienarchiv.model.Tag;
import org.raddatz.familienarchiv.repository.DocumentRepository;
@@ -385,6 +386,20 @@ public class DocumentService {
documentRepository.save(doc);
}
@Transactional
public void addTrainingLabel(UUID documentId, TrainingLabel label) {
Document doc = getDocumentById(documentId);
doc.getTrainingLabels().add(label);
documentRepository.save(doc);
}
@Transactional
public void removeTrainingLabel(UUID documentId, TrainingLabel label) {
Document doc = getDocumentById(documentId);
doc.getTrainingLabels().remove(label);
documentRepository.save(doc);
}
public Document getDocumentById(UUID id) {
return documentRepository.findById(id)
.orElseThrow(() -> DomainException.notFound(ErrorCode.DOCUMENT_NOT_FOUND, "Document not found: " + id));

View File

@@ -114,7 +114,8 @@ public class FileService {
/**
* Generates a presigned URL for downloading an object from S3/MinIO.
* Valid for 15 minutes — enough for OCR processing on CPU.
* Valid for 1 hour — covers multi-page documents on CPU-only OCR hardware
* (a 100-page document at 10 s/page takes ~17 min; 1 h gives ample headroom).
*/
public String generatePresignedUrl(String s3Key) {
GetObjectRequest getObjectRequest = GetObjectRequest.builder()
@@ -123,7 +124,7 @@ public class FileService {
.build();
GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder()
.signatureDuration(Duration.ofMinutes(15))
.signatureDuration(Duration.ofHours(1))
.getObjectRequest(getObjectRequest)
.build();

View File

@@ -32,6 +32,11 @@ public class OcrAsyncRunner {
@Async
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
runSingleDocument(jobId, documentId, userId, false);
}
@Async
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId, boolean useExistingAnnotations) {
OcrJob job = ocrJobRepository.findById(jobId).orElse(null);
if (job == null) return;
@@ -49,7 +54,18 @@ public class OcrAsyncRunner {
try {
updateProgress(job, "LOADING");
clearExistingBlocks(documentId);
List<OcrClient.OcrRegion> regions = null;
if (useExistingAnnotations) {
regions = annotationService.listAnnotations(documentId).stream()
.map(a -> new OcrClient.OcrRegion(
a.getId().toString(), a.getPageNumber(),
a.getX(), a.getY(), a.getWidth(), a.getHeight()))
.toList();
} else {
clearExistingBlocks(documentId);
}
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
AtomicInteger blockCounter = new AtomicInteger(0);
@@ -57,7 +73,7 @@ public class OcrAsyncRunner {
AtomicInteger skippedPages = new AtomicInteger(0);
AtomicInteger totalPages = new AtomicInteger(0);
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), event -> {
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, event -> {
switch (event) {
case OcrStreamEvent.Start start -> {
totalPages.set(start.totalPages());
@@ -204,14 +220,21 @@ public class OcrAsyncRunner {
void createSingleBlock(UUID documentId, OcrBlockResult block,
UUID userId, String fileHash, int sortOrder) {
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
block.pageNumber(), block.x(), block.y(),
block.width(), block.height(), OCR_ANNOTATION_COLOR);
if (block.annotationId() != null) {
// Guided mode — annotation already exists; upsert the text block only
transcriptionService.upsertGuidedBlock(
documentId, UUID.fromString(block.annotationId()), block.text(), userId);
} else {
// Normal mode — create a new annotation and a new OCR block
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
block.pageNumber(), block.x(), block.y(),
block.width(), block.height(), OCR_ANNOTATION_COLOR);
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
documentId, annotationDTO, userId, fileHash, block.polygon());
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
documentId, annotationDTO, userId, fileHash, block.polygon());
transcriptionService.createOcrBlock(documentId, annotation.getId(),
block.text(), sortOrder, userId);
transcriptionService.createOcrBlock(documentId, annotation.getId(),
block.text(), sortOrder, userId);
}
}
}

View File

@@ -12,5 +12,6 @@ public record OcrBlockResult(
double width,
double height,
List<List<Double>> polygon,
String text
String text,
String annotationId // null in normal mode; set in guided mode to link back to existing annotation
) {}

View File

@@ -10,12 +10,42 @@ import java.util.function.Consumer;
public interface OcrClient {
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
/**
* A pre-drawn annotation region to use as guidance for OCR.
* When regions are provided, the OCR engine crops to each region and
* runs recognition only within that area, skipping full-page layout detection.
*/
record OcrRegion(String annotationId, int pageNumber,
double x, double y, double width, double height) {}
/**
* Send a training ZIP to the OCR service for fine-tuning the Kurrent model.
*
* @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService
* @return training result metrics (loss, accuracy, epochs)
*/
TrainingResult trainModel(byte[] trainingDataZip);
record TrainingResult(Double loss, Double accuracy, Double cer, Integer epochs) {}
/**
* Send a segmentation training ZIP to the OCR service for fine-tuning the blla model.
*
* @param trainingDataZip raw ZIP bytes produced by SegmentationTrainingExportService
* @return training result metrics
*/
TrainingResult segtrainModel(byte[] trainingDataZip);
/**
* Stream OCR results page-by-page via NDJSON. Implementations should override
* this method. The default exists only for backward compatibility during migration
* — it calls extractBlocks() and synthesizes events from the collected result.
*
* @param regions optional list of pre-drawn annotation regions; when non-null,
* the OCR service runs in guided mode (crop + recognize per region)
*/
default void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
default void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();

View File

@@ -52,6 +52,11 @@ public class OcrService {
}
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId) {
return startOcr(documentId, scriptTypeOverride, userId, false);
}
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId,
boolean useExistingAnnotations) {
Document doc = documentService.getDocumentById(documentId);
if (doc.getStatus() == DocumentStatus.PLACEHOLDER) {
@@ -66,6 +71,9 @@ public class OcrService {
if (scriptTypeOverride != null) {
documentService.updateScriptType(documentId, scriptTypeOverride);
if (scriptTypeOverride == ScriptType.HANDWRITING_KURRENT) {
documentService.addTrainingLabel(documentId, TrainingLabel.KURRENT_RECOGNITION);
}
}
OcrJob job = OcrJob.builder()
@@ -82,7 +90,7 @@ public class OcrService {
.build();
ocrJobDocumentRepository.save(jobDoc);
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId);
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId, useExistingAnnotations);
return job.getId();
}
}

View File

@@ -0,0 +1,236 @@
package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.slf4j.MDC;
import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.TransactionTemplate;
import java.io.ByteArrayOutputStream;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
@Service
@RequiredArgsConstructor
@Slf4j
public class OcrTrainingService {
private final OcrTrainingRunRepository trainingRunRepository;
private final TrainingDataExportService trainingDataExportService;
private final SegmentationTrainingExportService segmentationTrainingExportService;
private final OcrClient ocrClient;
private final OcrHealthClient ocrHealthClient;
private final TranscriptionBlockRepository blockRepository;
private final TransactionTemplate txTemplate;
public record TrainingInfoResponse(
int availableBlocks,
int totalOcrBlocks,
int availableDocuments,
int availableSegBlocks,
boolean ocrServiceAvailable,
OcrTrainingRun lastRun,
List<OcrTrainingRun> runs
) {}
// Not safe for horizontal scaling: training reloads the Kraken model in-process on the
// Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique
// index) prevents concurrent training API calls, but cannot prevent two OCR service replicas
// from diverging on model state. Deploy as a single instance only. See ADR-001.
public OcrTrainingRun triggerTraining(UUID triggeredBy) {
// Short transaction: guard check + create RUNNING row, then commit immediately.
// The DB connection is released before the OCR HTTP call, which can take several minutes.
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
if (eligibleBlocks.size() < 5) {
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
}
long documentCount = eligibleBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
OcrTrainingRun newRun = OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.blockCount(eligibleBlocks.size())
.documentCount((int) documentCount)
.modelName("german_kurrent")
.triggeredBy(triggeredBy)
.build();
return trainingRunRepository.save(newRun);
}));
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started training run {} with {} blocks from {} documents",
runId, run.getBlockCount(), run.getDocumentCount());
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
trainingDataExportService.exportToZip().writeTo(baos);
byte[] zipBytes = baos.toByteArray();
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
return Objects.requireNonNull(txTemplate.execute(status -> {
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setCer(result.cer());
run.setLoss(result.loss());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
OcrTrainingRun updated = trainingRunRepository.save(run);
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
return updated;
}));
} catch (Exception e) {
return Objects.requireNonNull(txTemplate.execute(status -> {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
OcrTrainingRun failed = trainingRunRepository.save(run);
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
return failed;
}));
} finally {
MDC.remove("trainingRunId");
}
}
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
// Same pattern as triggerTraining: narrow transactions around DB writes only.
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
if (segBlocks.size() < 5) {
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
}
long documentCount = segBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
OcrTrainingRun newRun = OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.blockCount(segBlocks.size())
.documentCount((int) documentCount)
.modelName("blla")
.triggeredBy(triggeredBy)
.build();
return trainingRunRepository.save(newRun);
}));
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started segmentation training run {} with {} segments from {} documents",
runId, run.getBlockCount(), run.getDocumentCount());
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
segmentationTrainingExportService.exportToZip().writeTo(baos);
byte[] zipBytes = baos.toByteArray();
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
return Objects.requireNonNull(txTemplate.execute(status -> {
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setLoss(result.loss());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
OcrTrainingRun updated = trainingRunRepository.save(run);
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
return updated;
}));
} catch (Exception e) {
return Objects.requireNonNull(txTemplate.execute(status -> {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
OcrTrainingRun failed = trainingRunRepository.save(run);
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
return failed;
}));
} finally {
MDC.remove("trainingRunId");
}
}
public TrainingInfoResponse getTrainingInfo() {
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
int availableDocuments = (int) eligibleBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
int totalOcrBlocks = (int) blockRepository.count();
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
return new TrainingInfoResponse(
eligibleBlocks.size(),
totalOcrBlocks,
availableDocuments,
availableSegBlocks,
ocrHealthClient.isHealthy(),
lastRun,
recentRuns
);
}
@EventListener(ApplicationReadyEvent.class)
@Transactional
public void recoverOrphanedRuns() {
var cutoff = Instant.now().minusSeconds(3600);
trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).ifPresent(run -> {
if (run.getCreatedAt().isBefore(cutoff)) {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage("Abgebrochen: Dienst wurde neugestartet");
run.setCompletedAt(Instant.now());
trainingRunRepository.save(run);
log.warn("Recovered orphaned training run {} (marked FAILED on startup)", run.getId());
}
});
}
public Map<String, Object> buildTrainingInfoMap(TrainingInfoResponse info) {
return Map.of(
"availableBlocks", info.availableBlocks(),
"totalOcrBlocks", info.totalOcrBlocks(),
"availableDocuments", info.availableDocuments(),
"availableSegBlocks", info.availableSegBlocks(),
"ocrServiceAvailable", info.ocrServiceAvailable(),
"lastRun", info.lastRun() != null ? info.lastRun() : Map.of(),
"runs", info.runs()
);
}
}

View File

@@ -9,9 +9,14 @@ import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.model.ScriptType;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestClient;
import java.io.BufferedReader;
@@ -36,11 +41,16 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true);
private final RestClient restClient;
private final RestClient trainingRestClient;
private final HttpClient streamingHttpClient;
private final String baseUrl;
private final String trainingToken;
public RestClientOcrClient(@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl) {
public RestClientOcrClient(
@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl,
@Value("${app.ocr.training-token:}") String trainingToken) {
this.baseUrl = baseUrl;
this.trainingToken = trainingToken;
HttpClient httpClient = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
@@ -54,6 +64,17 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.requestFactory(requestFactory)
.build();
HttpClient trainingHttpClient = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
.connectTimeout(Duration.ofSeconds(10))
.build();
JdkClientHttpRequestFactory trainingRequestFactory = new JdkClientHttpRequestFactory(trainingHttpClient);
trainingRequestFactory.setReadTimeout(Duration.ofMinutes(10));
this.trainingRestClient = RestClient.builder()
.baseUrl(baseUrl)
.requestFactory(trainingRequestFactory)
.build();
this.streamingHttpClient = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
.connectTimeout(Duration.ofSeconds(10))
@@ -81,6 +102,64 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.toList();
}
@Override
public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) {
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
@Override
public String getFilename() { return "training-data.zip"; }
};
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
HttpHeaders partHeaders = new HttpHeaders();
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders));
var spec = trainingRestClient.post()
.uri("/train")
.contentType(MediaType.MULTIPART_FORM_DATA);
if (trainingToken != null && !trainingToken.isBlank()) {
spec = spec.header("X-Training-Token", trainingToken);
}
TrainingResultJson result = spec
.body(body)
.retrieve()
.body(TrainingResultJson.class);
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
}
@Override
public OcrClient.TrainingResult segtrainModel(byte[] trainingDataZip) {
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
@Override
public String getFilename() { return "segmentation-data.zip"; }
};
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
HttpHeaders partHeaders = new HttpHeaders();
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders));
var spec = trainingRestClient.post()
.uri("/segtrain")
.contentType(MediaType.MULTIPART_FORM_DATA);
if (trainingToken != null && !trainingToken.isBlank()) {
spec = spec.header("X-Training-Token", trainingToken);
}
TrainingResultJson result = spec
.body(body)
.retrieve()
.body(TrainingResultJson.class);
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
}
@Override
public boolean isHealthy() {
try {
@@ -96,13 +175,18 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
}
@Override
public void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
public void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
String body;
try {
body = NDJSON_MAPPER.writeValueAsString(Map.of(
"pdfUrl", pdfUrl,
"scriptType", scriptType.name(),
"language", "de"));
var requestMap = new java.util.LinkedHashMap<String, Object>();
requestMap.put("pdfUrl", pdfUrl);
requestMap.put("scriptType", scriptType.name());
requestMap.put("language", "de");
if (regions != null && !regions.isEmpty()) {
requestMap.put("regions", regions);
}
body = NDJSON_MAPPER.writeValueAsString(requestMap);
} catch (IOException e) {
throw new RuntimeException("Failed to serialize OCR request", e);
}
@@ -120,7 +204,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
if (response.statusCode() == 404) {
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
OcrClient.super.streamBlocks(pdfUrl, scriptType, handler);
OcrClient.super.streamBlocks(pdfUrl, scriptType, regions, handler);
return;
}
@@ -171,6 +255,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
}
}
record TrainingResultJson(Double loss, Double accuracy, Double cer, Integer epochs) {}
record OcrBlockJson(
@JsonProperty("pageNumber") int pageNumber,
double x,
@@ -178,10 +264,11 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
double width,
double height,
List<List<Double>> polygon,
String text
String text,
String annotationId
) {
OcrBlockResult toResult() {
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text);
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text, annotationId);
}
}
}

View File

@@ -0,0 +1,174 @@
package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.pdfbox.Loader;
import org.apache.pdfbox.pdmodel.PDDocument;
import org.apache.pdfbox.rendering.PDFRenderer;
import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.model.DocumentAnnotation;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.AnnotationRepository;
import org.raddatz.familienarchiv.repository.DocumentRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
@Service
@RequiredArgsConstructor
@Slf4j
public class SegmentationTrainingExportService {
private final TranscriptionBlockRepository blockRepository;
private final AnnotationRepository annotationRepository;
private final DocumentRepository documentRepository;
private final FileService fileService;
public List<TranscriptionBlock> querySegmentationBlocks() {
return blockRepository.findSegmentationBlocks();
}
public StreamingResponseBody exportToZip() {
List<TranscriptionBlock> blocks = querySegmentationBlocks();
if (blocks.isEmpty()) {
return out -> {};
}
// Group by documentId so we download each PDF only once
Map<UUID, List<TranscriptionBlock>> byDoc = new LinkedHashMap<>();
for (TranscriptionBlock b : blocks) {
byDoc.computeIfAbsent(b.getDocumentId(), k -> new ArrayList<>()).add(b);
}
// Pre-fetch annotations keyed by id
Map<UUID, DocumentAnnotation> annotations = new HashMap<>();
for (TranscriptionBlock b : blocks) {
annotationRepository.findById(b.getAnnotationId())
.ifPresent(a -> annotations.put(a.getId(), a));
}
// Pre-fetch documents keyed by id
Map<UUID, Document> documents = new HashMap<>();
for (UUID docId : byDoc.keySet()) {
documentRepository.findById(docId).ifPresent(d -> documents.put(d.getId(), d));
}
return out -> {
try (ZipOutputStream zip = new ZipOutputStream(out)) {
for (Map.Entry<UUID, List<TranscriptionBlock>> entry : byDoc.entrySet()) {
UUID docId = entry.getKey();
Document doc = documents.get(docId);
if (doc == null || doc.getFilePath() == null) {
log.warn("Skipping document {} — no file path", docId);
continue;
}
byte[] pdfBytes;
try {
pdfBytes = fileService.downloadFileBytes(doc.getFilePath());
} catch (FileService.StorageFileNotFoundException | IOException e) {
log.warn("Skipping document {} — S3 download failed: {}", docId, e.getMessage());
continue;
}
// Group blocks by page number for this document
Map<Integer, List<TranscriptionBlock>> byPage = new LinkedHashMap<>();
for (TranscriptionBlock b : entry.getValue()) {
DocumentAnnotation ann = annotations.get(b.getAnnotationId());
if (ann != null) {
byPage.computeIfAbsent(ann.getPageNumber(), k -> new ArrayList<>()).add(b);
}
}
try (PDDocument pdf = Loader.loadPDF(pdfBytes)) {
PDFRenderer renderer = new PDFRenderer(pdf);
for (Map.Entry<Integer, List<TranscriptionBlock>> pageEntry : byPage.entrySet()) {
int pageNumber = pageEntry.getKey();
int pageIdx = pageNumber - 1;
if (pageIdx < 0 || pageIdx >= pdf.getNumberOfPages()) continue;
BufferedImage pageImage = renderer.renderImageWithDPI(pageIdx, 300);
String basename = "page-" + docId + "-" + pageNumber;
// Collect annotations for this page
List<DocumentAnnotation> pageAnnotations = new ArrayList<>();
for (TranscriptionBlock b : pageEntry.getValue()) {
DocumentAnnotation ann = annotations.get(b.getAnnotationId());
if (ann != null) pageAnnotations.add(ann);
}
writePngEntry(zip, basename, pageImage);
writePageXmlEntry(zip, basename, pageImage, pageAnnotations);
}
} catch (Exception e) {
log.warn("Skipping document {} — rendering failed: {}", docId, e.getMessage());
}
}
}
};
}
private void writePngEntry(ZipOutputStream zip, String basename, BufferedImage image) throws IOException {
zip.putNextEntry(new ZipEntry(basename + ".png"));
ImageIO.write(image, "PNG", zip);
zip.closeEntry();
}
private void writePageXmlEntry(ZipOutputStream zip, String basename,
BufferedImage pageImage,
List<DocumentAnnotation> annotations) throws IOException {
int imgW = pageImage.getWidth();
int imgH = pageImage.getHeight();
StringBuilder regions = new StringBuilder();
for (DocumentAnnotation ann : annotations) {
String coords = buildPolygonCoords(ann, imgW, imgH);
String regionId = ann.getId().toString();
regions.append(" <TextRegion id=\"").append(regionId).append("\">\n");
regions.append(" <Coords points=\"").append(coords).append("\"/>\n");
regions.append(" </TextRegion>\n");
}
String xml = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n"
+ "<PcGts xmlns=\"http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15\">\n"
+ " <Page imageFilename=\"" + basename + ".png\""
+ " imageWidth=\"" + imgW + "\""
+ " imageHeight=\"" + imgH + "\">\n"
+ regions
+ " </Page>\n"
+ "</PcGts>\n";
zip.putNextEntry(new ZipEntry(basename + ".xml"));
zip.write(xml.getBytes(StandardCharsets.UTF_8));
zip.closeEntry();
}
String buildPolygonCoords(DocumentAnnotation ann, int imgW, int imgH) {
List<List<Double>> polygon = ann.getPolygon();
if (polygon != null && !polygon.isEmpty()) {
// Use explicit polygon — de-normalize to pixel coordinates
StringBuilder sb = new StringBuilder();
for (List<Double> pt : polygon) {
if (sb.length() > 0) sb.append(' ');
int px = (int) (pt.get(0) * imgW);
int py = (int) (pt.get(1) * imgH);
sb.append(px).append(',').append(py);
}
return sb.toString();
}
// Fall back to bounding box from x/y/width/height
int x = (int) (ann.getX() * imgW);
int y = (int) (ann.getY() * imgH);
int w = (int) (ann.getWidth() * imgW);
int h = (int) (ann.getHeight() * imgH);
return x + "," + y + " " + (x + w) + "," + y + " " + (x + w) + "," + (y + h) + " " + x + "," + (y + h);
}
}

View File

@@ -0,0 +1,173 @@
package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.pdfbox.Loader;
import org.apache.pdfbox.pdmodel.PDDocument;
import org.apache.pdfbox.rendering.PDFRenderer;
import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.model.DocumentAnnotation;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.AnnotationRepository;
import org.raddatz.familienarchiv.repository.DocumentRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
@Service
@RequiredArgsConstructor
@Slf4j
public class TrainingDataExportService {
private final TranscriptionBlockRepository blockRepository;
private final AnnotationRepository annotationRepository;
private final DocumentRepository documentRepository;
private final FileService fileService;
public List<TranscriptionBlock> queryEligibleBlocks() {
return blockRepository.findEligibleKurrentBlocks();
}
public StreamingResponseBody exportToZip() {
// Collect all data before entering the lambda — no open DB txn during streaming
List<TranscriptionBlock> blocks = queryEligibleBlocks();
if (blocks.isEmpty()) {
return out -> {}; // caller checks isEmpty() for 204 response
}
// Group blocks by documentId so we only download each PDF once
Map<UUID, List<TranscriptionBlock>> byDoc = new LinkedHashMap<>();
for (TranscriptionBlock b : blocks) {
byDoc.computeIfAbsent(b.getDocumentId(), k -> new ArrayList<>()).add(b);
}
// Pre-fetch annotations keyed by id
Map<UUID, DocumentAnnotation> annotations = new HashMap<>();
for (TranscriptionBlock b : blocks) {
annotationRepository.findById(b.getAnnotationId())
.ifPresent(a -> annotations.put(a.getId(), a));
}
// Pre-fetch documents keyed by id
Map<UUID, Document> documents = new HashMap<>();
for (UUID docId : byDoc.keySet()) {
documentRepository.findById(docId).ifPresent(d -> documents.put(d.getId(), d));
}
return out -> {
try (ZipOutputStream zip = new ZipOutputStream(out)) {
for (Map.Entry<UUID, List<TranscriptionBlock>> entry : byDoc.entrySet()) {
UUID docId = entry.getKey();
Document doc = documents.get(docId);
if (doc == null || doc.getFilePath() == null) {
log.warn("Skipping document {} — no file path", docId);
continue;
}
byte[] pdfBytes;
try {
pdfBytes = fileService.downloadFileBytes(doc.getFilePath());
} catch (FileService.StorageFileNotFoundException | IOException e) {
log.warn("Skipping document {} — S3 download failed: {}", docId, e.getMessage());
continue;
}
try (PDDocument pdf = Loader.loadPDF(pdfBytes)) {
PDFRenderer renderer = new PDFRenderer(pdf);
for (TranscriptionBlock block : entry.getValue()) {
DocumentAnnotation ann = annotations.get(block.getAnnotationId());
if (ann == null) continue;
int pageIdx = ann.getPageNumber() - 1; // pageNumber is 1-based
if (pageIdx < 0 || pageIdx >= pdf.getNumberOfPages()) continue;
BufferedImage pageImage = renderPageImage(renderer, pageIdx);
BufferedImage cropped = cropBlockImage(pageImage, ann);
writeTrainingPair(zip, block.getId(), cropped, block.getText());
}
} catch (Exception e) {
log.warn("Skipping document {} — rendering failed: {}", docId, e.getMessage());
}
}
}
};
}
BufferedImage renderPageImage(PDFRenderer renderer, int pageIdx) throws IOException {
return renderer.renderImageWithDPI(pageIdx, 300);
}
BufferedImage cropBlockImage(BufferedImage page, DocumentAnnotation ann) {
int imgW = page.getWidth();
int imgH = page.getHeight();
int x = (int) (ann.getX() * imgW);
int y = (int) (ann.getY() * imgH);
int w = (int) (ann.getWidth() * imgW);
int h = (int) (ann.getHeight() * imgH);
// Clamp to image bounds
x = Math.max(0, Math.min(x, imgW - 1));
y = Math.max(0, Math.min(y, imgH - 1));
w = Math.max(1, Math.min(w, imgW - x));
h = Math.max(1, Math.min(h, imgH - y));
return page.getSubimage(x, y, w, h);
}
void writeTrainingPair(ZipOutputStream zip, UUID blockId, BufferedImage image, String text) throws IOException {
String base = blockId.toString();
int w = image.getWidth();
int h = image.getHeight();
// Baseline at 75 % height — typical text baseline position in a cropped line image
int baselineY = (h * 3) / 4;
// Write PNG
zip.putNextEntry(new ZipEntry(base + ".png"));
ImageIO.write(image, "PNG", zip);
zip.closeEntry();
// Write PAGE XML (Kraken 7+ dropped the legacy "path" format)
String safeText = escapeXml(text != null ? text : "");
String xml = String.format(
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" +
"<PcGts xmlns=\"http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15\">\n" +
" <Metadata><Creator>familienarchiv</Creator></Metadata>\n" +
" <Page imageFilename=\"%s.png\" imageWidth=\"%d\" imageHeight=\"%d\">\n" +
" <TextRegion id=\"r0\" type=\"paragraph\">\n" +
" <Coords points=\"0,0 %d,0 %d,%d 0,%d\"/>\n" +
" <TextLine id=\"l0\">\n" +
" <Coords points=\"0,0 %d,0 %d,%d 0,%d\"/>\n" +
" <Baseline points=\"0,%d %d,%d\"/>\n" +
" <TextEquiv><Unicode>%s</Unicode></TextEquiv>\n" +
" </TextLine>\n" +
" </TextRegion>\n" +
" </Page>\n" +
"</PcGts>\n",
base, w, h,
w - 1, w - 1, h - 1, h - 1,
w - 1, w - 1, h - 1, h - 1,
baselineY, w - 1, baselineY,
safeText);
zip.putNextEntry(new ZipEntry(base + ".xml"));
zip.write(xml.getBytes(StandardCharsets.UTF_8));
zip.closeEntry();
}
private static String escapeXml(String text) {
return text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;");
}
}

View File

@@ -94,6 +94,27 @@ public class TranscriptionService {
return saved;
}
/**
* Upsert an OCR transcription block for a pre-existing annotation (guided OCR mode).
* If the annotation already has a MANUAL block, it is left unchanged.
* If it has an OCR block, the text is updated in-place.
* If it has no block yet, a new OCR block is created.
*/
@Transactional
public TranscriptionBlock upsertGuidedBlock(UUID documentId, UUID annotationId,
String text, UUID userId) {
return blockRepository.findByAnnotationId(annotationId).map(existing -> {
if (existing.getSource() == BlockSource.MANUAL && !existing.getText().isBlank()) {
return existing; // never overwrite non-empty manual transcription
}
existing.setText(sanitizeText(text));
existing.setUpdatedBy(userId);
TranscriptionBlock saved = blockRepository.save(existing);
saveVersion(saved, userId);
return saved;
}).orElseGet(() -> createOcrBlock(documentId, annotationId, text, 0, userId));
}
@Transactional
public TranscriptionBlock updateBlock(UUID documentId, UUID blockId,
UpdateTranscriptionBlockDTO dto, UUID userId) {

View File

@@ -0,0 +1,5 @@
CREATE TABLE document_training_labels (
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
label VARCHAR(50) NOT NULL,
PRIMARY KEY (document_id, label)
);

View File

@@ -0,0 +1,16 @@
CREATE TABLE ocr_training_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
status VARCHAR(20) NOT NULL DEFAULT 'RUNNING',
block_count INT NOT NULL,
document_count INT NOT NULL,
model_name VARCHAR(100) NOT NULL,
error_message TEXT,
triggered_by UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
completed_at TIMESTAMPTZ
);
-- Enforce single active run at the DB layer (application check is the UX layer)
CREATE UNIQUE INDEX idx_ocr_training_runs_one_running
ON ocr_training_runs (status)
WHERE status = 'RUNNING';

View File

@@ -0,0 +1,5 @@
-- Intentional: segmentation-only blocks have no text.
-- This migration is irreversible without a data cleanup step
-- (cannot re-add NOT NULL if null rows exist).
ALTER TABLE transcription_blocks ALTER COLUMN text DROP NOT NULL;
ALTER TABLE transcription_blocks ALTER COLUMN text SET DEFAULT '';

View File

@@ -0,0 +1,5 @@
ALTER TABLE ocr_training_runs
ADD COLUMN cer DOUBLE PRECISION,
ADD COLUMN loss DOUBLE PRECISION,
ADD COLUMN accuracy DOUBLE PRECISION,
ADD COLUMN epochs INT;

View File

@@ -123,15 +123,19 @@ class AnnotationControllerTest {
@Test
@WithMockUser(authorities = "ANNOTATE_ALL")
void createAnnotation_returns409_whenOverlap() throws Exception {
void createAnnotation_returns201_whenAnnotationsOverlap() throws Exception {
// Overlapping annotations are allowed — historical letter lines often intersect
UUID docId = UUID.randomUUID();
DocumentAnnotation saved = DocumentAnnotation.builder()
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
.x(0.1).y(0.1).width(0.3).height(0.3).color("#ff0000").build();
when(documentService.getDocumentById(any())).thenReturn(Document.builder().build());
when(annotationService.createAnnotation(any(), any(), any(), any()))
.thenThrow(DomainException.conflict(ErrorCode.ANNOTATION_OVERLAP, "Overlap"));
when(annotationService.createAnnotation(any(), any(), any(), any())).thenReturn(saved);
mockMvc.perform(post("/api/documents/" + UUID.randomUUID() + "/annotations")
mockMvc.perform(post("/api/documents/" + docId + "/annotations")
.contentType(MediaType.APPLICATION_JSON)
.content(ANNOTATION_JSON))
.andExpect(status().isConflict());
.andExpect(status().isCreated());
}
// ─── DELETE /api/documents/{documentId}/annotations/{annotationId} ─────────

View File

@@ -14,6 +14,7 @@ import org.raddatz.familienarchiv.service.FileService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.webmvc.test.autoconfigure.WebMvcTest;
import org.raddatz.familienarchiv.config.SecurityConfig;
import org.springframework.http.MediaType;
import org.springframework.boot.autoconfigure.aop.AopAutoConfiguration;
import org.springframework.context.annotation.Import;
import org.springframework.security.test.context.support.WithMockUser;
@@ -33,6 +34,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.multipart;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.patch;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
@@ -485,6 +487,58 @@ class DocumentControllerTest {
.andExpect(jsonPath("$[0].editorName").value("Emma Müller"));
}
// ─── PATCH /api/documents/{id}/training-labels ───────────────────────────
@Test
void patchTrainingLabels_returns401_whenUnauthenticated() throws Exception {
mockMvc.perform(patch("/api/documents/" + UUID.randomUUID() + "/training-labels")
.contentType(MediaType.APPLICATION_JSON)
.content("{\"label\":\"KURRENT_RECOGNITION\",\"enrolled\":true}"))
.andExpect(status().isUnauthorized());
}
@Test
@WithMockUser
void patchTrainingLabels_returns403_whenMissingWritePermission() throws Exception {
mockMvc.perform(patch("/api/documents/" + UUID.randomUUID() + "/training-labels")
.contentType(MediaType.APPLICATION_JSON)
.content("{\"label\":\"KURRENT_RECOGNITION\",\"enrolled\":true}"))
.andExpect(status().isForbidden());
}
@Test
@WithMockUser(authorities = "WRITE_ALL")
void patchTrainingLabels_returns204_whenAddingLabel() throws Exception {
UUID id = UUID.randomUUID();
mockMvc.perform(patch("/api/documents/" + id + "/training-labels")
.contentType(MediaType.APPLICATION_JSON)
.content("{\"label\":\"KURRENT_RECOGNITION\",\"enrolled\":true}"))
.andExpect(status().isNoContent());
verify(documentService).addTrainingLabel(eq(id), any());
}
@Test
@WithMockUser(authorities = "WRITE_ALL")
void patchTrainingLabels_returns204_whenRemovingLabel() throws Exception {
UUID id = UUID.randomUUID();
mockMvc.perform(patch("/api/documents/" + id + "/training-labels")
.contentType(MediaType.APPLICATION_JSON)
.content("{\"label\":\"KURRENT_SEGMENTATION\",\"enrolled\":false}"))
.andExpect(status().isNoContent());
verify(documentService).removeTrainingLabel(eq(id), any());
}
@Test
@WithMockUser(authorities = "WRITE_ALL")
void patchTrainingLabels_returns400_whenUnknownLabel() throws Exception {
mockMvc.perform(patch("/api/documents/" + UUID.randomUUID() + "/training-labels")
.contentType(MediaType.APPLICATION_JSON)
.content("{\"label\":\"UNKNOWN_GARBAGE\",\"enrolled\":true}"))
.andExpect(status().isBadRequest());
}
// ─── GET /api/documents/{id}/versions/{versionId} ────────────────────────
@Test

View File

@@ -24,6 +24,7 @@ import java.util.List;
import java.util.UUID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
@@ -43,15 +44,18 @@ class OcrControllerTest {
@MockitoBean OcrProgressService ocrProgressService;
@MockitoBean UserService userService;
@MockitoBean CustomUserDetailsService customUserDetailsService;
@MockitoBean TrainingDataExportService trainingDataExportService;
@MockitoBean SegmentationTrainingExportService segmentationTrainingExportService;
@MockitoBean OcrTrainingService ocrTrainingService;
@Test
@WithMockUser(authorities = "WRITE_ALL")
void triggerOcr_returns202_withJobId() throws Exception {
UUID docId = UUID.randomUUID();
UUID jobId = UUID.randomUUID();
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER);
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER, false);
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any())).thenReturn(jobId);
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any(), anyBoolean())).thenReturn(jobId);
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
.contentType(MediaType.APPLICATION_JSON)
@@ -64,7 +68,7 @@ class OcrControllerTest {
@WithMockUser(authorities = "WRITE_ALL")
void triggerOcr_returns400_whenDocumentNotUploaded() throws Exception {
UUID docId = UUID.randomUUID();
when(ocrService.startOcr(eq(docId), any(), any()))
when(ocrService.startOcr(eq(docId), any(), any(), anyBoolean()))
.thenThrow(DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED, "Not uploaded"));
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
@@ -121,6 +125,109 @@ class OcrControllerTest {
.andExpect(jsonPath("$.jobId").value(jobId.toString()));
}
// ─── GET /api/ocr/training-data/export ───────────────────────────────────
@Test
void exportTrainingData_returns401_whenUnauthenticated() throws Exception {
mockMvc.perform(get("/api/ocr/training-data/export"))
.andExpect(status().isUnauthorized());
}
@Test
@WithMockUser(authorities = "READ_ALL")
void exportTrainingData_returns403_whenNotAdmin() throws Exception {
mockMvc.perform(get("/api/ocr/training-data/export"))
.andExpect(status().isForbidden());
}
@Test
@WithMockUser(authorities = "ADMIN")
void exportTrainingData_returns204_whenNoEligibleBlocks() throws Exception {
when(trainingDataExportService.queryEligibleBlocks()).thenReturn(List.of());
mockMvc.perform(get("/api/ocr/training-data/export"))
.andExpect(status().isNoContent());
}
@Test
@WithMockUser(authorities = "ADMIN")
void exportTrainingData_returns200_withZipContentType_whenBlocksExist() throws Exception {
org.raddatz.familienarchiv.model.TranscriptionBlock block =
org.raddatz.familienarchiv.model.TranscriptionBlock.builder()
.id(UUID.randomUUID()).documentId(UUID.randomUUID())
.annotationId(UUID.randomUUID()).text("x").sortOrder(0).build();
when(trainingDataExportService.queryEligibleBlocks()).thenReturn(List.of(block));
when(trainingDataExportService.exportToZip()).thenReturn(out -> {});
mockMvc.perform(get("/api/ocr/training-data/export"))
.andExpect(status().isOk())
.andExpect(result ->
org.assertj.core.api.Assertions.assertThat(
result.getResponse().getContentType()).contains("application/zip"));
}
// ─── POST /api/ocr/train ───────────────────────────────────────────────────
@Test
void triggerTraining_returns401_whenUnauthenticated() throws Exception {
mockMvc.perform(post("/api/ocr/train"))
.andExpect(status().isUnauthorized());
}
@Test
@WithMockUser(authorities = "READ_ALL")
void triggerTraining_returns403_whenNotAdmin() throws Exception {
mockMvc.perform(post("/api/ocr/train"))
.andExpect(status().isForbidden());
}
@Test
@WithMockUser(authorities = "ADMIN")
void triggerTraining_returns409_whenRunAlreadyRunning() throws Exception {
when(ocrTrainingService.triggerTraining(any()))
.thenThrow(DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, "Already running"));
mockMvc.perform(post("/api/ocr/train"))
.andExpect(status().isConflict());
}
@Test
@WithMockUser(authorities = "ADMIN")
void triggerTraining_returns201_withRunInfo() throws Exception {
UUID runId = UUID.randomUUID();
OcrTrainingRun run = OcrTrainingRun.builder()
.id(runId).status(TrainingStatus.DONE)
.blockCount(10).documentCount(3).modelName("german_kurrent").build();
when(ocrTrainingService.triggerTraining(any())).thenReturn(run);
mockMvc.perform(post("/api/ocr/train"))
.andExpect(status().isCreated())
.andExpect(jsonPath("$.status").value("DONE"))
.andExpect(jsonPath("$.blockCount").value(10));
}
// ─── GET /api/ocr/training-info ───────────────────────────────────────────
@Test
@WithMockUser(authorities = "READ_ALL")
void getTrainingInfo_returns403_whenNotAdmin() throws Exception {
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isForbidden());
}
@Test
@WithMockUser(authorities = "ADMIN")
void getTrainingInfo_returns200_withInfo() throws Exception {
OcrTrainingService.TrainingInfoResponse info =
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of());
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.availableBlocks").value(5))
.andExpect(jsonPath("$.ocrServiceAvailable").value(true));
}
@Test
@WithMockUser(authorities = "READ_ALL")
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {

View File

@@ -0,0 +1,136 @@
package org.raddatz.familienarchiv.repository;
import jakarta.persistence.EntityManager;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.PostgresContainerConfig;
import org.raddatz.familienarchiv.config.FlywayConfig;
import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.model.DocumentStatus;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.data.jpa.test.autoconfigure.DataJpaTest;
import org.springframework.boot.jdbc.test.autoconfigure.AutoConfigureTestDatabase;
import org.springframework.context.annotation.Import;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Integration tests that verify DB-level constraints introduced in the OCR pipeline migrations
* are actually enforced by PostgreSQL. These tests exercise constraints that cannot be verified
* by unit tests alone.
*/
@DataJpaTest
@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE)
@Import({PostgresContainerConfig.class, FlywayConfig.class})
class MigrationIntegrationTest {
@Autowired JdbcTemplate jdbc;
@Autowired DocumentRepository documentRepository;
@Autowired EntityManager em;
// ─── V23: chk_annotation_polygon_quad CHECK constraint ───────────────────
@Test
void v23_polygonCheckConstraint_rejectsNonQuadrilateral() {
UUID docId = createDocument();
// A 3-point polygon violates chk_annotation_polygon_quad (must be exactly 4 points or NULL)
assertThatThrownBy(() ->
jdbc.update(
"""
INSERT INTO document_annotations
(id, document_id, page_number, x, y, width, height, color, polygon)
VALUES (gen_random_uuid(), ?, 1, 0.1, 0.1, 0.5, 0.1, '#00C7B1',
'[[0.1,0.1],[0.9,0.1],[0.9,0.2]]'::jsonb)
""",
docId)
).isInstanceOf(DataIntegrityViolationException.class);
}
@Test
void v23_polygonCheckConstraint_allowsNullPolygon() {
UUID docId = createDocument();
int rows = jdbc.update(
"""
INSERT INTO document_annotations
(id, document_id, page_number, x, y, width, height, color, polygon)
VALUES (gen_random_uuid(), ?, 1, 0.1, 0.1, 0.5, 0.1, '#00C7B1', NULL)
""",
docId);
assertThat(rows).isEqualTo(1);
}
@Test
void v23_polygonCheckConstraint_allowsQuadrilateral() {
UUID docId = createDocument();
int rows = jdbc.update(
"""
INSERT INTO document_annotations
(id, document_id, page_number, x, y, width, height, color, polygon)
VALUES (gen_random_uuid(), ?, 1, 0.1, 0.1, 0.5, 0.1, '#00C7B1',
'[[0.1,0.1],[0.9,0.1],[0.9,0.2],[0.1,0.2]]'::jsonb)
""",
docId);
assertThat(rows).isEqualTo(1);
}
// ─── V30: idx_ocr_training_runs_one_running partial unique index ──────────
@Test
@Transactional(propagation = Propagation.NOT_SUPPORTED)
void v30_partialUniqueIndex_preventsTwoRunningTrainingRuns() {
jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name)
VALUES (gen_random_uuid(), 'RUNNING', 10, 2, 'kurrent_v1')
""");
// A second RUNNING row violates the partial unique index
assertThatThrownBy(() ->
jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name)
VALUES (gen_random_uuid(), 'RUNNING', 5, 1, 'kurrent_v1')
""")
).isInstanceOf(DataIntegrityViolationException.class);
// Clean up — runs outside the DataJpaTest transaction, so must be explicit
jdbc.update("DELETE FROM ocr_training_runs WHERE status = 'RUNNING'");
}
@Test
void v30_partialUniqueIndex_allowsMultipleDoneRuns() {
int rows1 = jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name)
VALUES (gen_random_uuid(), 'DONE', 10, 2, 'kurrent_v1')
""");
int rows2 = jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name)
VALUES (gen_random_uuid(), 'DONE', 15, 3, 'kurrent_v2')
""");
assertThat(rows1).isEqualTo(1);
assertThat(rows2).isEqualTo(1);
}
// ─── helpers ─────────────────────────────────────────────────────────────
private UUID createDocument() {
Document doc = documentRepository.save(Document.builder()
.title("Testdokument")
.originalFilename("test.pdf")
.status(DocumentStatus.UPLOADED)
.build());
// Flush so the row is visible to subsequent JdbcTemplate queries within the same transaction
em.flush();
return doc.getId();
}
}

View File

@@ -0,0 +1,135 @@
package org.raddatz.familienarchiv.repository;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.PostgresContainerConfig;
import org.raddatz.familienarchiv.config.FlywayConfig;
import org.raddatz.familienarchiv.model.*;
import java.util.Set;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.jdbc.test.autoconfigure.AutoConfigureTestDatabase;
import org.springframework.boot.data.jpa.test.autoconfigure.DataJpaTest;
import org.springframework.context.annotation.Import;
import java.util.List;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
@DataJpaTest
@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE)
@Import({PostgresContainerConfig.class, FlywayConfig.class})
class TrainingBlockQueryTest {
@Autowired TranscriptionBlockRepository blockRepository;
@Autowired DocumentRepository documentRepository;
@Autowired AnnotationRepository annotationRepository;
private UUID kurrentDocId;
private UUID typewriterDocId;
private UUID kurrentAnnotationId;
private UUID typewriterAnnotationId;
@BeforeEach
void setUp() {
Document kurrentDoc = documentRepository.save(Document.builder()
.title("Kurrent Brief")
.originalFilename("kurrent.pdf")
.status(DocumentStatus.UPLOADED)
.trainingLabels(new java.util.HashSet<>(Set.of(TrainingLabel.KURRENT_RECOGNITION)))
.build());
kurrentDocId = kurrentDoc.getId();
Document typewriterDoc = documentRepository.save(Document.builder()
.title("Getippter Brief")
.originalFilename("typed.pdf")
.status(DocumentStatus.UPLOADED)
.build());
typewriterDocId = typewriterDoc.getId();
kurrentAnnotationId = annotationRepository.save(annotation(kurrentDocId)).getId();
typewriterAnnotationId = annotationRepository.save(annotation(typewriterDocId)).getId();
}
@Test
void findEligibleKurrentBlocks_includesReviewedManualBlock() {
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.MANUAL, true));
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
assertThat(result).hasSize(1);
assertThat(result.get(0).getSource()).isEqualTo(BlockSource.MANUAL);
}
@Test
void findEligibleKurrentBlocks_excludesUnreviewedManualBlock() {
// MANUAL blocks filled by guided OCR but not yet reviewed should not count as checked text
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.MANUAL, false));
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
assertThat(result).isEmpty();
}
@Test
void findEligibleKurrentBlocks_includesReviewedOcrBlock() {
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.OCR, true));
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
assertThat(result).hasSize(1);
assertThat(result.get(0).isReviewed()).isTrue();
}
@Test
void findEligibleKurrentBlocks_excludesUnreviewedOcrBlock() {
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.OCR, false));
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
assertThat(result).isEmpty();
}
@Test
void findEligibleKurrentBlocks_excludesNonEnrolledDocument() {
blockRepository.save(block(typewriterDocId, typewriterAnnotationId, BlockSource.MANUAL, false));
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
assertThat(result).isEmpty();
}
@Test
void findEligibleKurrentBlocks_returnsAllReviewedBlocksAcrossBothSources() {
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.MANUAL, true));
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.OCR, true));
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.MANUAL, false)); // excluded
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.OCR, false)); // excluded
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
assertThat(result).hasSize(2);
}
// ─── helpers ─────────────────────────────────────────────────────────────
private DocumentAnnotation annotation(UUID docId) {
return DocumentAnnotation.builder()
.documentId(docId)
.pageNumber(1)
.x(0.1).y(0.2).width(0.3).height(0.4)
.color("#00C7B1")
.build();
}
private TranscriptionBlock block(UUID docId, UUID annotId, BlockSource source, boolean reviewed) {
return TranscriptionBlock.builder()
.annotationId(annotId)
.documentId(docId)
.text("Liebe Tante")
.sortOrder(0)
.source(source)
.reviewed(reviewed)
.build();
}
}

View File

@@ -9,6 +9,7 @@ import org.raddatz.familienarchiv.dto.CreateAnnotationDTO;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.DocumentAnnotation;
import org.raddatz.familienarchiv.repository.AnnotationRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import java.util.List;
import java.util.Optional;
@@ -20,7 +21,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.http.HttpStatus.CONFLICT;
import static org.springframework.http.HttpStatus.FORBIDDEN;
import static org.springframework.http.HttpStatus.NOT_FOUND;
@@ -28,39 +28,19 @@ import static org.springframework.http.HttpStatus.NOT_FOUND;
class AnnotationServiceTest {
@Mock AnnotationRepository annotationRepository;
@Mock TranscriptionBlockRepository blockRepository;
@InjectMocks AnnotationService annotationService;
// ─── createAnnotation ─────────────────────────────────────────────────────
@Test
void createAnnotation_throwsConflict_whenAnnotationOverlapsExisting() {
void createAnnotation_savesAnnotation() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.3, 0.3, "#ff0000");
DocumentAnnotation existing = DocumentAnnotation.builder()
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
.x(0.2).y(0.2).width(0.3).height(0.3).color("#00ff00").build();
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1))
.thenReturn(List.of(existing));
assertThatThrownBy(() -> annotationService.createAnnotation(docId, dto, userId, null))
.isInstanceOf(DomainException.class)
.satisfies(e -> assertThat(((DomainException) e).getStatus()).isEqualTo(CONFLICT));
verify(annotationRepository, never()).save(any());
}
@Test
void createAnnotation_savesAndReturns_whenNoOverlap() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.0, 0.0, 0.05, 0.05, "#ff0000");
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of());
DocumentAnnotation saved = DocumentAnnotation.builder()
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
.x(0.0).y(0.0).width(0.05).height(0.05).color("#ff0000").createdBy(userId).build();
.x(0.1).y(0.1).width(0.3).height(0.3).color("#ff0000").createdBy(userId).build();
when(annotationRepository.save(any())).thenReturn(saved);
DocumentAnnotation result = annotationService.createAnnotation(docId, dto, userId, null);
@@ -69,6 +49,77 @@ class AnnotationServiceTest {
verify(annotationRepository).save(any());
}
@Test
void createAnnotation_allowsOverlappingAnnotations() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.3, 0.3, "#ff0000");
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
// Should not throw even when overlapping annotations exist on the same page
annotationService.createAnnotation(docId, dto, userId, null);
verify(annotationRepository).save(any());
}
@Test
void createAnnotation_setsFileHash_whenProvided() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.0, 0.0, 0.05, 0.05, "#ff0000");
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
DocumentAnnotation result = annotationService.createAnnotation(docId, dto, userId, "abc123");
assertThat(result.getFileHash()).isEqualTo("abc123");
}
@Test
void createAnnotation_setsNullFileHash_whenNoneProvided() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.0, 0.0, 0.05, 0.05, "#ff0000");
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
DocumentAnnotation result = annotationService.createAnnotation(docId, dto, userId, null);
assertThat(result.getFileHash()).isNull();
}
// ─── createOcrAnnotation ──────────────────────────────────────────────────
@Test
void createOcrAnnotation_savesWithPolygon() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.8, 0.04, "#00C7B1");
List<List<Double>> polygon = List.of(
List.of(0.1, 0.1), List.of(0.9, 0.11),
List.of(0.89, 0.14), List.of(0.11, 0.13));
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
DocumentAnnotation result = annotationService.createOcrAnnotation(
docId, dto, userId, "filehash", polygon);
assertThat(result.getPolygon()).isEqualTo(polygon);
assertThat(result.getDocumentId()).isEqualTo(docId);
verify(annotationRepository).save(any());
}
@Test
void createOcrAnnotation_savesWithNullPolygon_whenPolygonNotProvided() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.8, 0.04, "#00C7B1");
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
DocumentAnnotation result = annotationService.createOcrAnnotation(
docId, dto, userId, "filehash", null);
assertThat(result.getPolygon()).isNull();
verify(annotationRepository).save(any());
}
// ─── deleteAnnotation ─────────────────────────────────────────────────────
@Test
@@ -114,36 +165,42 @@ class AnnotationServiceTest {
annotationService.deleteAnnotation(docId, annotId, ownerId);
verify(blockRepository).deleteByAnnotationId(annotId);
verify(annotationRepository).delete(annotation);
}
@Test
void createAnnotation_setsFileHash_whenProvided() {
void deleteAnnotation_deletesLinkedTranscriptionBlock_beforeAnnotation() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.0, 0.0, 0.05, 0.05, "#ff0000");
String fileHash = "abc123";
UUID annotId = UUID.randomUUID();
UUID ownerId = UUID.randomUUID();
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of());
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
DocumentAnnotation annotation = DocumentAnnotation.builder()
.id(annotId).documentId(docId).createdBy(ownerId).build();
when(annotationRepository.findByIdAndDocumentId(annotId, docId))
.thenReturn(Optional.of(annotation));
DocumentAnnotation result = annotationService.createAnnotation(docId, dto, userId, fileHash);
annotationService.deleteAnnotation(docId, annotId, ownerId);
assertThat(result.getFileHash()).isEqualTo(fileHash);
var inOrder = org.mockito.Mockito.inOrder(blockRepository, annotationRepository);
inOrder.verify(blockRepository).deleteByAnnotationId(annotId);
inOrder.verify(annotationRepository).delete(annotation);
}
@Test
void createAnnotation_setsNullFileHash_whenNoneProvided() {
void deleteAnnotation_throwsForbidden_whenUserIdIsNull() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.0, 0.0, 0.05, 0.05, "#ff0000");
UUID annotId = UUID.randomUUID();
UUID ownerId = UUID.randomUUID();
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of());
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
DocumentAnnotation annotation = DocumentAnnotation.builder()
.id(annotId).documentId(docId).createdBy(ownerId).build();
when(annotationRepository.findByIdAndDocumentId(annotId, docId))
.thenReturn(Optional.of(annotation));
DocumentAnnotation result = annotationService.createAnnotation(docId, dto, userId, null);
assertThat(result.getFileHash()).isNull();
assertThatThrownBy(() -> annotationService.deleteAnnotation(docId, annotId, null))
.isInstanceOf(DomainException.class)
.satisfies(e -> assertThat(((DomainException) e).getStatus()).isEqualTo(FORBIDDEN));
}
// ─── listAnnotations ──────────────────────────────────────────────────────
@@ -183,149 +240,4 @@ class AnnotationServiceTest {
verify(annotationRepository, never()).save(any());
}
// ─── deleteAnnotation — null userId ───────────────────────────────────────
@Test
void deleteAnnotation_throwsForbidden_whenUserIdIsNull() {
UUID docId = UUID.randomUUID();
UUID annotId = UUID.randomUUID();
UUID ownerId = UUID.randomUUID();
DocumentAnnotation annotation = DocumentAnnotation.builder()
.id(annotId).documentId(docId).createdBy(ownerId).build();
when(annotationRepository.findByIdAndDocumentId(annotId, docId))
.thenReturn(Optional.of(annotation));
assertThatThrownBy(() -> annotationService.deleteAnnotation(docId, annotId, null))
.isInstanceOf(DomainException.class)
.satisfies(e -> assertThat(((DomainException) e).getStatus()).isEqualTo(FORBIDDEN));
}
// ─── overlaps — partial overlap cases ────────────────────────────────────
@Test
void createAnnotation_noConflict_whenAnnotationIsToTheLeft() {
// existing: x=0.5, w=0.3 (x2=0.8); dto: x=0.0, w=0.4 (dx2=0.4)
// existing.getX() < dx2 → 0.5 < 0.4 → FALSE → no overlap (first && fails)
UUID docId = UUID.randomUUID();
DocumentAnnotation existing = DocumentAnnotation.builder()
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
.x(0.5).y(0.0).width(0.3).height(0.5).color("#ff0000").build();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.0, 0.0, 0.4, 0.5, "#0000ff");
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of(existing));
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
annotationService.createAnnotation(docId, dto, UUID.randomUUID(), null);
verify(annotationRepository).save(any());
}
@Test
void createAnnotation_noConflict_whenAnnotationIsToTheRight() {
// existing: x=0.0, w=0.1 (ex2=0.1); dto: x=0.2, w=0.3 (dx2=0.5)
// existing.getX() < dx2 → 0.0 < 0.5 → TRUE
// ex2 > dto.getX() → 0.1 > 0.2 → FALSE → no overlap (second && fails)
UUID docId = UUID.randomUUID();
DocumentAnnotation existing = DocumentAnnotation.builder()
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
.x(0.0).y(0.0).width(0.1).height(0.5).color("#ff0000").build();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.2, 0.0, 0.3, 0.5, "#0000ff");
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of(existing));
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
annotationService.createAnnotation(docId, dto, UUID.randomUUID(), null);
verify(annotationRepository).save(any());
}
@Test
void createAnnotation_noConflict_whenAnnotationIsBelow() {
// x ranges overlap, but y ranges don't
// existing: x=0.0, w=0.5, y=0.5, h=0.2 (ey2=0.7)
// dto: x=0.1, w=0.3 (dx2=0.4), y=0.0, h=0.4 (dy2=0.4)
// existing.getX() < dx2 → 0.0 < 0.4 → TRUE
// ex2 > dto.getX() → 0.5 > 0.1 → TRUE
// existing.getY() < dy2 → 0.5 < 0.4 → FALSE → no overlap (third && fails)
UUID docId = UUID.randomUUID();
DocumentAnnotation existing = DocumentAnnotation.builder()
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
.x(0.0).y(0.5).width(0.5).height(0.2).color("#ff0000").build();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.0, 0.3, 0.4, "#0000ff");
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of(existing));
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
annotationService.createAnnotation(docId, dto, UUID.randomUUID(), null);
verify(annotationRepository).save(any());
}
// ─── createOcrAnnotation ──────────────────────────────────────────────────
@Test
void createOcrAnnotation_skipsOverlapCheck_andSavesWithPolygon() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.8, 0.04, "#00C7B1");
List<List<Double>> polygon = List.of(
List.of(0.1, 0.1), List.of(0.9, 0.11),
List.of(0.89, 0.14), List.of(0.11, 0.13));
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
DocumentAnnotation result = annotationService.createOcrAnnotation(
docId, dto, userId, "filehash", polygon);
assertThat(result.getPolygon()).isEqualTo(polygon);
assertThat(result.getDocumentId()).isEqualTo(docId);
verify(annotationRepository).save(any());
verify(annotationRepository, never()).findByDocumentIdAndPageNumber(any(), any(int.class));
}
@Test
void createOcrAnnotation_savesWithNullPolygon_whenPolygonNotProvided() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.8, 0.04, "#00C7B1");
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
DocumentAnnotation result = annotationService.createOcrAnnotation(
docId, dto, userId, "filehash", null);
assertThat(result.getPolygon()).isNull();
verify(annotationRepository).save(any());
}
@Test
void createOcrAnnotation_doesNotCheckOverlap_evenWhenOverlappingAnnotationExists() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.3, 0.3, "#00C7B1");
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
annotationService.createOcrAnnotation(docId, dto, userId, "hash", null);
verify(annotationRepository, never()).findByDocumentIdAndPageNumber(any(), any(int.class));
}
// ─── overlaps — partial overlap cases ────────────────────────────────────
@Test
void createAnnotation_noConflict_whenAnnotationIsAbove() {
// x ranges overlap, y ranges don't — existing is ABOVE the new annotation
// existing: x=0.0, w=0.5, y=0.0, h=0.1 (ey2=0.1)
// dto: x=0.1, w=0.3 (dx2=0.4), y=0.2, h=0.3 (dy2=0.5)
// A: 0.0 < 0.4 → TRUE, B: 0.5 > 0.1 → TRUE, C: 0.0 < 0.5 → TRUE
// D: ey2 > dto.getY() → 0.1 > 0.2 → FALSE → no overlap (fourth && fails)
UUID docId = UUID.randomUUID();
DocumentAnnotation existing = DocumentAnnotation.builder()
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
.x(0.0).y(0.0).width(0.5).height(0.1).color("#ff0000").build();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.2, 0.3, 0.3, "#0000ff");
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of(existing));
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
annotationService.createAnnotation(docId, dto, UUID.randomUUID(), null);
verify(annotationRepository).save(any());
}
}

View File

@@ -60,8 +60,8 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2")));
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null)));
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
@@ -84,7 +84,7 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test")));
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null)));
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
@@ -112,12 +112,12 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(1));
handler.accept(new OcrStreamEvent.Page(0, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -142,7 +142,7 @@ class OcrAsyncRunnerTest {
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any());
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -174,18 +174,18 @@ class OcrAsyncRunnerTest {
List<String> progressMessages = new ArrayList<>();
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(3));
handler.accept(new OcrStreamEvent.Page(0, List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2"))));
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2", null))));
progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Page(1, List.of(
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3"))));
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3", null))));
progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Done(3, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -215,14 +215,14 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(3));
handler.accept(new OcrStreamEvent.Page(0, List.of()));
handler.accept(new OcrStreamEvent.Error(1, "failed"));
handler.accept(new OcrStreamEvent.Page(2, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 1));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -251,13 +251,13 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(2));
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
handler.accept(new OcrStreamEvent.Page(1, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 1));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);

View File

@@ -12,13 +12,19 @@ class OcrClientDefaultStreamTest {
@Test
void defaultStreamBlocksSynthesizesEventsFromExtractBlocks() {
OcrClient client = (pdfUrl, scriptType) -> List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
OcrClient client = new OcrClient() {
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
return List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null),
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3", null));
}
@Override public TrainingResult trainModel(byte[] zip) { return null; }
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
};
List<OcrStreamEvent> events = new ArrayList<>();
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
assertThat(events).hasSize(4);
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
@@ -42,10 +48,14 @@ class OcrClientDefaultStreamTest {
@Test
void defaultStreamBlocksHandlesEmptyResults() {
OcrClient client = (pdfUrl, scriptType) -> List.of();
OcrClient client = new OcrClient() {
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); }
@Override public TrainingResult trainModel(byte[] zip) { return null; }
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
};
List<OcrStreamEvent> events = new ArrayList<>();
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
assertThat(events).hasSize(2);
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);

View File

@@ -142,7 +142,7 @@ class OcrServiceTest {
UUID result = ocrService.startOcr(docId, null, userId);
assertThat(result).isEqualTo(jobId);
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId);
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId, false);
}
@Test

View File

@@ -17,7 +17,7 @@ class OcrStreamEventTest {
@Test
void pageRecordHoldsBlocksAndPageNumber() {
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test");
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test", null);
var page = new OcrStreamEvent.Page(0, List.of(block));
assertThat(page.pageNumber()).isEqualTo(0);
assertThat(page.blocks()).hasSize(1);

View File

@@ -0,0 +1,181 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingLabel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.*;
class OcrTrainingServiceTest {
OcrTrainingRunRepository runRepository;
TrainingDataExportService exportService;
SegmentationTrainingExportService segExportService;
OcrClient ocrClient;
OcrHealthClient healthClient;
TranscriptionBlockRepository blockRepository;
TransactionTemplate txTemplate;
OcrTrainingService service;
@BeforeEach
void setUp() {
runRepository = mock(OcrTrainingRunRepository.class);
exportService = mock(TrainingDataExportService.class);
segExportService = mock(SegmentationTrainingExportService.class);
ocrClient = mock(OcrClient.class);
healthClient = mock(OcrHealthClient.class);
blockRepository = mock(TranscriptionBlockRepository.class);
txTemplate = mock(TransactionTemplate.class);
// Execute transaction callbacks inline so unit tests run without a real DataSource
when(txTemplate.execute(any())).thenAnswer(inv -> {
TransactionCallback<?> callback = inv.getArgument(0);
return callback.doInTransaction(null);
});
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
when(blockRepository.count()).thenReturn(0L);
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
}
// ─── Concurrent guard ─────────────────────────────────────────────────────
@Test
void triggerTraining_throws409_whenRunningRunExists() {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING))
.thenReturn(Optional.of(OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(2).modelName("german_kurrent").build()));
assertThatThrownBy(() -> service.triggerTraining(null))
.isInstanceOf(DomainException.class)
.extracting("status")
.satisfies(s -> assertThat(s.toString()).contains("409"));
}
// ─── Threshold guard ──────────────────────────────────────────────────────
@Test
void triggerTraining_throws422_whenFewerThan5Blocks() {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
when(exportService.queryEligibleBlocks()).thenReturn(List.of(
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build()
));
assertThatThrownBy(() -> service.triggerTraining(null))
.isInstanceOf(DomainException.class);
}
// ─── Happy path ───────────────────────────────────────────────────────────
@Test
void triggerTraining_createsRunWithCorrectCounts_andMarksDone() throws Exception {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
UUID docA = UUID.randomUUID();
UUID docB = UUID.randomUUID();
List<TranscriptionBlock> blocks = List.of(
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docB).build()
);
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
when(exportService.exportToZip()).thenReturn(out -> {});
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
OcrTrainingRun saved = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(2).modelName("german_kurrent").build();
when(runRepository.save(any())).thenReturn(saved);
service.triggerTraining(null);
// Verify run created with correct counts and then updated to DONE
verify(runRepository, times(2)).save(argThat(run ->
run.getBlockCount() == 5 || run.getStatus() == TrainingStatus.DONE));
}
// ─── Failure path ─────────────────────────────────────────────────────────
@Test
void triggerTraining_marksRunFailed_whenOcrClientThrows() throws Exception {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
UUID docA = UUID.randomUUID();
List<TranscriptionBlock> blocks = List.of(
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build()
);
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
when(exportService.exportToZip()).thenReturn(out -> {});
when(ocrClient.trainModel(any())).thenThrow(new RuntimeException("OCR service timeout"));
OcrTrainingRun saved = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent").build();
when(runRepository.save(any())).thenReturn(saved);
service.triggerTraining(null);
verify(runRepository, atLeastOnce()).save(argThat(run ->
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
}
// ─── Orphan recovery ──────────────────────────────────────────────────────
@Test
void recoverOrphanedRuns_marksRunFailed_whenOlderThanOneHour() {
OcrTrainingRun orphan = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent")
.createdAt(Instant.now().minusSeconds(7200))
.build();
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(orphan));
when(runRepository.save(any())).thenReturn(orphan);
service.recoverOrphanedRuns();
verify(runRepository).save(argThat(run ->
run.getStatus() == TrainingStatus.FAILED
&& run.getErrorMessage().contains("Abgebrochen")));
}
@Test
void recoverOrphanedRuns_doesNothing_whenRunIsRecent() {
OcrTrainingRun recent = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent")
.createdAt(Instant.now().minusSeconds(60))
.build();
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(recent));
service.recoverOrphanedRuns();
verify(runRepository, never()).save(any());
}
}

View File

@@ -0,0 +1,290 @@
package org.raddatz.familienarchiv.service;
import org.apache.pdfbox.Loader;
import org.apache.pdfbox.pdmodel.PDDocument;
import org.apache.pdfbox.pdmodel.PDPage;
import org.apache.pdfbox.pdmodel.common.PDRectangle;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.PostgresContainerConfig;
import org.raddatz.familienarchiv.config.FlywayConfig;
import org.raddatz.familienarchiv.model.*;
import org.raddatz.familienarchiv.repository.AnnotationRepository;
import org.raddatz.familienarchiv.repository.DocumentRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.jdbc.test.autoconfigure.AutoConfigureTestDatabase;
import org.springframework.boot.data.jpa.test.autoconfigure.DataJpaTest;
import org.springframework.context.annotation.Import;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.HashSet;
import java.util.Set;
import java.util.UUID;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.*;
@DataJpaTest
@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE)
@Import({PostgresContainerConfig.class, FlywayConfig.class})
class TrainingDataExportServiceTest {
@Autowired TranscriptionBlockRepository blockRepository;
@Autowired DocumentRepository documentRepository;
@Autowired AnnotationRepository annotationRepository;
static byte[] minimalPdfBytes;
@BeforeAll
static void createMinimalPdf() throws Exception {
try (PDDocument doc = new PDDocument()) {
doc.addPage(new PDPage(PDRectangle.A4));
ByteArrayOutputStream out = new ByteArrayOutputStream();
doc.save(out);
minimalPdfBytes = out.toByteArray();
}
}
// ─── Query: enrollment filter ─────────────────────────────────────────────
@Test
void export_includesManualBlockFromEnrolledDocument() throws Exception {
UUID docId = enrolledDoc("enrolled.pdf");
UUID annotId = annotation(docId);
blockRepository.save(manualBlock(docId, annotId, "Liebe Mutter"));
FileService fileService = mockFileService();
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
StreamingResponseBody body = service.exportToZip();
byte[] zipBytes = stream(body);
assertThat(zipEntryNames(zipBytes)).isNotEmpty();
}
@Test
void export_excludesUnreviewedManualBlockFromEnrolledDocument() throws Exception {
// MANUAL blocks whose text hasn't been verified yet should not count as training data
UUID docId = enrolledDoc("unreviewed-manual.pdf");
UUID annotId = annotation(docId);
TranscriptionBlock block = TranscriptionBlock.builder()
.annotationId(annotId).documentId(docId)
.text("Liebe Tante").sortOrder(0)
.source(BlockSource.MANUAL).reviewed(false).build();
blockRepository.save(block);
FileService fileService = mockFileService();
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
StreamingResponseBody body = service.exportToZip();
assertThat(zipEntryNames(stream(body))).isEmpty();
}
@Test
void export_excludesManualBlockFromNonEnrolledDocument() throws Exception {
UUID docId = nonEnrolledDoc("notenrolled.pdf");
UUID annotId = annotation(docId);
blockRepository.save(manualBlock(docId, annotId, "Liebe Tante"));
FileService fileService = mockFileService();
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
StreamingResponseBody body = service.exportToZip();
byte[] zipBytes = stream(body);
assertThat(zipEntryNames(zipBytes)).isEmpty();
}
@Test
void export_includesReviewedOcrBlockFromEnrolledDocument() throws Exception {
UUID docId = enrolledDoc("ocr-reviewed.pdf");
UUID annotId = annotation(docId);
TranscriptionBlock block = TranscriptionBlock.builder()
.annotationId(annotId).documentId(docId)
.text("OCR text").sortOrder(0)
.source(BlockSource.OCR).reviewed(true).build();
blockRepository.save(block);
FileService fileService = mockFileService();
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
StreamingResponseBody body = service.exportToZip();
assertThat(zipEntryNames(stream(body))).isNotEmpty();
}
@Test
void export_excludesUnreviewedOcrBlockFromEnrolledDocument() throws Exception {
UUID docId = enrolledDoc("ocr-unreviewed.pdf");
UUID annotId = annotation(docId);
TranscriptionBlock block = TranscriptionBlock.builder()
.annotationId(annotId).documentId(docId)
.text("Raw OCR").sortOrder(0)
.source(BlockSource.OCR).reviewed(false).build();
blockRepository.save(block);
FileService fileService = mockFileService();
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
StreamingResponseBody body = service.exportToZip();
assertThat(zipEntryNames(stream(body))).isEmpty();
}
// ─── ZIP structure ────────────────────────────────────────────────────────
@Test
void export_producesExactly2EntriesPerBlock_pngAndTxt() throws Exception {
UUID docId = enrolledDoc("zip-struct.pdf");
UUID annotId = annotation(docId);
blockRepository.save(manualBlock(docId, annotId, "Erste Zeile"));
blockRepository.save(manualBlock(docId, annotId, "Zweite Zeile"));
FileService fileService = mockFileService();
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
byte[] zipBytes = stream(service.exportToZip());
var names = zipEntryNames(zipBytes);
assertThat(names).hasSize(4); // 2 blocks × 2 entries each
assertThat(names.stream().filter(n -> n.endsWith(".png")).count()).isEqualTo(2);
assertThat(names.stream().filter(n -> n.endsWith(".xml")).count()).isEqualTo(2);
}
@Test
void export_pageXmlContainsBlockText() throws Exception {
UUID docId = enrolledDoc("txt-content.pdf");
UUID annotId = annotation(docId);
String expectedText = "Sehr geehrte Frau";
blockRepository.save(manualBlock(docId, annotId, expectedText));
FileService fileService = mockFileService();
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
byte[] zipBytes = stream(service.exportToZip());
String xmlContent = readZipEntry(zipBytes, ".xml");
assertThat(xmlContent).contains("<Unicode>" + expectedText + "</Unicode>");
}
@Test
void export_pageXmlEscapesSpecialCharacters() throws Exception {
UUID docId = enrolledDoc("special-chars.pdf");
UUID annotId = annotation(docId);
blockRepository.save(manualBlock(docId, annotId, "A & B < C > D"));
FileService fileService = mockFileService();
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
byte[] zipBytes = stream(service.exportToZip());
String xmlContent = readZipEntry(zipBytes, ".xml");
assertThat(xmlContent).contains("<Unicode>A &amp; B &lt; C &gt; D</Unicode>");
}
// ─── S3 failure resilience ────────────────────────────────────────────────
@Test
void export_skipsDocumentWhenS3DownloadFails_andStillIncludesOtherDocuments() throws Exception {
UUID failDocId = enrolledDoc("fail.pdf");
UUID okDocId = enrolledDoc("ok.pdf");
UUID failAnnotId = annotation(failDocId);
UUID okAnnotId = annotation(okDocId);
blockRepository.save(manualBlock(failDocId, failAnnotId, "Will fail"));
blockRepository.save(manualBlock(okDocId, okAnnotId, "Will succeed"));
FileService fileService = mock(FileService.class);
when(fileService.downloadFileBytes("fail.pdf")).thenThrow(new FileService.StorageFileNotFoundException("missing"));
when(fileService.downloadFileBytes("ok.pdf")).thenReturn(minimalPdfBytes);
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
byte[] zipBytes = stream(service.exportToZip());
var names = zipEntryNames(zipBytes);
// ok.pdf block produces 2 entries; fail.pdf block is skipped
assertThat(names).hasSize(2);
}
// ─── Empty export ─────────────────────────────────────────────────────────
@Test
void queryEligibleBlocks_returnsEmpty_whenNoEnrolledDocuments() {
FileService fileService = mockFileService();
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
assertThat(service.queryEligibleBlocks()).isEmpty();
}
// ─── helpers ─────────────────────────────────────────────────────────────
private UUID enrolledDoc(String filename) {
Document doc = documentRepository.save(Document.builder()
.title(filename).originalFilename(filename).filePath(filename)
.status(DocumentStatus.UPLOADED)
.trainingLabels(new HashSet<>(Set.of(TrainingLabel.KURRENT_RECOGNITION)))
.build());
return doc.getId();
}
private UUID nonEnrolledDoc(String filename) {
Document doc = documentRepository.save(Document.builder()
.title(filename).originalFilename(filename).filePath(filename)
.status(DocumentStatus.UPLOADED)
.build());
return doc.getId();
}
private UUID annotation(UUID docId) {
return annotationRepository.save(DocumentAnnotation.builder()
.documentId(docId).pageNumber(1)
.x(0.1).y(0.1).width(0.8).height(0.1).color("#00C7B1")
.build()).getId();
}
private TranscriptionBlock manualBlock(UUID docId, UUID annotId, String text) {
return TranscriptionBlock.builder()
.annotationId(annotId).documentId(docId)
.text(text).sortOrder(0)
.source(BlockSource.MANUAL).reviewed(true).build();
}
private FileService mockFileService() {
FileService fs = mock(FileService.class);
try {
when(fs.downloadFileBytes(anyString())).thenReturn(minimalPdfBytes);
} catch (Exception e) {
throw new RuntimeException(e);
}
return fs;
}
private static byte[] stream(StreamingResponseBody body) throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
body.writeTo(out);
return out.toByteArray();
}
private static java.util.List<String> zipEntryNames(byte[] zipBytes) throws Exception {
var names = new java.util.ArrayList<String>();
try (ZipInputStream zis = new ZipInputStream(new ByteArrayInputStream(zipBytes))) {
ZipEntry entry;
while ((entry = zis.getNextEntry()) != null) {
names.add(entry.getName());
zis.closeEntry();
}
}
return names;
}
private static String readZipEntry(byte[] zipBytes, String suffix) throws Exception {
try (ZipInputStream zis = new ZipInputStream(new ByteArrayInputStream(zipBytes))) {
ZipEntry entry;
while ((entry = zis.getNextEntry()) != null) {
if (entry.getName().endsWith(suffix)) {
return new String(zis.readAllBytes(), java.nio.charset.StandardCharsets.UTF_8);
}
zis.closeEntry();
}
}
return null;
}
}

View File

@@ -0,0 +1,111 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.model.BlockSource;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.AnnotationRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockVersionRepository;
import java.util.Optional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
class TranscriptionServiceGuidedTest {
TranscriptionBlockRepository blockRepository;
TranscriptionBlockVersionRepository versionRepository;
AnnotationRepository annotationRepository;
AnnotationService annotationService;
DocumentService documentService;
TranscriptionService service;
UUID docId = UUID.randomUUID();
UUID annId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
@BeforeEach
void setUp() {
blockRepository = mock(TranscriptionBlockRepository.class);
versionRepository = mock(TranscriptionBlockVersionRepository.class);
annotationRepository = mock(AnnotationRepository.class);
annotationService = mock(AnnotationService.class);
documentService = mock(DocumentService.class);
service = new TranscriptionService(blockRepository, versionRepository,
annotationRepository, annotationService, documentService);
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
}
@Test
void upsertGuidedBlock_createsNewBlock_whenAnnotationHasNoBlock() {
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.empty());
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "Hallo", userId);
assertThat(result.getText()).isEqualTo("Hallo");
assertThat(result.getSource()).isEqualTo(BlockSource.OCR);
assertThat(result.getAnnotationId()).isEqualTo(annId);
verify(blockRepository).save(any());
}
@Test
void upsertGuidedBlock_updatesExistingOcrBlock_whenAnnotationHasOcrBlock() {
TranscriptionBlock existing = TranscriptionBlock.builder()
.id(UUID.randomUUID())
.annotationId(annId)
.documentId(docId)
.text("old text")
.source(BlockSource.OCR)
.sortOrder(0)
.build();
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(existing));
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "new text", userId);
assertThat(result.getText()).isEqualTo("new text");
verify(blockRepository).save(any());
}
@Test
void upsertGuidedBlock_doesNotOverwriteNonEmptyManualBlock() {
TranscriptionBlock manual = TranscriptionBlock.builder()
.id(UUID.randomUUID())
.annotationId(annId)
.documentId(docId)
.text("manually written")
.source(BlockSource.MANUAL)
.sortOrder(0)
.build();
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(manual));
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "ocr result", userId);
assertThat(result.getText()).isEqualTo("manually written");
verify(blockRepository, never()).save(any());
}
@Test
void upsertGuidedBlock_fillsEmptyManualBlock_withOcrText() {
TranscriptionBlock emptyManual = TranscriptionBlock.builder()
.id(UUID.randomUUID())
.annotationId(annId)
.documentId(docId)
.text("")
.source(BlockSource.MANUAL)
.sortOrder(0)
.build();
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(emptyManual));
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "ocr result", userId);
assertThat(result.getText()).isEqualTo("ocr result");
verify(blockRepository).save(any());
}
}

View File

@@ -72,6 +72,9 @@ services:
- archive-net
# --- OCR: Python microservice (Surya + Kraken) ---
# Single-node only: OCR training reloads the model in-process after each run.
# Running multiple replicas would cause training conflicts and model-state divergence.
# See ADR-001 for the architectural rationale.
ocr-service:
build:
context: ./ocr-service
@@ -87,6 +90,7 @@ services:
- ocr_cache:/root/.cache
environment:
KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel
TRAINING_TOKEN: "${OCR_TRAINING_TOKEN:-}"
OCR_CONFIDENCE_THRESHOLD: "0.3"
OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5"
RECOGNITION_BATCH_SIZE: "16"

View File

@@ -1,7 +1,6 @@
{
"$schema": "https://inlang.com/schema/inlang-message-format",
"error_annotation_not_found": "Die Annotation wurde nicht gefunden.",
"error_annotation_overlap": "Die Annotation überschneidet sich mit einer vorhandenen.",
"annotation_outdated_notice": "Einige Annotationen beziehen sich auf eine frühere Dateiversion und werden nicht angezeigt.",
"error_document_not_found": "Das Dokument wurde nicht gefunden.",
"error_document_no_file": "Diesem Dokument ist noch keine Datei zugeordnet.",
@@ -480,6 +479,7 @@
"scan_collapse": "Scan verkleinern",
"transcription_empty_title": "Noch keine Transkription",
"transcription_empty_desc": "Zeichne Bereiche auf dem Scan und tippe den Text ab, um eine Transkription zu erstellen.",
"transcription_empty_draw_hint": "Zeichnen Sie Bereiche auf dem Dokument, um mit der Transkription zu beginnen.",
"transcription_panel_close": "Panel schließen",
"person_alias_heading": "Namensverlauf",
"person_alias_empty": "Noch keine Namensaenderungen erfasst.",
@@ -505,6 +505,7 @@
"error_ocr_job_not_found": "Der OCR-Auftrag wurde nicht gefunden.",
"error_ocr_document_not_uploaded": "Das Dokument hat keine Datei — OCR ist nicht möglich.",
"error_ocr_processing_failed": "Die OCR-Verarbeitung ist fehlgeschlagen.",
"error_training_already_running": "Es läuft bereits ein Trainings-Vorgang.",
"ocr_script_type_typewriter": "Schreibmaschine",
"ocr_script_type_handwriting_latin": "Handschrift (lateinisch)",
"ocr_script_type_handwriting_kurrent": "Handschrift (Kurrent/Sütterlin)",
@@ -530,7 +531,35 @@
"ocr_status_analyzing_page": "Seite {current} von {total} wird analysiert…",
"ocr_status_done_skipped": "{count} Blöcke erstellt, {skipped} Seite(n) übersprungen",
"ocr_status_error": "OCR fehlgeschlagen",
"ocr_trigger_no_annotations": "Zeichnen Sie zuerst Bereiche auf dem Dokument ein.",
"ocr_section_heading": "OCR ausführen",
"transcription_block_review": "Als geprüft markieren",
"transcription_block_unreview": "Markierung aufheben",
"transcription_reviewed_count": "{reviewed} von {total} geprüft"
"transcription_reviewed_count": "{reviewed} von {total} geprüft",
"training_ocr_heading": "Kurrent-Erkennung trainieren",
"training_ocr_description": "Starte ein neues Training mit den bisher geprüften OCR-Blöcken, um die Erkennungsgenauigkeit für Kurrentschrift zu verbessern.",
"training_ocr_blocks_ready": "{blocks} geprüfte Blöcke bereit / {docs} Dokumente",
"training_ocr_blocks_total": "(von {total} OCR-Blöcken gesamt)",
"training_start_btn": "Training starten",
"training_in_progress": "…",
"training_success": "Training wurde gestartet und abgeschlossen.",
"training_too_few_blocks": "Mindestens 5 geprüfte Blöcke erforderlich (aktuell: {available}).",
"training_service_down": "OCR-Dienst ist nicht erreichbar.",
"training_history_heading": "Verlauf",
"training_history_empty": "Noch keine Trainings-Läufe.",
"training_history_col_date": "Datum",
"training_history_col_status": "Status",
"training_history_col_blocks": "Blöcke",
"training_history_col_docs": "Dokumente",
"training_history_col_cer": "Fehlerrate",
"training_status_done": "Fertig",
"training_status_failed": "Fehler",
"training_status_running": "Läuft…",
"training_seg_heading": "Segmentierung trainieren",
"training_seg_description": "Starte ein neues Training mit annotierten Segmentierungsbereichen, um die Texterkennung zu verbessern.",
"training_seg_blocks_ready": "{blocks} Segmentierungsblöcke bereit",
"training_seg_too_few_blocks": "Mindestens 5 Segmentierungsblöcke erforderlich (aktuell: {available}).",
"transcription_block_segmentation_only": "Nur Segmentierung",
"training_chip_kurrent": "Kurrent-Erkennung",
"training_chip_segmentation": "Segmentierung"
}

View File

@@ -1,7 +1,6 @@
{
"$schema": "https://inlang.com/schema/inlang-message-format",
"error_annotation_not_found": "Annotation not found.",
"error_annotation_overlap": "The annotation overlaps an existing one.",
"annotation_outdated_notice": "Some annotations refer to an earlier file version and are not shown.",
"error_document_not_found": "Document not found.",
"error_document_no_file": "No file is associated with this document.",
@@ -480,6 +479,7 @@
"scan_collapse": "Collapse scan",
"transcription_empty_title": "No transcription yet",
"transcription_empty_desc": "Draw regions on the scan and type the text to create a transcription.",
"transcription_empty_draw_hint": "Draw regions on the document to start transcribing.",
"transcription_panel_close": "Close panel",
"person_alias_heading": "Name history",
"person_alias_empty": "No name changes recorded yet.",
@@ -505,6 +505,7 @@
"error_ocr_job_not_found": "The OCR job was not found.",
"error_ocr_document_not_uploaded": "The document has no file — OCR is not possible.",
"error_ocr_processing_failed": "OCR processing failed.",
"error_training_already_running": "A training run is already in progress.",
"ocr_script_type_typewriter": "Typewriter",
"ocr_script_type_handwriting_latin": "Handwriting (Latin)",
"ocr_script_type_handwriting_kurrent": "Handwriting (Kurrent/Sütterlin)",
@@ -530,7 +531,35 @@
"ocr_status_analyzing_page": "Analyzing page {current} of {total}…",
"ocr_status_done_skipped": "{count} blocks created, {skipped} page(s) skipped",
"ocr_status_error": "OCR failed",
"ocr_trigger_no_annotations": "Draw regions on the document first.",
"ocr_section_heading": "Run OCR",
"transcription_block_review": "Mark as reviewed",
"transcription_block_unreview": "Unmark as reviewed",
"transcription_reviewed_count": "{reviewed} of {total} reviewed"
"transcription_reviewed_count": "{reviewed} of {total} reviewed",
"training_ocr_heading": "Train Kurrent recognition",
"training_ocr_description": "Start a new training run using the reviewed OCR blocks to improve recognition accuracy for Kurrent script.",
"training_ocr_blocks_ready": "{blocks} reviewed blocks ready / {docs} documents",
"training_ocr_blocks_total": "(of {total} OCR blocks total)",
"training_start_btn": "Start training",
"training_in_progress": "…",
"training_success": "Training started and completed.",
"training_too_few_blocks": "At least 5 reviewed blocks required (currently: {available}).",
"training_service_down": "OCR service is unavailable.",
"training_history_heading": "History",
"training_history_empty": "No training runs yet.",
"training_history_col_date": "Date",
"training_history_col_status": "Status",
"training_history_col_blocks": "Blocks",
"training_history_col_docs": "Documents",
"training_history_col_cer": "Error Rate",
"training_status_done": "Done",
"training_status_failed": "Failed",
"training_status_running": "Running…",
"training_seg_heading": "Train segmentation",
"training_seg_description": "Start a new training run using annotated segmentation regions to improve text detection.",
"training_seg_blocks_ready": "{blocks} segmentation blocks ready",
"training_seg_too_few_blocks": "At least 5 segmentation blocks required (currently: {available}).",
"transcription_block_segmentation_only": "Segmentation only",
"training_chip_kurrent": "Kurrent recognition",
"training_chip_segmentation": "Segmentation"
}

View File

@@ -1,7 +1,6 @@
{
"$schema": "https://inlang.com/schema/inlang-message-format",
"error_annotation_not_found": "Anotación no encontrada.",
"error_annotation_overlap": "La anotación se superpone con una existente.",
"annotation_outdated_notice": "Algunas anotaciones hacen referencia a una versión anterior del archivo y no se muestran.",
"error_document_not_found": "Documento no encontrado.",
"error_document_no_file": "No hay ningún archivo asociado a este documento.",
@@ -480,6 +479,7 @@
"scan_collapse": "Reducir escaneo",
"transcription_empty_title": "Sin transcripcion",
"transcription_empty_desc": "Dibuja regiones en el escaneo y escribe el texto para crear una transcripcion.",
"transcription_empty_draw_hint": "Dibuje regiones en el documento para comenzar a transcribir.",
"transcription_panel_close": "Cerrar panel",
"person_alias_heading": "Historial de nombres",
"person_alias_empty": "Aun no se han registrado cambios de nombre.",
@@ -505,6 +505,7 @@
"error_ocr_job_not_found": "No se encontró el trabajo OCR.",
"error_ocr_document_not_uploaded": "El documento no tiene archivo — OCR no es posible.",
"error_ocr_processing_failed": "El procesamiento OCR ha fallado.",
"error_training_already_running": "Ya hay un proceso de entrenamiento en curso.",
"ocr_script_type_typewriter": "Máquina de escribir",
"ocr_script_type_handwriting_latin": "Escritura manuscrita (latina)",
"ocr_script_type_handwriting_kurrent": "Escritura manuscrita (Kurrent/Sütterlin)",
@@ -530,7 +531,35 @@
"ocr_status_analyzing_page": "Analizando página {current} de {total}…",
"ocr_status_done_skipped": "{count} bloques creados, {skipped} página(s) omitida(s)",
"ocr_status_error": "OCR fallido",
"ocr_trigger_no_annotations": "Dibuje regiones en el documento primero.",
"ocr_section_heading": "Ejecutar OCR",
"transcription_block_review": "Marcar como revisado",
"transcription_block_unreview": "Desmarcar como revisado",
"transcription_reviewed_count": "{reviewed} de {total} revisados"
"transcription_reviewed_count": "{reviewed} de {total} revisados",
"training_ocr_heading": "Entrenar reconocimiento Kurrent",
"training_ocr_description": "Inicia un nuevo entrenamiento con los bloques OCR revisados para mejorar la precisión de reconocimiento del script Kurrent.",
"training_ocr_blocks_ready": "{blocks} bloques revisados listos / {docs} documentos",
"training_ocr_blocks_total": "(de {total} bloques OCR en total)",
"training_start_btn": "Iniciar entrenamiento",
"training_in_progress": "…",
"training_success": "Entrenamiento iniciado y completado.",
"training_too_few_blocks": "Se requieren al menos 5 bloques revisados (actualmente: {available}).",
"training_service_down": "El servicio OCR no está disponible.",
"training_history_heading": "Historial",
"training_history_empty": "Todavía no hay ejecuciones de entrenamiento.",
"training_history_col_date": "Fecha",
"training_history_col_status": "Estado",
"training_history_col_blocks": "Bloques",
"training_history_col_docs": "Documentos",
"training_history_col_cer": "Tasa de error",
"training_status_done": "Listo",
"training_status_failed": "Error",
"training_status_running": "Ejecutando…",
"training_seg_heading": "Entrenar segmentación",
"training_seg_description": "Inicia un nuevo entrenamiento con regiones de segmentación anotadas para mejorar la detección de texto.",
"training_seg_blocks_ready": "{blocks} bloques de segmentación listos",
"training_seg_too_few_blocks": "Se requieren al menos 5 bloques de segmentación (actualmente: {available}).",
"transcription_block_segmentation_only": "Solo segmentación",
"training_chip_kurrent": "Reconocimiento Kurrent",
"training_chip_segmentation": "Segmentación"
}

View File

@@ -72,7 +72,9 @@ $effect(() => {
></div>
</div>
<p class="mt-2 text-right text-sm text-gray-500">
{m.ocr_progress_page({ current: String(currentPage), total: String(totalPages) })}
<span aria-live="polite" aria-atomic="true">
{m.ocr_progress_page({ current: String(currentPage), total: String(totalPages) })}
</span>
</p>
</div>
{:else if status === 'error'}

View File

@@ -0,0 +1,92 @@
<script lang="ts">
import TrainingHistory from './TrainingHistory.svelte';
import { m } from '$lib/paraglide/messages.js';
interface Run {
id: string;
status: 'RUNNING' | 'DONE' | 'FAILED';
blockCount: number;
documentCount: number;
modelName: string;
errorMessage?: string;
triggeredBy?: string;
createdAt: string;
completedAt?: string;
}
interface TrainingInfo {
availableBlocks?: number;
totalOcrBlocks?: number;
availableDocuments?: number;
ocrServiceAvailable?: boolean;
lastRun?: Run | null;
runs?: Run[];
}
interface Props {
trainingInfo: TrainingInfo | null;
}
let { trainingInfo }: Props = $props();
let training = $state(false);
let successMessage = $state<string | null>(null);
const available = $derived(trainingInfo?.availableBlocks ?? 0);
const tooFewBlocks = $derived(available < 5);
const serviceDown = $derived(trainingInfo?.ocrServiceAvailable === false);
const disabled = $derived(training || tooFewBlocks || serviceDown);
async function startTraining() {
training = true;
successMessage = null;
try {
const res = await fetch('/api/ocr/train', { method: 'POST' });
if (res.ok) {
successMessage = m.training_success();
setTimeout(() => {
successMessage = null;
}, 5000);
}
} finally {
training = false;
}
}
</script>
<div class="rounded-sm border border-line bg-surface p-6 shadow-sm">
<h2 class="mb-1 font-sans text-sm font-bold text-ink">{m.training_ocr_heading()}</h2>
<p class="mb-4 text-sm text-ink-2">{m.training_ocr_description()}</p>
<p class="mb-3 text-sm text-ink">
{m.training_ocr_blocks_ready({ blocks: available, docs: trainingInfo?.availableDocuments ?? 0 })}
<span class="text-ink-2"
>{m.training_ocr_blocks_total({ total: trainingInfo?.totalOcrBlocks ?? 0 })}</span
>
</p>
<button
onclick={startTraining}
disabled={disabled}
class="rounded-sm bg-primary px-5 py-2 font-sans text-xs font-bold tracking-widest text-primary-fg uppercase transition-opacity hover:opacity-80 focus-visible:ring-2 focus-visible:ring-brand-navy disabled:cursor-not-allowed disabled:opacity-50"
>
{training ? '…' : m.training_start_btn()}
</button>
{#if tooFewBlocks}
<p class="mt-2 text-xs text-ink-3">
{m.training_too_few_blocks({ available })}
</p>
{:else if serviceDown}
<p class="mt-2 text-xs text-orange-600">{m.training_service_down()}</p>
{/if}
{#if successMessage}
<p class="mt-2 text-xs text-green-700">{successMessage}</p>
{/if}
<h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase">
{m.training_history_heading()}
</h3>
<TrainingHistory runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName !== 'blla')} />
</div>

View File

@@ -0,0 +1,96 @@
import { afterEach, describe, expect, it, vi } from 'vitest';
import { cleanup, render } from 'vitest-browser-svelte';
import { page } from 'vitest/browser';
import OcrTrainingCard from './OcrTrainingCard.svelte';
afterEach(cleanup);
afterEach(() => vi.restoreAllMocks());
const baseInfo = {
availableBlocks: 10,
totalOcrBlocks: 20,
availableDocuments: 3,
ocrServiceAvailable: true,
lastRun: null,
runs: []
};
describe('OcrTrainingCard — disabled states', () => {
it('disables button and shows hint when availableBlocks is 0', async () => {
render(OcrTrainingCard, { trainingInfo: { ...baseInfo, availableBlocks: 0 } });
const btn = page.getByRole('button', { name: /Training starten/i });
await expect.element(btn).toBeDisabled();
await expect
.element(page.getByText(/Mindestens 5 geprüfte Blöcke erforderlich/i))
.toBeInTheDocument();
});
it('disables button and shows hint when availableBlocks is less than 5', async () => {
render(OcrTrainingCard, { trainingInfo: { ...baseInfo, availableBlocks: 3 } });
const btn = page.getByRole('button', { name: /Training starten/i });
await expect.element(btn).toBeDisabled();
await expect.element(page.getByText(/Mindestens 5/i)).toBeInTheDocument();
});
it('disables button and shows service-down warning when ocrServiceAvailable is false', async () => {
render(OcrTrainingCard, { trainingInfo: { ...baseInfo, ocrServiceAvailable: false } });
const btn = page.getByRole('button', { name: /Training starten/i });
await expect.element(btn).toBeDisabled();
await expect.element(page.getByText(/OCR-Dienst ist nicht erreichbar/i)).toBeInTheDocument();
});
it('does not show service-down warning when blocks are insufficient', async () => {
// tooFewBlocks hint takes priority over serviceDown hint
render(OcrTrainingCard, {
trainingInfo: { ...baseInfo, availableBlocks: 2, ocrServiceAvailable: false }
});
await expect.element(page.getByText(/Mindestens 5/i)).toBeInTheDocument();
// serviceDown text should NOT appear because tooFewBlocks branch hides it
const serviceMsg = document.querySelector('.text-orange-600');
expect(serviceMsg).toBeNull();
});
});
describe('OcrTrainingCard — enabled state', () => {
it('enables button when availableBlocks >= 5 and service is up', async () => {
render(OcrTrainingCard, { trainingInfo: baseInfo });
const btn = page.getByRole('button', { name: /Training starten/i });
await expect.element(btn).not.toBeDisabled();
});
it('shows block count info text', async () => {
render(OcrTrainingCard, {
trainingInfo: { ...baseInfo, availableBlocks: 7, totalOcrBlocks: 15 }
});
await expect.element(page.getByText(/7/)).toBeInTheDocument();
await expect.element(page.getByText(/von 15 OCR-Blöcken/i)).toBeInTheDocument();
});
});
describe('OcrTrainingCard — in-flight state', () => {
it('shows "…" while POST is in-flight', async () => {
let resolveFetch!: (v: unknown) => void;
const pendingFetch = new Promise((resolve) => {
resolveFetch = resolve;
});
vi.stubGlobal('fetch', vi.fn().mockReturnValue(pendingFetch));
render(OcrTrainingCard, { trainingInfo: baseInfo });
const btn = page.getByRole('button', { name: /Training starten/i });
await btn.click();
// While fetch is still pending the button label becomes "…"
await expect.element(page.getByRole('button', { name: '…' })).toBeInTheDocument();
// Cleanup: resolve the pending promise
resolveFetch({ ok: false });
});
});

View File

@@ -1,37 +1,23 @@
<script lang="ts">
import { untrack } from 'svelte';
import { m } from '$lib/paraglide/messages.js';
import { getConfirmService } from '$lib/services/confirm.svelte';
import ScriptTypeSelect from './ScriptTypeSelect.svelte';
interface Props {
existingBlockCount: number;
blockCount: number;
storedScriptType: string;
onTrigger: (scriptType: string) => void;
onTrigger: (scriptType: string, useExistingAnnotations: boolean) => void;
}
let { existingBlockCount, storedScriptType, onTrigger }: Props = $props();
const { confirm } = getConfirmService();
let { blockCount, storedScriptType, onTrigger }: Props = $props();
import { untrack } from 'svelte';
let selectedScriptType: string = $state(
untrack(() => (storedScriptType && storedScriptType !== 'UNKNOWN' ? storedScriptType : ''))
);
async function handleClick() {
function handleClick() {
if (!selectedScriptType) return;
if (existingBlockCount > 0) {
const confirmed = await confirm({
title: m.ocr_confirm_title(),
body: m.ocr_confirm_body({ count: String(existingBlockCount) }),
confirmLabel: m.ocr_confirm_btn(),
destructive: true
});
if (!confirmed) return;
}
onTrigger(selectedScriptType);
onTrigger(selectedScriptType, true);
}
</script>
@@ -39,11 +25,14 @@ async function handleClick() {
<ScriptTypeSelect bind:value={selectedScriptType} />
<button
type="button"
disabled={!selectedScriptType}
disabled={!selectedScriptType || blockCount === 0}
title={!selectedScriptType ? m.ocr_trigger_btn_disabled() : undefined}
onclick={handleClick}
class="min-h-[44px] w-full rounded-sm bg-brand-navy font-sans text-sm font-medium text-white transition-colors hover:bg-brand-navy/90 disabled:cursor-not-allowed disabled:opacity-50"
>
{m.ocr_trigger_btn()}
</button>
{#if blockCount === 0}
<p class="text-xs text-ink-3">{m.ocr_trigger_no_annotations()}</p>
{/if}
</div>

View File

@@ -0,0 +1,86 @@
<script lang="ts">
import TrainingHistory from './TrainingHistory.svelte';
import { m } from '$lib/paraglide/messages.js';
interface Run {
id: string;
status: 'RUNNING' | 'DONE' | 'FAILED';
blockCount: number;
documentCount: number;
modelName: string;
errorMessage?: string;
triggeredBy?: string;
createdAt: string;
completedAt?: string;
}
interface TrainingInfo {
availableSegBlocks?: number;
ocrServiceAvailable?: boolean;
runs?: Run[];
}
interface Props {
trainingInfo: TrainingInfo | null;
}
let { trainingInfo }: Props = $props();
let training = $state(false);
let successMessage = $state<string | null>(null);
const available = $derived(trainingInfo?.availableSegBlocks ?? 0);
const tooFewBlocks = $derived(available < 5);
const serviceDown = $derived(trainingInfo?.ocrServiceAvailable === false);
const disabled = $derived(training || tooFewBlocks || serviceDown);
async function startTraining() {
training = true;
successMessage = null;
try {
const res = await fetch('/api/ocr/segtrain', { method: 'POST' });
if (res.ok) {
successMessage = m.training_success();
setTimeout(() => {
successMessage = null;
}, 5000);
}
} finally {
training = false;
}
}
</script>
<div class="rounded-sm border border-line bg-surface p-6 shadow-sm">
<h2 class="mb-1 font-sans text-sm font-bold text-ink">{m.training_seg_heading()}</h2>
<p class="mb-4 text-sm text-ink-2">{m.training_seg_description()}</p>
<p class="mb-3 text-sm text-ink">
{m.training_seg_blocks_ready({ blocks: available })}
</p>
<button
onclick={startTraining}
disabled={disabled}
class="rounded-sm bg-primary px-5 py-2 font-sans text-xs font-bold tracking-widest text-primary-fg uppercase transition-opacity hover:opacity-80 focus-visible:ring-2 focus-visible:ring-brand-navy disabled:cursor-not-allowed disabled:opacity-50"
>
{training ? '…' : m.training_start_btn()}
</button>
{#if tooFewBlocks}
<p class="mt-2 text-xs text-ink-3">
{m.training_seg_too_few_blocks({ available })}
</p>
{:else if serviceDown}
<p class="mt-2 text-xs text-orange-600">{m.training_service_down()}</p>
{/if}
{#if successMessage}
<p class="mt-2 text-xs text-green-700">{successMessage}</p>
{/if}
<h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase">
{m.training_history_heading()}
</h3>
<TrainingHistory runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName === 'blla')} />
</div>

View File

@@ -0,0 +1,119 @@
<script lang="ts">
import { m } from '$lib/paraglide/messages.js';
interface Run {
id: string;
status: 'RUNNING' | 'DONE' | 'FAILED';
blockCount: number;
documentCount: number;
modelName: string;
errorMessage?: string;
triggeredBy?: string;
createdAt: string;
completedAt?: string;
cer?: number;
}
interface Props {
runs: Run[];
}
let { runs }: Props = $props();
const dateFormatter = new Intl.DateTimeFormat('de-DE', {
day: 'numeric',
month: 'short',
year: 'numeric'
});
function formatDate(iso: string): string {
return dateFormatter.format(new Date(iso));
}
function formatCer(cer: number | undefined | null): string {
if (cer == null) return '—';
return (cer * 100).toFixed(1) + ' %';
}
</script>
<table class="w-full text-sm">
<thead>
<tr class="border-b border-line text-xs font-bold tracking-widest text-ink-3 uppercase">
<th class="pb-2 text-left">{m.training_history_col_date()}</th>
<th class="pb-2 text-left">{m.training_history_col_status()}</th>
<th class="pb-2 text-right">{m.training_history_col_blocks()}</th>
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_docs()}</th>
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_cer()}</th>
</tr>
</thead>
<tbody>
{#if runs.length === 0}
<tr>
<td colspan="5" class="py-4 text-center text-sm text-ink-2">
{m.training_history_empty()}
</td>
</tr>
{:else}
{#each runs as run (run.id)}
<tr class="border-b border-line/50 last:border-0">
<td class="py-2 text-ink-2">{formatDate(run.createdAt)}</td>
<td class="py-2">
{#if run.status === 'DONE'}
<span
class="inline-flex items-center gap-1 rounded-sm bg-green-100 px-1.5 py-0.5 text-xs font-medium text-green-700"
>
<svg
aria-hidden="true"
class="h-3 w-3 shrink-0"
viewBox="0 0 20 20"
fill="currentColor"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{m.training_status_done()}
</span>
{:else if run.status === 'FAILED'}
<span
class="inline-flex items-center gap-1 rounded-sm bg-red-100 px-1.5 py-0.5 text-xs font-medium text-red-700"
title={run.errorMessage}
>
<svg
aria-hidden="true"
class="h-3 w-3 shrink-0"
viewBox="0 0 20 20"
fill="currentColor"
>
<path
fill-rule="evenodd"
d="M4.293 4.293a1 1 0 011.414 0L10 8.586l4.293-4.293a1 1 0 111.414 1.414L11.414 10l4.293 4.293a1 1 0 01-1.414 1.414L10 11.414l-4.293 4.293a1 1 0 01-1.414-1.414L8.586 10 4.293 5.707a1 1 0 010-1.414z"
clip-rule="evenodd"
/>
</svg>
{m.training_status_failed()}
</span>
{:else}
<span
class="inline-flex items-center gap-1 rounded-sm bg-yellow-100 px-1.5 py-0.5 text-xs font-medium text-yellow-700"
>
<span
aria-hidden="true"
class="h-1.5 w-1.5 animate-pulse rounded-full bg-yellow-500"
></span>
{m.training_status_running()}
</span>
{/if}
</td>
<td class="py-2 text-right text-ink-2">{run.blockCount}</td>
<td class="hidden py-2 text-right text-ink-2 md:table-cell">{run.documentCount}</td>
<td class="hidden py-2 text-right md:table-cell"
>{run.status === 'DONE' && run.cer != null ? formatCer(run.cer) : '—'}</td
>
</tr>
{/each}
{/if}
</tbody>
</table>

View File

@@ -27,6 +27,7 @@ type Props = {
onMoveDown?: () => void;
isFirst?: boolean;
isLast?: boolean;
source?: 'MANUAL' | 'OCR';
};
let {
@@ -48,7 +49,8 @@ let {
onMoveUp,
onMoveDown,
isFirst = false,
isLast = false
isLast = false,
source = 'MANUAL'
}: Props = $props();
let localText = $state(text);
@@ -172,6 +174,11 @@ function handleTextareaMouseUp() {
{label}
</span>
{/if}
{#if (!text || text.trim() === '') && source === 'MANUAL'}
<span class="rounded bg-muted px-1.5 py-0.5 text-xs font-medium text-ink-3"
>{m.transcription_block_segmentation_only()}</span
>
{/if}
</div>
<!-- Textarea -->

View File

@@ -19,7 +19,10 @@ type Props = {
onSaveBlock: (blockId: string, text: string) => Promise<void>;
onDeleteBlock: (blockId: string) => Promise<void>;
onReviewToggle: (blockId: string) => Promise<void>;
onTriggerOcr?: (scriptType: string) => void;
onTriggerOcr?: (scriptType: string, useExistingAnnotations: boolean) => void;
canWrite?: boolean;
trainingLabels?: string[];
onToggleTrainingLabel?: (label: string, enrolled: boolean) => Promise<void>;
};
let {
@@ -34,10 +37,14 @@ let {
onSaveBlock,
onDeleteBlock,
onReviewToggle,
onTriggerOcr
onTriggerOcr,
canWrite = false,
trainingLabels = [],
onToggleTrainingLabel
}: Props = $props();
let activeBlockId: string | null = $state(null);
let localLabels: string[] = $derived.by(() => [...trainingLabels]);
// Sync: when an annotation is clicked on the PDF, activate the corresponding block
$effect(() => {
@@ -50,6 +57,9 @@ let debounceTimers = new SvelteMap<string, ReturnType<typeof setTimeout>>();
let pendingTexts = new SvelteMap<string, string>();
let sortedBlocks = $derived([...blocks].sort((a, b) => a.sortOrder - b.sortOrder));
let hasBlocks = $derived(blocks.length > 0);
let reviewedCount = $derived(blocks.filter((b) => b.reviewed).length);
let totalCount = $derived(blocks.length);
let reviewProgress = $derived(totalCount > 0 ? (reviewedCount / totalCount) * 100 : 0);
function getSaveState(blockId: string): SaveState {
return saveStates.get(blockId) ?? 'idle';
@@ -185,7 +195,7 @@ let dropTargetIdx: number | null = $state(null);
let dragOffsetY: number = $state(0);
let dragStartY = 0;
let capturedEl: HTMLElement | null = null;
let listEl: HTMLElement | null = null;
let listEl: HTMLElement | null = $state(null);
function handleGripDown(e: PointerEvent, blockId: string) {
if (!(e.target as HTMLElement).closest('[data-drag-handle]')) return;
@@ -237,6 +247,23 @@ function handlePointerUp() {
capturedEl = null;
}
async function handleLabelToggle(label: string) {
if (!onToggleTrainingLabel) return;
const enrolled = !localLabels.includes(label);
// Optimistic update
if (enrolled) {
localLabels = [...localLabels, label];
} else {
localLabels = localLabels.filter((l) => l !== label);
}
try {
await onToggleTrainingLabel(label, enrolled);
} catch {
// Revert on failure
localLabels = [...trainingLabels];
}
}
function flushViaBeacon() {
for (const [blockId, text] of pendingTexts) {
clearDebounce(blockId);
@@ -263,78 +290,91 @@ $effect(() => {
});
</script>
<div class="flex h-full flex-col overflow-y-auto bg-surface p-4">
<div class="flex h-full flex-col overflow-y-auto bg-surface">
{#if hasBlocks}
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="flex flex-col gap-3"
bind:this={listEl}
onpointermove={handlePointerMove}
onpointerup={handlePointerUp}
>
{#each sortedBlocks as block, i (block.id)}
{#if dropTargetIdx === i}
<div class="h-1 rounded-full bg-turquoise transition-all"></div>
{/if}
<!-- svelte-ignore a11y_no_static_element_interactions -->
<!-- Sticky review progress header -->
<div class="sticky top-0 z-10 border-b border-line bg-surface px-4 pt-3 pb-2">
<p class="font-sans text-xs text-ink-2">
<span class="font-semibold text-ink">{reviewedCount} / {totalCount}</span> geprüft
</p>
<div class="bg-brand-sand mt-1.5 h-0.5 w-full overflow-hidden rounded-full">
<div
data-block-wrapper
onblur={handleBlur}
onpointerdown={(e) => handleGripDown(e, block.id)}
class="relative transition-all duration-150 {draggedBlockId === block.id ? 'z-10 rounded-lg shadow-xl ring-2 ring-turquoise/40' : ''}"
style={draggedBlockId === block.id ? `transform: translateY(${dragOffsetY}px) scale(1.02); opacity: 0.9;` : ''}
>
<TranscriptionBlock
blockId={block.id}
documentId={documentId}
blockNumber={i + 1}
text={block.text}
label={block.label}
active={activeBlockId === block.id}
reviewed={block.reviewed ?? false}
saveState={getSaveState(block.id)}
canComment={canComment}
currentUserId={currentUserId}
onTextChange={(text) => handleTextChange(block.id, text)}
onFocus={() => handleFocus(block.id)}
onDeleteClick={() => handleDelete(block.id)}
onRetry={() => handleRetry(block.id)}
onReviewToggle={() => onReviewToggle(block.id)}
onMoveUp={() => handleMoveUp(block.id)}
onMoveDown={() => handleMoveDown(block.id)}
isFirst={i === 0}
isLast={i === sortedBlocks.length - 1}
/>
</div>
{/each}
{#if dropTargetIdx === sortedBlocks.length}
<div class="h-1 rounded-full bg-turquoise transition-all"></div>
{/if}
<!-- Next block CTA — dashed outline hint -->
<div
class="flex items-center justify-center rounded border border-dashed border-line px-4 py-5 text-center font-sans text-sm text-ink-3"
>
{m.transcription_next_block_cta({ number: sortedBlocks.length + 1 })}
class="h-full rounded-full bg-brand-mint transition-all duration-300"
style="width: {reviewProgress}%"
></div>
</div>
{#if canRunOcr && onTriggerOcr}
<details class="mt-6">
<summary
class="cursor-pointer font-sans text-xs font-medium text-ink-3 transition-colors hover:text-brand-navy"
</div>
<div class="p-4">
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="flex flex-col gap-3"
bind:this={listEl}
onpointermove={handlePointerMove}
onpointerup={handlePointerUp}
>
{#each sortedBlocks as block, i (block.id)}
{#if dropTargetIdx === i}
<div class="h-1 rounded-full bg-turquoise transition-all"></div>
{/if}
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
data-block-wrapper
onblur={handleBlur}
onpointerdown={(e) => handleGripDown(e, block.id)}
class="relative transition-all duration-150 {draggedBlockId === block.id ? 'z-10 rounded-lg shadow-xl ring-2 ring-turquoise/40' : ''}"
style={draggedBlockId === block.id ? `transform: translateY(${dragOffsetY}px) scale(1.02); opacity: 0.9;` : ''}
>
{m.ocr_rerun_label()}
</summary>
<div class="mt-3 max-w-xs">
<OcrTrigger
existingBlockCount={blocks.length}
storedScriptType={storedScriptType}
onTrigger={onTriggerOcr}
<TranscriptionBlock
blockId={block.id}
documentId={documentId}
blockNumber={i + 1}
text={block.text}
label={block.label}
active={activeBlockId === block.id}
reviewed={block.reviewed ?? false}
saveState={getSaveState(block.id)}
canComment={canComment}
currentUserId={currentUserId}
onTextChange={(text) => handleTextChange(block.id, text)}
onFocus={() => handleFocus(block.id)}
onDeleteClick={() => handleDelete(block.id)}
onRetry={() => handleRetry(block.id)}
onReviewToggle={() => onReviewToggle(block.id)}
onMoveUp={() => handleMoveUp(block.id)}
onMoveDown={() => handleMoveDown(block.id)}
isFirst={i === 0}
isLast={i === sortedBlocks.length - 1}
source={block.source}
/>
</div>
</details>
{/if}
{/each}
{#if dropTargetIdx === sortedBlocks.length}
<div class="h-1 rounded-full bg-turquoise transition-all"></div>
{/if}
<!-- Next block CTA — dashed outline hint -->
<div
class="flex items-center justify-center rounded border border-dashed border-line px-4 py-5 text-center font-sans text-sm text-ink-3"
>
{m.transcription_next_block_cta({ number: sortedBlocks.length + 1 })}
</div>
{#if canRunOcr && onTriggerOcr}
<div class="mt-6">
<p class="mb-3 font-sans text-xs font-bold tracking-widest text-ink-3 uppercase">
{m.ocr_section_heading()}
</p>
<div class="max-w-xs">
<OcrTrigger
blockCount={blocks.length}
storedScriptType={storedScriptType}
onTrigger={onTriggerOcr}
/>
</div>
</div>
{/if}
</div>
</div>
{:else}
<div class="flex flex-1 flex-col items-center justify-center px-6 py-12 text-center">
@@ -352,25 +392,28 @@ $effect(() => {
/>
</svg>
{#if canRunOcr && onTriggerOcr}
<p class="mb-6 max-w-xs text-sm leading-relaxed text-ink-3">
{m.transcription_empty_title()}
</p>
<div class="w-full max-w-xs">
<OcrTrigger
existingBlockCount={0}
storedScriptType={storedScriptType}
onTrigger={onTriggerOcr}
/>
</div>
<p class="mt-4 text-xs text-ink-3">
{m.transcription_empty_desc()}
</p>
{:else}
<p class="max-w-xs text-sm leading-relaxed text-ink-3">
{m.transcription_empty_cta()}
</p>
{/if}
<p class="max-w-xs text-sm leading-relaxed text-ink-3">
{m.transcription_empty_draw_hint()}
</p>
</div>
{/if}
{#if canWrite}
<div class="border-t border-line px-4 py-3">
<p class="mb-2 font-sans text-xs font-medium text-ink-2">Für Training vormerken</p>
<div class="flex flex-wrap gap-2">
{#each [{ label: 'KURRENT_RECOGNITION', display: m.training_chip_kurrent() }, { label: 'KURRENT_SEGMENTATION', display: m.training_chip_segmentation() }] as chip (chip.label)}
<button
type="button"
onclick={() => handleLabelToggle(chip.label)}
class="rounded-full border px-3 py-1 font-sans text-xs font-medium transition-colors {localLabels.includes(chip.label)
? 'border-brand-mint bg-brand-mint text-brand-navy'
: 'border-line bg-surface text-ink-3 hover:border-brand-mint hover:text-brand-navy'}"
>
{chip.display}
</button>
{/each}
</div>
</div>
{/if}
</div>

View File

@@ -13,7 +13,9 @@ const block1 = {
text: 'Block eins',
label: null,
sortOrder: 0,
version: 0
version: 0,
source: 'MANUAL' as const,
reviewed: false
};
const block2 = {
id: 'b2',
@@ -22,7 +24,9 @@ const block2 = {
text: 'Block zwei',
label: null,
sortOrder: 1,
version: 0
version: 0,
source: 'OCR' as const,
reviewed: true
};
function renderView(overrides: Record<string, unknown> = {}, service = createConfirmService()) {
@@ -36,6 +40,7 @@ function renderView(overrides: Record<string, unknown> = {}, service = createCon
onBlockFocus: vi.fn(),
onSaveBlock: vi.fn(),
onDeleteBlock: vi.fn(),
onReviewToggle: vi.fn(),
...overrides
},
context: new Map([[CONFIRM_KEY, service]])
@@ -58,7 +63,7 @@ describe('TranscriptionEditView — rendering', () => {
it('shows empty state when no blocks', async () => {
renderView({ blocks: [] });
await expect.element(page.getByText(/Markiere einen Bereich/)).toBeInTheDocument();
await expect.element(page.getByText(/Zeichnen Sie Bereiche/)).toBeInTheDocument();
});
});
@@ -232,3 +237,23 @@ describe('TranscriptionEditView — delete block', () => {
expect(onDeleteBlock).not.toHaveBeenCalled();
});
});
// ─── Review progress counter ──────────────────────────────────────────────────
describe('TranscriptionEditView — review progress counter', () => {
it('shows reviewed count and total when blocks exist', async () => {
// block1: reviewed=false, block2: reviewed=true → "1 / 2 geprüft"
renderView();
await expect.element(page.getByText(/1 \/ 2 geprüft/)).toBeInTheDocument();
});
it('shows 0 reviewed when no blocks are reviewed', async () => {
renderView({ blocks: [block1] }); // block1.reviewed = false
await expect.element(page.getByText(/0 \/ 1 geprüft/)).toBeInTheDocument();
});
it('does not show progress counter when there are no blocks', async () => {
renderView({ blocks: [] });
await expect.element(page.getByText(/geprüft/)).not.toBeInTheDocument();
});
});

View File

@@ -18,7 +18,6 @@ export type ErrorCode =
| 'IMPORT_ALREADY_RUNNING'
| 'INVALID_RESET_TOKEN'
| 'ANNOTATION_NOT_FOUND'
| 'ANNOTATION_OVERLAP'
| 'TRANSCRIPTION_BLOCK_NOT_FOUND'
| 'TRANSCRIPTION_BLOCK_CONFLICT'
| 'COMMENT_NOT_FOUND'
@@ -26,6 +25,7 @@ export type ErrorCode =
| 'OCR_JOB_NOT_FOUND'
| 'OCR_DOCUMENT_NOT_UPLOADED'
| 'OCR_PROCESSING_FAILED'
| 'TRAINING_ALREADY_RUNNING'
| 'UNAUTHORIZED'
| 'FORBIDDEN'
| 'VALIDATION_ERROR'
@@ -81,8 +81,6 @@ export function getErrorMessage(code: ErrorCode | string | undefined): string {
return m.error_invalid_reset_token();
case 'ANNOTATION_NOT_FOUND':
return m.error_annotation_not_found();
case 'ANNOTATION_OVERLAP':
return m.error_annotation_overlap();
case 'TRANSCRIPTION_BLOCK_NOT_FOUND':
return m.error_transcription_block_not_found();
case 'TRANSCRIPTION_BLOCK_CONFLICT':
@@ -97,6 +95,8 @@ export function getErrorMessage(code: ErrorCode | string | undefined): string {
return m.error_ocr_document_not_uploaded();
case 'OCR_PROCESSING_FAILED':
return m.error_ocr_processing_failed();
case 'TRAINING_ALREADY_RUNNING':
return m.error_training_already_running();
case 'UNAUTHORIZED':
return m.error_unauthorized();
case 'FORBIDDEN':

View File

@@ -116,6 +116,22 @@ export interface paths {
patch?: never;
trace?: never;
};
"/api/documents/{documentId}/transcription-blocks/{blockId}/review": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put: operations["reviewBlock"];
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/documents/{documentId}/transcription-blocks/reorder": {
parameters: {
query?: never;
@@ -212,6 +228,54 @@ export interface paths {
patch?: never;
trace?: never;
};
"/api/ocr/train": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
post: operations["triggerTraining"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/ocr/segtrain": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
post: operations["triggerSegTraining"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/ocr/batch": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
post: operations["triggerBatch"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/notifications/read-all": {
parameters: {
query?: never;
@@ -308,6 +372,22 @@ export interface paths {
patch?: never;
trace?: never;
};
"/api/documents/{documentId}/ocr": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
post: operations["triggerOcr"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/documents/{documentId}/comments": {
parameters: {
query?: never;
@@ -516,6 +596,22 @@ export interface paths {
patch: operations["updateGroup"];
trace?: never;
};
"/api/documents/{id}/training-labels": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch: operations["patchTrainingLabel"];
trace?: never;
};
"/api/documents/{documentId}/comments/{commentId}": {
parameters: {
query?: never;
@@ -628,6 +724,86 @@ export interface paths {
patch?: never;
trace?: never;
};
"/api/ocr/training-info": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get: operations["getTrainingInfo"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/ocr/training-data/export": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get: operations["exportTrainingData"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/ocr/segmentation-training-data/export": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get: operations["exportSegmentationTrainingData"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/ocr/jobs/{jobId}": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get: operations["getJobStatus"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/ocr/jobs/{jobId}/progress": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get: operations["streamProgress"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/notifications": {
parameters: {
query?: never;
@@ -740,6 +916,22 @@ export interface paths {
patch?: never;
trace?: never;
};
"/api/documents/{documentId}/ocr-status": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get: operations["getDocumentOcrStatus"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/documents/search": {
parameters: {
query?: never;
@@ -940,6 +1132,7 @@ export interface components {
name: string;
};
PersonUpdateDTO: {
title?: string;
firstName?: string;
lastName?: string;
alias?: string;
@@ -978,6 +1171,8 @@ export interface components {
receiverIds?: string[];
tags?: string;
metadataComplete?: boolean;
/** @enum {string} */
scriptType?: "UNKNOWN" | "TYPEWRITER" | "HANDWRITING_LATIN" | "HANDWRITING_KURRENT";
};
Document: {
/** Format: uuid */
@@ -1002,9 +1197,12 @@ export interface components {
/** Format: date-time */
updatedAt: string;
metadataComplete: boolean;
/** @enum {string} */
scriptType: "UNKNOWN" | "TYPEWRITER" | "HANDWRITING_LATIN" | "HANDWRITING_KURRENT";
receivers?: components["schemas"]["Person"][];
sender?: components["schemas"]["Person"];
tags?: components["schemas"]["Tag"][];
trainingLabels?: ("KURRENT_RECOGNITION" | "KURRENT_SEGMENTATION")[];
};
UpdateTranscriptionBlockDTO: {
text?: string;
@@ -1017,10 +1215,13 @@ export interface components {
annotationId: string;
/** Format: uuid */
documentId: string;
text: string;
text?: string;
label?: string;
/** Format: int32 */
sortOrder: number;
/** @enum {string} */
source: "MANUAL" | "OCR";
reviewed: boolean;
/** Format: int32 */
version: number;
/** Format: uuid */
@@ -1068,6 +1269,35 @@ export interface components {
/** Format: date-time */
createdAt: string;
};
OcrTrainingRun: {
/** Format: uuid */
id: string;
/** @enum {string} */
status: "RUNNING" | "DONE" | "FAILED";
/** Format: int32 */
blockCount: number;
/** Format: int32 */
documentCount: number;
modelName: string;
/** Format: double */
cer?: number;
/** Format: double */
loss?: number;
/** Format: double */
accuracy?: number;
/** Format: int32 */
epochs?: number;
errorMessage?: string;
/** Format: uuid */
triggeredBy?: string;
/** Format: date-time */
createdAt: string;
/** Format: date-time */
completedAt?: string;
};
BatchOcrDTO: {
documentIds: string[];
};
GroupDTO: {
name?: string;
permissions?: string[];
@@ -1118,6 +1348,11 @@ export interface components {
firstName: string;
lastName: string;
};
TriggerOcrDTO: {
/** @enum {string} */
scriptType?: "UNKNOWN" | "TYPEWRITER" | "HANDWRITING_LATIN" | "HANDWRITING_KURRENT";
useExistingAnnotations?: boolean;
};
CreateAnnotationDTO: {
/** Format: int32 */
pageNumber?: number;
@@ -1130,6 +1365,7 @@ export interface components {
/** Format: double */
height?: number;
color?: string;
polygon?: number[][];
};
DocumentAnnotation: {
/** Format: uuid */
@@ -1148,6 +1384,7 @@ export interface components {
height: number;
color: string;
fileHash?: string;
polygon?: number[][];
/** Format: uuid */
createdBy?: string;
/** Format: date-time */
@@ -1199,6 +1436,10 @@ export interface components {
actorName?: string;
documentTitle?: string;
};
TrainingLabelRequest: {
label?: string;
enrolled?: boolean;
};
StatsDTO: {
/** Format: int64 */
totalPersons?: number;
@@ -1218,15 +1459,54 @@ export interface components {
deathYear?: number;
alias?: string;
notes?: string;
personType?: string;
/** Format: int64 */
documentCount?: number;
personType?: string;
};
TrainingInfoResponse: {
/** Format: int32 */
availableBlocks?: number;
/** Format: int32 */
totalOcrBlocks?: number;
/** Format: int32 */
availableDocuments?: number;
/** Format: int32 */
availableSegBlocks?: number;
ocrServiceAvailable?: boolean;
lastRun?: components["schemas"]["OcrTrainingRun"];
runs?: components["schemas"]["OcrTrainingRun"][];
};
StreamingResponseBody: unknown;
OcrJob: {
/** Format: uuid */
id: string;
/** @enum {string} */
status: "PENDING" | "RUNNING" | "DONE" | "FAILED";
/** Format: int32 */
totalDocuments: number;
/** Format: int32 */
processedDocuments: number;
/** Format: int32 */
errorCount: number;
/** Format: int32 */
skippedCount: number;
progressMessage?: string;
/** Format: uuid */
createdBy?: string;
/** Format: date-time */
createdAt: string;
/** Format: date-time */
updatedAt: string;
};
SseEmitter: {
/** Format: int64 */
timeout?: number;
};
PageNotificationDTO: {
/** Format: int64 */
totalElements?: number;
/** Format: int32 */
totalPages?: number;
/** Format: int64 */
totalElements?: number;
pageable?: components["schemas"]["PageableObject"];
/** Format: int32 */
size?: number;
@@ -1256,10 +1536,6 @@ export interface components {
empty?: boolean;
unsorted?: boolean;
};
SseEmitter: {
/** Format: int64 */
timeout?: number;
};
DocumentVersionSummary: {
/** Format: uuid */
id: string;
@@ -1292,6 +1568,15 @@ export interface components {
/** Format: date-time */
changedAt: string;
};
OcrStatusDTO: {
status?: string;
/** Format: uuid */
jobId?: string;
/** Format: int32 */
currentPage?: number;
/** Format: int32 */
totalPages?: number;
};
DocumentSearchResult: {
documents?: components["schemas"]["Document"][];
/** Format: int64 */
@@ -1702,6 +1987,29 @@ export interface operations {
};
};
};
reviewBlock: {
parameters: {
query?: never;
header?: never;
path: {
documentId: string;
blockId: string;
};
cookie?: never;
};
requestBody?: never;
responses: {
/** @description OK */
200: {
headers: {
[name: string]: unknown;
};
content: {
"*/*": components["schemas"]["TranscriptionBlock"];
};
};
};
};
reorderBlocks: {
parameters: {
query?: never;
@@ -1914,6 +2222,72 @@ export interface operations {
};
};
};
triggerTraining: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Created */
201: {
headers: {
[name: string]: unknown;
};
content: {
"*/*": components["schemas"]["OcrTrainingRun"];
};
};
};
};
triggerSegTraining: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Created */
201: {
headers: {
[name: string]: unknown;
};
content: {
"*/*": components["schemas"]["OcrTrainingRun"];
};
};
};
};
triggerBatch: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["BatchOcrDTO"];
};
};
responses: {
/** @description Accepted */
202: {
headers: {
[name: string]: unknown;
};
content: {
"*/*": {
[key: string]: string;
};
};
};
};
};
markAllRead: {
parameters: {
query?: never;
@@ -2124,6 +2498,34 @@ export interface operations {
};
};
};
triggerOcr: {
parameters: {
query?: never;
header?: never;
path: {
documentId: string;
};
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["TriggerOcrDTO"];
};
};
responses: {
/** @description Accepted */
202: {
headers: {
[name: string]: unknown;
};
content: {
"*/*": {
[key: string]: string;
};
};
};
};
};
getDocumentComments: {
parameters: {
query?: never;
@@ -2521,6 +2923,30 @@ export interface operations {
};
};
};
patchTrainingLabel: {
parameters: {
query?: never;
header?: never;
path: {
id: string;
};
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["TrainingLabelRequest"];
};
};
responses: {
/** @description OK */
200: {
headers: {
[name: string]: unknown;
};
content?: never;
};
};
};
deleteComment: {
parameters: {
query?: never;
@@ -2701,6 +3127,110 @@ export interface operations {
};
};
};
getTrainingInfo: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description OK */
200: {
headers: {
[name: string]: unknown;
};
content: {
"*/*": components["schemas"]["TrainingInfoResponse"];
};
};
};
};
exportTrainingData: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description OK */
200: {
headers: {
[name: string]: unknown;
};
content: {
"*/*": components["schemas"]["StreamingResponseBody"];
};
};
};
};
exportSegmentationTrainingData: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description OK */
200: {
headers: {
[name: string]: unknown;
};
content: {
"*/*": components["schemas"]["StreamingResponseBody"];
};
};
};
};
getJobStatus: {
parameters: {
query?: never;
header?: never;
path: {
jobId: string;
};
cookie?: never;
};
requestBody?: never;
responses: {
/** @description OK */
200: {
headers: {
[name: string]: unknown;
};
content: {
"*/*": components["schemas"]["OcrJob"];
};
};
};
};
streamProgress: {
parameters: {
query?: never;
header?: never;
path: {
jobId: string;
};
cookie?: never;
};
requestBody?: never;
responses: {
/** @description OK */
200: {
headers: {
[name: string]: unknown;
};
content: {
"text/event-stream": components["schemas"]["SseEmitter"];
};
};
};
};
getNotifications: {
parameters: {
query?: {
@@ -2860,6 +3390,28 @@ export interface operations {
};
};
};
getDocumentOcrStatus: {
parameters: {
query?: never;
header?: never;
path: {
documentId: string;
};
cookie?: never;
};
requestBody?: never;
responses: {
/** @description OK */
200: {
headers: {
[name: string]: unknown;
};
content: {
"*/*": components["schemas"]["OcrStatusDTO"];
};
};
};
};
search_1: {
parameters: {
query?: {

View File

@@ -1,6 +1,13 @@
<script lang="ts">
import { onDestroy } from 'svelte';
import { m } from '$lib/paraglide/messages.js';
import OcrTrainingCard from '$lib/components/OcrTrainingCard.svelte';
import SegmentationTrainingCard from '$lib/components/SegmentationTrainingCard.svelte';
import type { components } from '$lib/generated/api.js';
type TrainingInfo = components['schemas']['TrainingInfoResponse'];
let trainingInfo: TrainingInfo | null = $state(null);
let backfillResult: number | null = $state(null);
let backfillLoading = $state(false);
@@ -51,8 +58,16 @@ async function triggerImport() {
}
}
async function fetchTrainingInfo() {
const res = await fetch('/api/ocr/training-info');
if (res.ok) {
trainingInfo = await res.json();
}
}
$effect(() => {
fetchImportStatus();
fetchTrainingInfo();
});
onDestroy(() => stopPolling());
@@ -88,6 +103,12 @@ async function backfillFileHashes() {
<div class="flex-1 overflow-y-auto p-6">
<div class="mx-auto max-w-2xl space-y-5">
<!-- OCR Recognition Training -->
<OcrTrainingCard trainingInfo={trainingInfo} />
<!-- OCR Segmentation Training -->
<SegmentationTrainingCard trainingInfo={trainingInfo} />
<!-- Backfill versions -->
<div class="rounded-sm border border-line bg-surface p-6 shadow-sm">
<h2 class="mb-1 font-sans text-sm font-bold text-ink">{m.admin_system_backfill_heading()}</h2>

View File

@@ -78,6 +78,8 @@ describe('Admin system page — mass import card', () => {
startedAt: null
})
})
// training info fetch → empty
.mockResolvedValueOnce({ ok: true, json: async () => ({}) })
// trigger POST → returns RUNNING immediately
.mockResolvedValueOnce({
ok: true,

View File

@@ -129,20 +129,29 @@ async function reviewToggle(blockId: string) {
transcriptionBlocks = transcriptionBlocks.map((b) => (b.id === blockId ? updated : b));
}
async function toggleTrainingLabel(label: string, enrolled: boolean) {
const res = await fetch(`/api/documents/${doc.id}/training-labels`, {
method: 'PATCH',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ label, enrolled })
});
if (!res.ok) throw new Error('Failed to update training label');
}
let ocrRunning = $state(false);
let ocrProgressMessage = $state('');
let ocrErrorMessage = $state('');
let ocrPollTimer = $state<ReturnType<typeof setInterval> | null>(null);
let ocrSkippedPages = $state(0);
async function triggerOcr(scriptType: string) {
async function triggerOcr(scriptType: string, useExistingAnnotations: boolean = false) {
ocrRunning = true;
ocrErrorMessage = '';
try {
const res = await fetch(`/api/documents/${doc.id}/ocr`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ scriptType })
body: JSON.stringify({ scriptType, useExistingAnnotations })
});
if (res.ok) {
const data = await res.json();
@@ -449,11 +458,14 @@ onMount(() => {
activeAnnotationId={activeAnnotationId}
storedScriptType={doc.scriptType ?? ''}
canRunOcr={canWrite && !!doc.filePath}
canWrite={canWrite}
trainingLabels={doc.trainingLabels ?? []}
onBlockFocus={handleBlockFocus}
onSaveBlock={saveBlock}
onDeleteBlock={deleteBlock}
onReviewToggle={reviewToggle}
onTriggerOcr={triggerOcr}
onToggleTrainingLabel={toggleTrainingLabel}
/>
{/if}
</div>

View File

@@ -23,4 +23,4 @@ COPY . .
EXPOSE 8000
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]

View File

@@ -47,7 +47,7 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
pred_it = rpred.rpred(_model, image, baseline_seg)
for record in pred_it:
polygon_pts = record.cuts if hasattr(record, "cuts") else []
polygon_pts = record.boundary if hasattr(record, "boundary") and record.boundary else []
if polygon_pts:
xs = [p[0] for p in polygon_pts]
@@ -55,8 +55,8 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
x1, y1 = min(xs), min(ys)
x2, y2 = max(xs), max(ys)
else:
xs = [p[0] for p in record.line]
ys = [p[1] for p in record.line]
xs = [p[0] for p in record.baseline]
ys = [p[1] for p in record.baseline]
x1, y1 = min(xs), min(ys) - 5
x2, y2 = max(xs), max(ys) + 5
@@ -79,6 +79,54 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
return blocks
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
"""Crop image to a normalized region and run Kraken recognition on the crop.
Used for guided OCR — skips full-page layout detection entirely.
A single synthetic baseline spanning the full crop width is used so that
blla.segment() (which crashes on small crops) is never called.
Coordinates are normalized to [0, 1].
"""
from kraken import rpred
from kraken.containers import Segmentation, BaselineLine
if _model is None:
raise RuntimeError("Kraken model is not loaded")
pw, ph = image.size
x1 = max(0, int(x * pw))
y1 = max(0, int(y * ph))
x2 = min(pw, int((x + w) * pw))
y2 = min(ph, int((y + h) * ph))
crop = image.crop((x1, y1, x2, y2))
cw, ch = crop.size
if cw == 0 or ch == 0:
return ""
# Single synthetic baseline at vertical midpoint, spanning full crop width.
# Kraken's bounds check is >= (not >), so all coordinates must be < image
# dimension — use cw-1 / ch-1 to stay strictly inside.
mid_y = ch // 2
synthetic_seg = Segmentation(
type="baselines",
imagename="",
text_direction="horizontal-lr",
script_detection=False,
lines=[
BaselineLine(
id="line0",
baseline=[(0, mid_y), (cw - 1, mid_y)],
boundary=[(0, 0), (cw - 1, 0), (cw - 1, ch - 1), (0, ch - 1)],
)
],
regions={},
line_orders=[],
)
pred_it = rpred.rpred(_model, crop, synthetic_seg)
return " ".join(r.prediction for r in pred_it)
def extract_blocks(images: list, language: str = "de") -> list[dict]:
"""Run Kraken segmentation + recognition on a list of PIL images.

View File

@@ -81,6 +81,25 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
return blocks
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
"""Crop image to a normalized region and run Surya recognition on the crop.
Used for guided OCR — skips full-page layout detection and only processes
the given bounding box. Coordinates are normalized to [0, 1].
"""
load_models()
pw, ph = image.size
x1 = max(0, int(x * pw))
y1 = max(0, int(y * ph))
x2 = min(pw, int((x + w) * pw))
y2 = min(ph, int((y + h) * ph))
crop = image.crop((x1, y1, x2, y2))
predictions = _recognition_predictor([crop], det_predictor=_detection_predictor)
return " ".join(line.text for line in predictions[0].text_lines)
def extract_blocks(images: list, language: str = "de") -> list[dict]:
"""Run Surya OCR on a list of PIL images (one per page).

View File

@@ -1,16 +1,23 @@
"""OCR microservice — FastAPI app with Surya and Kraken engine support."""
import asyncio
import glob
import io
import json
import logging
import os
import re
import shutil
import subprocess
import tempfile
import zipfile
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from urllib.parse import urlparse
import httpx
import pypdfium2 as pdfium
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, Header, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from PIL import Image
@@ -19,6 +26,9 @@ from engines import kraken as kraken_engine
from engines import surya as surya_engine
from models import OcrBlock, OcrRequest
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -100,7 +110,12 @@ async def run_ocr(request: OcrRequest):
@app.post("/ocr/stream")
async def run_ocr_stream(request: OcrRequest):
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page."""
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page.
When request.regions is provided, runs in guided mode: each region is cropped
and recognized individually, skipping full-page layout detection. The response
blocks include the annotationId from the region.
"""
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
@@ -115,6 +130,81 @@ async def run_ocr_stream(request: OcrRequest):
detail="Kraken model not available — cannot process Kurrent script",
)
engine = kraken_engine if use_kraken else surya_engine
if request.regions:
# Guided mode: recognize only the user-drawn annotation regions
regions_by_page: dict[int, list] = {}
for region in request.regions:
regions_by_page.setdefault(region.pageNumber, []).append(region)
async def generate_guided():
total_pages = len(images)
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
total_blocks = 0
skipped_pages = 0
for page_idx, image in enumerate(images, start=1):
page_regions = regions_by_page.get(page_idx, [])
if not page_regions:
yield json.dumps({
"type": "page",
"pageNumber": page_idx,
"blocks": [],
}) + "\n"
del image
continue
try:
blocks = []
for region in page_regions:
text = await asyncio.to_thread(
engine.extract_region_text, image,
region.x, region.y, region.width, region.height,
)
blocks.append({
"pageNumber": page_idx,
"x": region.x,
"y": region.y,
"width": region.width,
"height": region.height,
"polygon": None,
"text": text,
"annotationId": region.annotationId,
})
total_blocks += len(blocks)
yield json.dumps({
"type": "page",
"pageNumber": page_idx,
"blocks": blocks,
}) + "\n"
except Exception:
logger.exception("Guided OCR failed on page %d", page_idx)
skipped_pages += 1
yield json.dumps({
"type": "error",
"pageNumber": page_idx,
"message": f"Guided OCR processing failed on page {page_idx}",
}) + "\n"
finally:
del image
yield json.dumps({
"type": "done",
"totalBlocks": total_blocks,
"skippedPages": skipped_pages,
}) + "\n"
return StreamingResponse(
generate_guided(),
media_type="application/x-ndjson",
headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"},
)
async def generate():
total_pages = len(images)
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
@@ -124,7 +214,6 @@ async def run_ocr_stream(request: OcrRequest):
for page_idx, image in enumerate(images, start=1):
try:
engine = kraken_engine if use_kraken else surya_engine
blocks = await asyncio.to_thread(
engine.extract_page_blocks, image, page_idx, request.language
)
@@ -169,6 +258,262 @@ async def run_ocr_stream(request: OcrRequest):
)
def _check_training_token(x_training_token: str | None) -> None:
"""Validate training token — fails closed when TRAINING_TOKEN is not configured."""
if not TRAINING_TOKEN:
raise HTTPException(status_code=503, detail="Training not configured on this node")
if x_training_token != TRAINING_TOKEN:
raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token")
def _validate_zip_entry(name: str, extract_dir: str) -> None:
"""Reject ZIP Slip attacks: path traversal and absolute paths."""
if os.path.isabs(name) or name.startswith(".."):
raise HTTPException(status_code=400, detail=f"Unsafe ZIP entry: {name}")
resolved = os.path.realpath(os.path.join(extract_dir, name))
if not resolved.startswith(os.path.realpath(extract_dir)):
raise HTTPException(status_code=400, detail=f"ZIP Slip detected: {name}")
def _rotate_backups(model_path: str, keep: int = 3) -> None:
"""Keep only the last `keep` timestamped backups of the model."""
pattern = model_path + ".*.bak"
backups = sorted(glob.glob(pattern))
for old in backups[:-keep]:
try:
os.remove(old)
except OSError:
logger.warning("Could not remove old backup: %s", old)
def _parse_best_checkpoint(checkpoint_dir: str) -> tuple[float | None, int]:
"""Parse checkpoint filenames to find the best validation metric and epoch count.
Kraken saves checkpoints as e.g. ``checkpoint_03-0.9500.ckpt``.
Returns (best_accuracy, epoch_count).
"""
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
best_acc: float | None = None
max_epoch = 0
for fname in os.listdir(checkpoint_dir):
m = pattern.match(fname)
if m:
epoch = int(m.group(1))
acc = float(m.group(2))
max_epoch = max(max_epoch, epoch)
if best_acc is None or acc > best_acc:
best_acc = acc
return best_acc, max_epoch
def _find_best_model(checkpoint_dir: str) -> str | None:
"""Return the best final model file produced by ketos train.
With --weights-format coreml, ketos writes ``best_<score>.mlmodel``.
Falls back to any .mlmodel in the directory.
"""
# Prefer the named best file (e.g. best_0.8256.mlmodel or best_0.8256.safetensors)
best_pattern = re.compile(r"best_([0-9.]+)\.(mlmodel|safetensors)$")
best_acc: float | None = None
best_path: str | None = None
for fname in os.listdir(checkpoint_dir):
m = best_pattern.match(fname)
if m:
acc = float(m.group(1))
if best_acc is None or acc > best_acc:
best_acc = acc
best_path = os.path.join(checkpoint_dir, fname)
if best_path:
return best_path
# Fallback: any .mlmodel file
for fname in os.listdir(checkpoint_dir):
if fname.endswith(".mlmodel"):
return os.path.join(checkpoint_dir, fname)
return None
@app.post("/train")
async def train_model(
file: UploadFile,
x_training_token: str | None = Header(default=None),
):
"""Fine-tune the Kurrent recognition model with uploaded training data.
Accepts a ZIP archive containing .png/.gt.txt training pairs exported
by the Java backend. Training mutates in-process model state — not safe
if the service is replicated.
"""
_check_training_token(x_training_token)
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
zip_bytes = await file.read()
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
log.info("Starting training run %s", training_run_id)
def _run_training() -> dict:
with tempfile.TemporaryDirectory() as tmp_dir:
# Extract ZIP with safety checks
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
for entry in zf.namelist():
_validate_zip_entry(entry, tmp_dir)
zf.extractall(tmp_dir)
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
ground_truth = glob.glob(os.path.join(tmp_dir, "*.xml"))
if not ground_truth:
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
log.info("Training on %d ground-truth pairs", len(ground_truth))
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
cmd = [
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
"train",
"-f", "page",
"--weights-format", "coreml",
"-o", checkpoint_dir,
"-q", "fixed",
"-N", "10",
"-B", "1",
]
if os.path.exists(KRAKEN_MODEL_PATH):
cmd += ["-i", KRAKEN_MODEL_PATH]
cmd += ground_truth
log.info("Running: %s", " ".join(cmd[:6]) + " ...")
proc = subprocess.run(cmd, capture_output=True, text=True)
log.info("ketos train stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
if proc.stderr:
log.info("ketos train stderr: %s", proc.stderr[-2000:])
if proc.returncode != 0:
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
log.info("Training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
# Find the best model file produced by training
best_model = _find_best_model(checkpoint_dir)
if best_model is None:
raise RuntimeError("Training produced no model file")
# Backup existing model and replace
if os.path.exists(KRAKEN_MODEL_PATH):
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
backup_path = f"{KRAKEN_MODEL_PATH}.{timestamp}.bak"
shutil.copy2(KRAKEN_MODEL_PATH, backup_path)
log.info("Backed up model to %s", backup_path)
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
shutil.copy2(best_model, KRAKEN_MODEL_PATH)
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
# Reload model in-process
kraken_engine.load_models()
log.info("Reloaded Kraken model in-process")
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_training)
return result
@app.post("/segtrain")
async def segtrain_model(
file: UploadFile,
x_training_token: str | None = Header(default=None),
):
"""Fine-tune the blla segmentation model with uploaded PAGE XML training data.
Accepts a ZIP archive containing .png/.xml (PAGE XML) training pairs exported
by the Java backend. Training mutates in-process model state — not safe
if the service is replicated.
"""
_check_training_token(x_training_token)
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
zip_bytes = await file.read()
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
log.info("Starting segmentation training run %s", training_run_id)
blla_model_path = os.environ.get("BLLA_MODEL_PATH", "/app/models/blla.mlmodel")
def _run_segtrain() -> dict:
with tempfile.TemporaryDirectory() as tmp_dir:
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
for entry in zf.namelist():
_validate_zip_entry(entry, tmp_dir)
zf.extractall(tmp_dir)
log.info("Extracted %d ZIP entries for segmentation training", len(os.listdir(tmp_dir)))
xml_files = glob.glob(os.path.join(tmp_dir, "*.xml"))
if not xml_files:
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
log.info("Training on %d PAGE XML files", len(xml_files))
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
cmd = [
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
"segtrain",
"-o", checkpoint_dir,
"-q", "fixed",
"-N", "10",
]
if os.path.exists(blla_model_path):
cmd += ["-i", blla_model_path, "--resize", "both"]
else:
# No pretrained model — train from scratch with reduced height (800px)
# to keep peak RAM under ~200 MB on CPU (default 1800px uses ~500 MB+)
cmd += [
"-s",
"[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 "
"Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]",
]
cmd += xml_files
log.info("Running: %s", " ".join(cmd[:5]) + " ...")
proc = subprocess.run(cmd, capture_output=True, text=True)
log.info("ketos segtrain stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
if proc.stderr:
log.info("ketos segtrain stderr: %s", proc.stderr[-2000:])
if proc.returncode != 0:
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
best_model = _find_best_model(checkpoint_dir)
if best_model is None:
raise RuntimeError("Segmentation training produced no model file")
if os.path.exists(blla_model_path):
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
backup_path = f"{blla_model_path}.{timestamp}.bak"
shutil.copy2(blla_model_path, backup_path)
_rotate_backups(blla_model_path, keep=3)
shutil.copy2(best_model, blla_model_path)
log.info("Replaced blla model at %s", blla_model_path)
return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
result = await asyncio.to_thread(_run_segtrain)
return result
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
_validate_url(url)

View File

@@ -1,12 +1,24 @@
from pydantic import BaseModel, ConfigDict
class OcrRegion(BaseModel):
model_config = ConfigDict(populate_by_name=True)
annotationId: str
pageNumber: int
x: float
y: float
width: float
height: float
class OcrRequest(BaseModel):
model_config = ConfigDict(populate_by_name=True)
pdfUrl: str
scriptType: str = "UNKNOWN"
language: str = "de"
regions: list[OcrRegion] | None = None
class OcrBlock(BaseModel):
@@ -19,3 +31,4 @@ class OcrBlock(BaseModel):
height: float
polygon: list[list[float]] | None = None
text: str
annotationId: str | None = None

View File

@@ -0,0 +1,69 @@
"""Tests for /train and /segtrain endpoint authentication."""
import io
import zipfile
from unittest.mock import AsyncMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from main import app
def _minimal_zip() -> bytes:
"""Return a ZIP with one .xml file so endpoint validation passes."""
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("page_01.xml", "<PcGts/>")
return buf.getvalue()
# ─── Missing TRAINING_TOKEN → fail closed ─────────────────────────────────────
@pytest.mark.asyncio
async def test_train_returns_503_when_training_token_not_configured():
"""POST /train must return 503 when TRAINING_TOKEN env var is empty.
An empty token means the service was started without training configured.
Allowing requests through would grant unauthenticated access to the
training endpoint, contradicting the principle of failing closed.
"""
with patch("main.TRAINING_TOKEN", ""), \
patch("main._models_ready", True):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/train",
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
)
assert response.status_code == 503
@pytest.mark.asyncio
async def test_segtrain_returns_503_when_training_token_not_configured():
"""POST /segtrain must return 503 when TRAINING_TOKEN env var is empty."""
with patch("main.TRAINING_TOKEN", ""), \
patch("main._models_ready", True):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/segtrain",
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
)
assert response.status_code == 503
@pytest.mark.asyncio
async def test_train_returns_403_when_token_wrong():
"""POST /train must return 403 when TRAINING_TOKEN is set but header is wrong."""
with patch("main.TRAINING_TOKEN", "secret-token"), \
patch("main._models_ready", True):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/train",
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
headers={"X-Training-Token": "wrong-token"},
)
assert response.status_code == 403