feat(ocr): full OCR pipeline with polygon annotations, training, and guided mode #232
36
.env.example
Normal file
36
.env.example
Normal file
@@ -0,0 +1,36 @@
|
||||
# Datenbank (PostgreSQL)
|
||||
POSTGRES_USER=archive_user
|
||||
POSTGRES_PASSWORD=change-me
|
||||
POSTGRES_DB=family_archive_db
|
||||
|
||||
# Object Storage (MinIO)
|
||||
MINIO_ROOT_USER=minio_admin
|
||||
MINIO_ROOT_PASSWORD=change-me
|
||||
MINIO_DEFAULT_BUCKETS=archive-documents
|
||||
|
||||
# Ports (für Zugriff vom Host/NAS)
|
||||
PORT_DB=5432
|
||||
PORT_MINIO_API=9000
|
||||
PORT_MINIO_CONSOLE=9001
|
||||
PORT_BACKEND=8080
|
||||
PORT_FRONTEND=5173
|
||||
|
||||
# Mailpit — local mail catcher (dev only, included in docker-compose)
|
||||
# Web UI: http://localhost:8025
|
||||
# SMTP: localhost:1025 (used automatically by the backend container)
|
||||
PORT_MAILPIT_UI=8100
|
||||
PORT_MAILPIT_SMTP=1025
|
||||
|
||||
# OCR Training — set a secret token to protect the /train and /segtrain endpoints on the
|
||||
# Python OCR microservice. Leave empty to disable token authentication (development only).
|
||||
# OCR_TRAINING_TOKEN=change-me-in-production
|
||||
|
||||
# Production SMTP — uncomment and fill in to send real emails instead of catching them
|
||||
# APP_BASE_URL=https://your-domain.example.com
|
||||
# MAIL_HOST=smtp.example.com
|
||||
# MAIL_PORT=587
|
||||
# MAIL_USERNAME=your-smtp-user
|
||||
# MAIL_PASSWORD=your-smtp-password
|
||||
# MAIL_SMTP_AUTH=true
|
||||
# MAIL_STARTTLS_ENABLE=true
|
||||
# APP_MAIL_FROM=noreply@your-domain.example.com
|
||||
@@ -152,6 +152,13 @@
|
||||
<artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
|
||||
<version>3.0.2</version>
|
||||
</dependency>
|
||||
|
||||
<!-- PDF rendering for training data export -->
|
||||
<dependency>
|
||||
<groupId>org.apache.pdfbox</groupId>
|
||||
<artifactId>pdfbox</artifactId>
|
||||
<version>3.0.4</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import java.util.UUID;
|
||||
|
||||
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
||||
import org.raddatz.familienarchiv.dto.DocumentSearchResult;
|
||||
import org.raddatz.familienarchiv.dto.DocumentUpdateDTO;
|
||||
import org.raddatz.familienarchiv.dto.DocumentVersionSummary;
|
||||
@@ -21,6 +22,7 @@ import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.dto.DocumentSort;
|
||||
import org.raddatz.familienarchiv.model.DocumentStatus;
|
||||
import org.raddatz.familienarchiv.model.TrainingLabel;
|
||||
import org.raddatz.familienarchiv.model.DocumentVersion;
|
||||
import org.raddatz.familienarchiv.security.Permission;
|
||||
import org.raddatz.familienarchiv.security.RequirePermission;
|
||||
@@ -35,9 +37,11 @@ import org.springframework.core.io.InputStreamResource;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.ModelAttribute;
|
||||
import org.springframework.web.bind.annotation.PatchMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.PutMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RequestPart;
|
||||
@@ -208,6 +212,30 @@ public class DocumentController {
|
||||
return ResponseEntity.ok(DocumentSearchResult.of(results));
|
||||
}
|
||||
|
||||
// --- TRAINING LABELS ---
|
||||
|
||||
public record TrainingLabelRequest(String label, boolean enrolled) {}
|
||||
|
||||
@PatchMapping("/{id}/training-labels")
|
||||
@RequirePermission(Permission.WRITE_ALL)
|
||||
@ApiResponse(responseCode = "204")
|
||||
public ResponseEntity<Void> patchTrainingLabel(
|
||||
@PathVariable UUID id,
|
||||
@RequestBody TrainingLabelRequest req) {
|
||||
TrainingLabel label;
|
||||
try {
|
||||
label = TrainingLabel.valueOf(req.label());
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Unknown training label: " + req.label());
|
||||
}
|
||||
if (req.enrolled()) {
|
||||
documentService.addTrainingLabel(id, label);
|
||||
} else {
|
||||
documentService.removeTrainingLabel(id, label);
|
||||
}
|
||||
return ResponseEntity.noContent().build();
|
||||
}
|
||||
|
||||
// --- VERSIONS ---
|
||||
|
||||
@GetMapping("/{id}/versions")
|
||||
|
||||
@@ -7,16 +7,23 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO;
|
||||
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
|
||||
import org.raddatz.familienarchiv.model.AppUser;
|
||||
import org.raddatz.familienarchiv.model.OcrJob;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.security.Permission;
|
||||
import org.raddatz.familienarchiv.security.RequirePermission;
|
||||
import org.raddatz.familienarchiv.service.OcrBatchService;
|
||||
import org.raddatz.familienarchiv.service.OcrProgressService;
|
||||
import org.raddatz.familienarchiv.service.OcrService;
|
||||
import org.raddatz.familienarchiv.service.OcrTrainingService;
|
||||
import org.raddatz.familienarchiv.service.SegmentationTrainingExportService;
|
||||
import org.raddatz.familienarchiv.service.TrainingDataExportService;
|
||||
import org.raddatz.familienarchiv.service.UserService;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import jakarta.validation.Valid;
|
||||
@@ -32,6 +39,9 @@ public class OcrController {
|
||||
private final OcrBatchService ocrBatchService;
|
||||
private final OcrProgressService ocrProgressService;
|
||||
private final UserService userService;
|
||||
private final TrainingDataExportService trainingDataExportService;
|
||||
private final SegmentationTrainingExportService segmentationTrainingExportService;
|
||||
private final OcrTrainingService ocrTrainingService;
|
||||
|
||||
@PostMapping("/api/documents/{documentId}/ocr")
|
||||
@ResponseStatus(HttpStatus.ACCEPTED)
|
||||
@@ -41,7 +51,8 @@ public class OcrController {
|
||||
@RequestBody TriggerOcrDTO dto,
|
||||
Authentication authentication) {
|
||||
UUID userId = resolveUserId(authentication);
|
||||
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId);
|
||||
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId,
|
||||
Boolean.TRUE.equals(dto.getUseExistingAnnotations()));
|
||||
return Map.of("jobId", jobId);
|
||||
}
|
||||
|
||||
@@ -75,6 +86,54 @@ public class OcrController {
|
||||
return ocrService.getDocumentOcrStatus(documentId);
|
||||
}
|
||||
|
||||
@GetMapping("/api/ocr/training-data/export")
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public ResponseEntity<StreamingResponseBody> exportTrainingData() {
|
||||
if (trainingDataExportService.queryEligibleBlocks().isEmpty()) {
|
||||
return ResponseEntity.noContent().build();
|
||||
}
|
||||
StreamingResponseBody body = trainingDataExportService.exportToZip();
|
||||
return ResponseEntity.ok()
|
||||
.contentType(MediaType.parseMediaType("application/zip"))
|
||||
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"training-data.zip\"")
|
||||
.body(body);
|
||||
}
|
||||
|
||||
@GetMapping("/api/ocr/segmentation-training-data/export")
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public ResponseEntity<StreamingResponseBody> exportSegmentationTrainingData() {
|
||||
if (segmentationTrainingExportService.querySegmentationBlocks().isEmpty()) {
|
||||
return ResponseEntity.noContent().build();
|
||||
}
|
||||
StreamingResponseBody body = segmentationTrainingExportService.exportToZip();
|
||||
return ResponseEntity.ok()
|
||||
.contentType(MediaType.parseMediaType("application/zip"))
|
||||
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"segmentation-data.zip\"")
|
||||
.body(body);
|
||||
}
|
||||
|
||||
@PostMapping("/api/ocr/train")
|
||||
@ResponseStatus(HttpStatus.CREATED)
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public OcrTrainingRun triggerTraining(Authentication authentication) {
|
||||
UUID userId = resolveUserId(authentication);
|
||||
return ocrTrainingService.triggerTraining(userId);
|
||||
}
|
||||
|
||||
@PostMapping("/api/ocr/segtrain")
|
||||
@ResponseStatus(HttpStatus.CREATED)
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public OcrTrainingRun triggerSegTraining(Authentication authentication) {
|
||||
UUID userId = resolveUserId(authentication);
|
||||
return ocrTrainingService.triggerSegTraining(userId);
|
||||
}
|
||||
|
||||
@GetMapping("/api/ocr/training-info")
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {
|
||||
return ocrTrainingService.getTrainingInfo();
|
||||
}
|
||||
|
||||
private UUID resolveUserId(Authentication authentication) {
|
||||
if (authentication == null || !authentication.isAuthenticated()) return null;
|
||||
try {
|
||||
|
||||
@@ -10,4 +10,5 @@ import org.raddatz.familienarchiv.model.ScriptType;
|
||||
@AllArgsConstructor
|
||||
public class TriggerOcrDTO {
|
||||
private ScriptType scriptType;
|
||||
private Boolean useExistingAnnotations = false;
|
||||
}
|
||||
|
||||
@@ -49,8 +49,6 @@ public enum ErrorCode {
|
||||
// --- Annotations ---
|
||||
/** The annotation with the given ID does not exist. 404 */
|
||||
ANNOTATION_NOT_FOUND,
|
||||
/** The new annotation overlaps an existing one on the same page. 409 */
|
||||
ANNOTATION_OVERLAP,
|
||||
|
||||
// --- Transcription Blocks ---
|
||||
/** The transcription block with the given ID does not exist. 404 */
|
||||
@@ -75,6 +73,8 @@ public enum ErrorCode {
|
||||
OCR_DOCUMENT_NOT_UPLOADED,
|
||||
/** OCR processing failed for the document. 500 */
|
||||
OCR_PROCESSING_FAILED,
|
||||
/** A training run is already in progress. 409 */
|
||||
TRAINING_ALREADY_RUNNING,
|
||||
|
||||
// --- Generic ---
|
||||
/** Request validation failed (missing or malformed fields). 400 */
|
||||
|
||||
@@ -110,4 +110,11 @@ public class Document {
|
||||
@JoinTable(name = "document_tags", joinColumns = @JoinColumn(name = "document_id"), inverseJoinColumns = @JoinColumn(name = "tag_id"))
|
||||
@Builder.Default
|
||||
private Set<Tag> tags = new HashSet<>();
|
||||
|
||||
@ElementCollection(fetch = FetchType.EAGER)
|
||||
@CollectionTable(name = "document_training_labels", joinColumns = @JoinColumn(name = "document_id"))
|
||||
@Column(name = "label")
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Builder.Default
|
||||
private Set<TrainingLabel> trainingLabels = new HashSet<>();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.persistence.*;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.hibernate.annotations.CreationTimestamp;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.UUID;
|
||||
|
||||
@Entity
|
||||
@Table(name = "ocr_training_runs")
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class OcrTrainingRun {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.UUID)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private UUID id;
|
||||
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Column(nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private TrainingStatus status;
|
||||
|
||||
@Column(name = "block_count", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private int blockCount;
|
||||
|
||||
@Column(name = "document_count", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private int documentCount;
|
||||
|
||||
@Column(name = "model_name", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private String modelName;
|
||||
|
||||
@Column(name = "cer")
|
||||
private Double cer;
|
||||
|
||||
@Column(name = "loss")
|
||||
private Double loss;
|
||||
|
||||
@Column(name = "accuracy")
|
||||
private Double accuracy;
|
||||
|
||||
@Column(name = "epochs")
|
||||
private Integer epochs;
|
||||
|
||||
@Column(name = "error_message")
|
||||
private String errorMessage;
|
||||
|
||||
@Column(name = "triggered_by")
|
||||
private UUID triggeredBy;
|
||||
|
||||
@CreationTimestamp
|
||||
@Column(name = "created_at", nullable = false, updatable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private Instant createdAt;
|
||||
|
||||
@Column(name = "completed_at")
|
||||
private Instant completedAt;
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
public enum TrainingLabel {
|
||||
KURRENT_RECOGNITION,
|
||||
KURRENT_SEGMENTATION
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
public enum TrainingStatus {
|
||||
RUNNING,
|
||||
DONE,
|
||||
FAILED
|
||||
}
|
||||
@@ -30,8 +30,7 @@ public class TranscriptionBlock {
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private UUID documentId;
|
||||
|
||||
@Column(nullable = false, columnDefinition = "TEXT")
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@Column(columnDefinition = "TEXT")
|
||||
private String text;
|
||||
|
||||
@Column(length = 200)
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package org.raddatz.familienarchiv.repository;
|
||||
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun, UUID> {
|
||||
|
||||
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
|
||||
|
||||
List<OcrTrainingRun> findTop5ByOrderByCreatedAtDesc();
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package org.raddatz.familienarchiv.repository;
|
||||
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
import org.springframework.data.jpa.repository.Query;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
@@ -13,5 +14,27 @@ public interface TranscriptionBlockRepository extends JpaRepository<Transcriptio
|
||||
|
||||
Optional<TranscriptionBlock> findByIdAndDocumentId(UUID id, UUID documentId);
|
||||
|
||||
Optional<TranscriptionBlock> findByAnnotationId(UUID annotationId);
|
||||
|
||||
void deleteByAnnotationId(UUID annotationId);
|
||||
|
||||
int countByDocumentId(UUID documentId);
|
||||
|
||||
@Query("""
|
||||
SELECT b FROM TranscriptionBlock b
|
||||
JOIN DocumentAnnotation a ON a.id = b.annotationId
|
||||
JOIN Document d ON d.id = b.documentId
|
||||
WHERE b.reviewed = true
|
||||
AND 'KURRENT_RECOGNITION' MEMBER OF d.trainingLabels
|
||||
""")
|
||||
List<TranscriptionBlock> findEligibleKurrentBlocks();
|
||||
|
||||
@Query("""
|
||||
SELECT b FROM TranscriptionBlock b
|
||||
JOIN DocumentAnnotation a ON a.id = b.annotationId
|
||||
JOIN Document d ON d.id = b.documentId
|
||||
WHERE b.source = 'MANUAL'
|
||||
AND 'KURRENT_SEGMENTATION' MEMBER OF d.trainingLabels
|
||||
""")
|
||||
List<TranscriptionBlock> findSegmentationBlocks();
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.DocumentAnnotation;
|
||||
import org.raddatz.familienarchiv.repository.AnnotationRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@@ -17,6 +18,7 @@ import java.util.UUID;
|
||||
public class AnnotationService {
|
||||
|
||||
private final AnnotationRepository annotationRepository;
|
||||
private final TranscriptionBlockRepository blockRepository;
|
||||
|
||||
public List<DocumentAnnotation> listAnnotations(UUID documentId) {
|
||||
return annotationRepository.findByDocumentId(documentId);
|
||||
@@ -24,15 +26,6 @@ public class AnnotationService {
|
||||
|
||||
@Transactional
|
||||
public DocumentAnnotation createAnnotation(UUID documentId, CreateAnnotationDTO dto, UUID userId, String fileHash) {
|
||||
List<DocumentAnnotation> existing =
|
||||
annotationRepository.findByDocumentIdAndPageNumber(documentId, dto.getPageNumber());
|
||||
|
||||
boolean overlaps = existing.stream().anyMatch(a -> overlaps(a, dto));
|
||||
if (overlaps) {
|
||||
throw DomainException.conflict(
|
||||
ErrorCode.ANNOTATION_OVERLAP, "Annotation overlaps an existing one on this page");
|
||||
}
|
||||
|
||||
DocumentAnnotation annotation = DocumentAnnotation.builder()
|
||||
.documentId(documentId)
|
||||
.pageNumber(dto.getPageNumber())
|
||||
@@ -79,6 +72,7 @@ public class AnnotationService {
|
||||
throw DomainException.forbidden("Only the annotation author can delete it");
|
||||
}
|
||||
|
||||
blockRepository.deleteByAnnotationId(annotationId);
|
||||
annotationRepository.delete(annotation);
|
||||
}
|
||||
|
||||
@@ -90,14 +84,4 @@ public class AnnotationService {
|
||||
});
|
||||
}
|
||||
|
||||
// ─── private helpers ──────────────────────────────────────────────────────
|
||||
|
||||
private boolean overlaps(DocumentAnnotation existing, CreateAnnotationDTO dto) {
|
||||
double ex2 = existing.getX() + existing.getWidth();
|
||||
double ey2 = existing.getY() + existing.getHeight();
|
||||
double dx2 = dto.getX() + dto.getWidth();
|
||||
double dy2 = dto.getY() + dto.getHeight();
|
||||
return existing.getX() < dx2 && ex2 > dto.getX()
|
||||
&& existing.getY() < dy2 && ey2 > dto.getY();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.dto.DocumentSort;
|
||||
import org.raddatz.familienarchiv.model.DocumentStatus;
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
import org.raddatz.familienarchiv.model.TrainingLabel;
|
||||
import org.raddatz.familienarchiv.model.Person;
|
||||
import org.raddatz.familienarchiv.model.Tag;
|
||||
import org.raddatz.familienarchiv.repository.DocumentRepository;
|
||||
@@ -385,6 +386,20 @@ public class DocumentService {
|
||||
documentRepository.save(doc);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void addTrainingLabel(UUID documentId, TrainingLabel label) {
|
||||
Document doc = getDocumentById(documentId);
|
||||
doc.getTrainingLabels().add(label);
|
||||
documentRepository.save(doc);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void removeTrainingLabel(UUID documentId, TrainingLabel label) {
|
||||
Document doc = getDocumentById(documentId);
|
||||
doc.getTrainingLabels().remove(label);
|
||||
documentRepository.save(doc);
|
||||
}
|
||||
|
||||
public Document getDocumentById(UUID id) {
|
||||
return documentRepository.findById(id)
|
||||
.orElseThrow(() -> DomainException.notFound(ErrorCode.DOCUMENT_NOT_FOUND, "Document not found: " + id));
|
||||
|
||||
@@ -114,7 +114,8 @@ public class FileService {
|
||||
|
||||
/**
|
||||
* Generates a presigned URL for downloading an object from S3/MinIO.
|
||||
* Valid for 15 minutes — enough for OCR processing on CPU.
|
||||
* Valid for 1 hour — covers multi-page documents on CPU-only OCR hardware
|
||||
* (a 100-page document at 10 s/page takes ~17 min; 1 h gives ample headroom).
|
||||
*/
|
||||
public String generatePresignedUrl(String s3Key) {
|
||||
GetObjectRequest getObjectRequest = GetObjectRequest.builder()
|
||||
@@ -123,7 +124,7 @@ public class FileService {
|
||||
.build();
|
||||
|
||||
GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder()
|
||||
.signatureDuration(Duration.ofMinutes(15))
|
||||
.signatureDuration(Duration.ofHours(1))
|
||||
.getObjectRequest(getObjectRequest)
|
||||
.build();
|
||||
|
||||
|
||||
@@ -32,6 +32,11 @@ public class OcrAsyncRunner {
|
||||
|
||||
@Async
|
||||
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
|
||||
runSingleDocument(jobId, documentId, userId, false);
|
||||
}
|
||||
|
||||
@Async
|
||||
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId, boolean useExistingAnnotations) {
|
||||
OcrJob job = ocrJobRepository.findById(jobId).orElse(null);
|
||||
if (job == null) return;
|
||||
|
||||
@@ -49,7 +54,18 @@ public class OcrAsyncRunner {
|
||||
|
||||
try {
|
||||
updateProgress(job, "LOADING");
|
||||
clearExistingBlocks(documentId);
|
||||
|
||||
List<OcrClient.OcrRegion> regions = null;
|
||||
if (useExistingAnnotations) {
|
||||
regions = annotationService.listAnnotations(documentId).stream()
|
||||
.map(a -> new OcrClient.OcrRegion(
|
||||
a.getId().toString(), a.getPageNumber(),
|
||||
a.getX(), a.getY(), a.getWidth(), a.getHeight()))
|
||||
.toList();
|
||||
} else {
|
||||
clearExistingBlocks(documentId);
|
||||
}
|
||||
|
||||
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
|
||||
|
||||
AtomicInteger blockCounter = new AtomicInteger(0);
|
||||
@@ -57,7 +73,7 @@ public class OcrAsyncRunner {
|
||||
AtomicInteger skippedPages = new AtomicInteger(0);
|
||||
AtomicInteger totalPages = new AtomicInteger(0);
|
||||
|
||||
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), event -> {
|
||||
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, event -> {
|
||||
switch (event) {
|
||||
case OcrStreamEvent.Start start -> {
|
||||
totalPages.set(start.totalPages());
|
||||
@@ -204,14 +220,21 @@ public class OcrAsyncRunner {
|
||||
|
||||
void createSingleBlock(UUID documentId, OcrBlockResult block,
|
||||
UUID userId, String fileHash, int sortOrder) {
|
||||
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
|
||||
block.pageNumber(), block.x(), block.y(),
|
||||
block.width(), block.height(), OCR_ANNOTATION_COLOR);
|
||||
if (block.annotationId() != null) {
|
||||
// Guided mode — annotation already exists; upsert the text block only
|
||||
transcriptionService.upsertGuidedBlock(
|
||||
documentId, UUID.fromString(block.annotationId()), block.text(), userId);
|
||||
} else {
|
||||
// Normal mode — create a new annotation and a new OCR block
|
||||
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
|
||||
block.pageNumber(), block.x(), block.y(),
|
||||
block.width(), block.height(), OCR_ANNOTATION_COLOR);
|
||||
|
||||
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
|
||||
documentId, annotationDTO, userId, fileHash, block.polygon());
|
||||
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
|
||||
documentId, annotationDTO, userId, fileHash, block.polygon());
|
||||
|
||||
transcriptionService.createOcrBlock(documentId, annotation.getId(),
|
||||
block.text(), sortOrder, userId);
|
||||
transcriptionService.createOcrBlock(documentId, annotation.getId(),
|
||||
block.text(), sortOrder, userId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,5 +12,6 @@ public record OcrBlockResult(
|
||||
double width,
|
||||
double height,
|
||||
List<List<Double>> polygon,
|
||||
String text
|
||||
String text,
|
||||
String annotationId // null in normal mode; set in guided mode to link back to existing annotation
|
||||
) {}
|
||||
|
||||
@@ -10,12 +10,42 @@ import java.util.function.Consumer;
|
||||
public interface OcrClient {
|
||||
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
|
||||
|
||||
/**
|
||||
* A pre-drawn annotation region to use as guidance for OCR.
|
||||
* When regions are provided, the OCR engine crops to each region and
|
||||
* runs recognition only within that area, skipping full-page layout detection.
|
||||
*/
|
||||
record OcrRegion(String annotationId, int pageNumber,
|
||||
double x, double y, double width, double height) {}
|
||||
|
||||
/**
|
||||
* Send a training ZIP to the OCR service for fine-tuning the Kurrent model.
|
||||
*
|
||||
* @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService
|
||||
* @return training result metrics (loss, accuracy, epochs)
|
||||
*/
|
||||
TrainingResult trainModel(byte[] trainingDataZip);
|
||||
|
||||
record TrainingResult(Double loss, Double accuracy, Double cer, Integer epochs) {}
|
||||
|
||||
/**
|
||||
* Send a segmentation training ZIP to the OCR service for fine-tuning the blla model.
|
||||
*
|
||||
* @param trainingDataZip raw ZIP bytes produced by SegmentationTrainingExportService
|
||||
* @return training result metrics
|
||||
*/
|
||||
TrainingResult segtrainModel(byte[] trainingDataZip);
|
||||
|
||||
/**
|
||||
* Stream OCR results page-by-page via NDJSON. Implementations should override
|
||||
* this method. The default exists only for backward compatibility during migration
|
||||
* — it calls extractBlocks() and synthesizes events from the collected result.
|
||||
*
|
||||
* @param regions optional list of pre-drawn annotation regions; when non-null,
|
||||
* the OCR service runs in guided mode (crop + recognize per region)
|
||||
*/
|
||||
default void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
|
||||
default void streamBlocks(String pdfUrl, ScriptType scriptType,
|
||||
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
|
||||
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
|
||||
|
||||
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();
|
||||
|
||||
@@ -52,6 +52,11 @@ public class OcrService {
|
||||
}
|
||||
|
||||
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId) {
|
||||
return startOcr(documentId, scriptTypeOverride, userId, false);
|
||||
}
|
||||
|
||||
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId,
|
||||
boolean useExistingAnnotations) {
|
||||
Document doc = documentService.getDocumentById(documentId);
|
||||
|
||||
if (doc.getStatus() == DocumentStatus.PLACEHOLDER) {
|
||||
@@ -66,6 +71,9 @@ public class OcrService {
|
||||
|
||||
if (scriptTypeOverride != null) {
|
||||
documentService.updateScriptType(documentId, scriptTypeOverride);
|
||||
if (scriptTypeOverride == ScriptType.HANDWRITING_KURRENT) {
|
||||
documentService.addTrainingLabel(documentId, TrainingLabel.KURRENT_RECOGNITION);
|
||||
}
|
||||
}
|
||||
|
||||
OcrJob job = OcrJob.builder()
|
||||
@@ -82,7 +90,7 @@ public class OcrService {
|
||||
.build();
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId);
|
||||
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId, useExistingAnnotations);
|
||||
return job.getId();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.slf4j.MDC;
|
||||
import org.springframework.boot.context.event.ApplicationReadyEvent;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class OcrTrainingService {
|
||||
|
||||
private final OcrTrainingRunRepository trainingRunRepository;
|
||||
private final TrainingDataExportService trainingDataExportService;
|
||||
private final SegmentationTrainingExportService segmentationTrainingExportService;
|
||||
private final OcrClient ocrClient;
|
||||
private final OcrHealthClient ocrHealthClient;
|
||||
private final TranscriptionBlockRepository blockRepository;
|
||||
private final TransactionTemplate txTemplate;
|
||||
|
||||
public record TrainingInfoResponse(
|
||||
int availableBlocks,
|
||||
int totalOcrBlocks,
|
||||
int availableDocuments,
|
||||
int availableSegBlocks,
|
||||
boolean ocrServiceAvailable,
|
||||
OcrTrainingRun lastRun,
|
||||
List<OcrTrainingRun> runs
|
||||
) {}
|
||||
|
||||
// Not safe for horizontal scaling: training reloads the Kraken model in-process on the
|
||||
// Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique
|
||||
// index) prevents concurrent training API calls, but cannot prevent two OCR service replicas
|
||||
// from diverging on model state. Deploy as a single instance only. See ADR-001.
|
||||
public OcrTrainingRun triggerTraining(UUID triggeredBy) {
|
||||
// Short transaction: guard check + create RUNNING row, then commit immediately.
|
||||
// The DB connection is released before the OCR HTTP call, which can take several minutes.
|
||||
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
if (eligibleBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
|
||||
}
|
||||
|
||||
long documentCount = eligibleBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
OcrTrainingRun newRun = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(eligibleBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("german_kurrent")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
return trainingRunRepository.save(newRun);
|
||||
}));
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started training run {} with {} blocks from {} documents",
|
||||
runId, run.getBlockCount(), run.getDocumentCount());
|
||||
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
trainingDataExportService.exportToZip().writeTo(baos);
|
||||
byte[] zipBytes = baos.toByteArray();
|
||||
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
|
||||
OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
|
||||
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setCer(result.cer());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
OcrTrainingRun updated = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
|
||||
return updated;
|
||||
}));
|
||||
} catch (Exception e) {
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
OcrTrainingRun failed = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
|
||||
return failed;
|
||||
}));
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
}
|
||||
}
|
||||
|
||||
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
|
||||
// Same pattern as triggerTraining: narrow transactions around DB writes only.
|
||||
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
|
||||
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
|
||||
if (segBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
|
||||
}
|
||||
|
||||
long documentCount = segBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
OcrTrainingRun newRun = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(segBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("blla")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
return trainingRunRepository.save(newRun);
|
||||
}));
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started segmentation training run {} with {} segments from {} documents",
|
||||
runId, run.getBlockCount(), run.getDocumentCount());
|
||||
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
segmentationTrainingExportService.exportToZip().writeTo(baos);
|
||||
byte[] zipBytes = baos.toByteArray();
|
||||
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
|
||||
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
|
||||
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
OcrTrainingRun updated = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
|
||||
return updated;
|
||||
}));
|
||||
} catch (Exception e) {
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
OcrTrainingRun failed = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
|
||||
return failed;
|
||||
}));
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
}
|
||||
}
|
||||
|
||||
public TrainingInfoResponse getTrainingInfo() {
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
int availableDocuments = (int) eligibleBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
int totalOcrBlocks = (int) blockRepository.count();
|
||||
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
|
||||
|
||||
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
|
||||
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
|
||||
|
||||
return new TrainingInfoResponse(
|
||||
eligibleBlocks.size(),
|
||||
totalOcrBlocks,
|
||||
availableDocuments,
|
||||
availableSegBlocks,
|
||||
ocrHealthClient.isHealthy(),
|
||||
lastRun,
|
||||
recentRuns
|
||||
);
|
||||
}
|
||||
|
||||
@EventListener(ApplicationReadyEvent.class)
|
||||
@Transactional
|
||||
public void recoverOrphanedRuns() {
|
||||
var cutoff = Instant.now().minusSeconds(3600);
|
||||
trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).ifPresent(run -> {
|
||||
if (run.getCreatedAt().isBefore(cutoff)) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage("Abgebrochen: Dienst wurde neugestartet");
|
||||
run.setCompletedAt(Instant.now());
|
||||
trainingRunRepository.save(run);
|
||||
log.warn("Recovered orphaned training run {} (marked FAILED on startup)", run.getId());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public Map<String, Object> buildTrainingInfoMap(TrainingInfoResponse info) {
|
||||
return Map.of(
|
||||
"availableBlocks", info.availableBlocks(),
|
||||
"totalOcrBlocks", info.totalOcrBlocks(),
|
||||
"availableDocuments", info.availableDocuments(),
|
||||
"availableSegBlocks", info.availableSegBlocks(),
|
||||
"ocrServiceAvailable", info.ocrServiceAvailable(),
|
||||
"lastRun", info.lastRun() != null ? info.lastRun() : Map.of(),
|
||||
"runs", info.runs()
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -9,9 +9,14 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.io.ByteArrayResource;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.client.JdkClientHttpRequestFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.client.RestClient;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
@@ -36,11 +41,16 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true);
|
||||
|
||||
private final RestClient restClient;
|
||||
private final RestClient trainingRestClient;
|
||||
private final HttpClient streamingHttpClient;
|
||||
private final String baseUrl;
|
||||
private final String trainingToken;
|
||||
|
||||
public RestClientOcrClient(@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl) {
|
||||
public RestClientOcrClient(
|
||||
@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl,
|
||||
@Value("${app.ocr.training-token:}") String trainingToken) {
|
||||
this.baseUrl = baseUrl;
|
||||
this.trainingToken = trainingToken;
|
||||
|
||||
HttpClient httpClient = HttpClient.newBuilder()
|
||||
.version(HttpClient.Version.HTTP_1_1)
|
||||
@@ -54,6 +64,17 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
.requestFactory(requestFactory)
|
||||
.build();
|
||||
|
||||
HttpClient trainingHttpClient = HttpClient.newBuilder()
|
||||
.version(HttpClient.Version.HTTP_1_1)
|
||||
.connectTimeout(Duration.ofSeconds(10))
|
||||
.build();
|
||||
JdkClientHttpRequestFactory trainingRequestFactory = new JdkClientHttpRequestFactory(trainingHttpClient);
|
||||
trainingRequestFactory.setReadTimeout(Duration.ofMinutes(10));
|
||||
this.trainingRestClient = RestClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.requestFactory(trainingRequestFactory)
|
||||
.build();
|
||||
|
||||
this.streamingHttpClient = HttpClient.newBuilder()
|
||||
.version(HttpClient.Version.HTTP_1_1)
|
||||
.connectTimeout(Duration.ofSeconds(10))
|
||||
@@ -81,6 +102,64 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
.toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) {
|
||||
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
|
||||
@Override
|
||||
public String getFilename() { return "training-data.zip"; }
|
||||
};
|
||||
|
||||
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
|
||||
HttpHeaders partHeaders = new HttpHeaders();
|
||||
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
|
||||
body.add("file", new HttpEntity<>(zipResource, partHeaders));
|
||||
|
||||
var spec = trainingRestClient.post()
|
||||
.uri("/train")
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA);
|
||||
|
||||
if (trainingToken != null && !trainingToken.isBlank()) {
|
||||
spec = spec.header("X-Training-Token", trainingToken);
|
||||
}
|
||||
|
||||
TrainingResultJson result = spec
|
||||
.body(body)
|
||||
.retrieve()
|
||||
.body(TrainingResultJson.class);
|
||||
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
|
||||
}
|
||||
|
||||
@Override
|
||||
public OcrClient.TrainingResult segtrainModel(byte[] trainingDataZip) {
|
||||
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
|
||||
@Override
|
||||
public String getFilename() { return "segmentation-data.zip"; }
|
||||
};
|
||||
|
||||
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
|
||||
HttpHeaders partHeaders = new HttpHeaders();
|
||||
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
|
||||
body.add("file", new HttpEntity<>(zipResource, partHeaders));
|
||||
|
||||
var spec = trainingRestClient.post()
|
||||
.uri("/segtrain")
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA);
|
||||
|
||||
if (trainingToken != null && !trainingToken.isBlank()) {
|
||||
spec = spec.header("X-Training-Token", trainingToken);
|
||||
}
|
||||
|
||||
TrainingResultJson result = spec
|
||||
.body(body)
|
||||
.retrieve()
|
||||
.body(TrainingResultJson.class);
|
||||
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isHealthy() {
|
||||
try {
|
||||
@@ -96,13 +175,18 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
|
||||
public void streamBlocks(String pdfUrl, ScriptType scriptType,
|
||||
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
|
||||
String body;
|
||||
try {
|
||||
body = NDJSON_MAPPER.writeValueAsString(Map.of(
|
||||
"pdfUrl", pdfUrl,
|
||||
"scriptType", scriptType.name(),
|
||||
"language", "de"));
|
||||
var requestMap = new java.util.LinkedHashMap<String, Object>();
|
||||
requestMap.put("pdfUrl", pdfUrl);
|
||||
requestMap.put("scriptType", scriptType.name());
|
||||
requestMap.put("language", "de");
|
||||
if (regions != null && !regions.isEmpty()) {
|
||||
requestMap.put("regions", regions);
|
||||
}
|
||||
body = NDJSON_MAPPER.writeValueAsString(requestMap);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to serialize OCR request", e);
|
||||
}
|
||||
@@ -120,7 +204,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
|
||||
if (response.statusCode() == 404) {
|
||||
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
|
||||
OcrClient.super.streamBlocks(pdfUrl, scriptType, handler);
|
||||
OcrClient.super.streamBlocks(pdfUrl, scriptType, regions, handler);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -171,6 +255,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
}
|
||||
}
|
||||
|
||||
record TrainingResultJson(Double loss, Double accuracy, Double cer, Integer epochs) {}
|
||||
|
||||
record OcrBlockJson(
|
||||
@JsonProperty("pageNumber") int pageNumber,
|
||||
double x,
|
||||
@@ -178,10 +264,11 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
double width,
|
||||
double height,
|
||||
List<List<Double>> polygon,
|
||||
String text
|
||||
String text,
|
||||
String annotationId
|
||||
) {
|
||||
OcrBlockResult toResult() {
|
||||
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text);
|
||||
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text, annotationId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.pdfbox.Loader;
|
||||
import org.apache.pdfbox.pdmodel.PDDocument;
|
||||
import org.apache.pdfbox.rendering.PDFRenderer;
|
||||
import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.model.DocumentAnnotation;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.AnnotationRepository;
|
||||
import org.raddatz.familienarchiv.repository.DocumentRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.*;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipOutputStream;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class SegmentationTrainingExportService {
|
||||
|
||||
private final TranscriptionBlockRepository blockRepository;
|
||||
private final AnnotationRepository annotationRepository;
|
||||
private final DocumentRepository documentRepository;
|
||||
private final FileService fileService;
|
||||
|
||||
public List<TranscriptionBlock> querySegmentationBlocks() {
|
||||
return blockRepository.findSegmentationBlocks();
|
||||
}
|
||||
|
||||
public StreamingResponseBody exportToZip() {
|
||||
List<TranscriptionBlock> blocks = querySegmentationBlocks();
|
||||
if (blocks.isEmpty()) {
|
||||
return out -> {};
|
||||
}
|
||||
|
||||
// Group by documentId so we download each PDF only once
|
||||
Map<UUID, List<TranscriptionBlock>> byDoc = new LinkedHashMap<>();
|
||||
for (TranscriptionBlock b : blocks) {
|
||||
byDoc.computeIfAbsent(b.getDocumentId(), k -> new ArrayList<>()).add(b);
|
||||
}
|
||||
|
||||
// Pre-fetch annotations keyed by id
|
||||
Map<UUID, DocumentAnnotation> annotations = new HashMap<>();
|
||||
for (TranscriptionBlock b : blocks) {
|
||||
annotationRepository.findById(b.getAnnotationId())
|
||||
.ifPresent(a -> annotations.put(a.getId(), a));
|
||||
}
|
||||
|
||||
// Pre-fetch documents keyed by id
|
||||
Map<UUID, Document> documents = new HashMap<>();
|
||||
for (UUID docId : byDoc.keySet()) {
|
||||
documentRepository.findById(docId).ifPresent(d -> documents.put(d.getId(), d));
|
||||
}
|
||||
|
||||
return out -> {
|
||||
try (ZipOutputStream zip = new ZipOutputStream(out)) {
|
||||
for (Map.Entry<UUID, List<TranscriptionBlock>> entry : byDoc.entrySet()) {
|
||||
UUID docId = entry.getKey();
|
||||
Document doc = documents.get(docId);
|
||||
if (doc == null || doc.getFilePath() == null) {
|
||||
log.warn("Skipping document {} — no file path", docId);
|
||||
continue;
|
||||
}
|
||||
|
||||
byte[] pdfBytes;
|
||||
try {
|
||||
pdfBytes = fileService.downloadFileBytes(doc.getFilePath());
|
||||
} catch (FileService.StorageFileNotFoundException | IOException e) {
|
||||
log.warn("Skipping document {} — S3 download failed: {}", docId, e.getMessage());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Group blocks by page number for this document
|
||||
Map<Integer, List<TranscriptionBlock>> byPage = new LinkedHashMap<>();
|
||||
for (TranscriptionBlock b : entry.getValue()) {
|
||||
DocumentAnnotation ann = annotations.get(b.getAnnotationId());
|
||||
if (ann != null) {
|
||||
byPage.computeIfAbsent(ann.getPageNumber(), k -> new ArrayList<>()).add(b);
|
||||
}
|
||||
}
|
||||
|
||||
try (PDDocument pdf = Loader.loadPDF(pdfBytes)) {
|
||||
PDFRenderer renderer = new PDFRenderer(pdf);
|
||||
for (Map.Entry<Integer, List<TranscriptionBlock>> pageEntry : byPage.entrySet()) {
|
||||
int pageNumber = pageEntry.getKey();
|
||||
int pageIdx = pageNumber - 1;
|
||||
if (pageIdx < 0 || pageIdx >= pdf.getNumberOfPages()) continue;
|
||||
|
||||
BufferedImage pageImage = renderer.renderImageWithDPI(pageIdx, 300);
|
||||
String basename = "page-" + docId + "-" + pageNumber;
|
||||
|
||||
// Collect annotations for this page
|
||||
List<DocumentAnnotation> pageAnnotations = new ArrayList<>();
|
||||
for (TranscriptionBlock b : pageEntry.getValue()) {
|
||||
DocumentAnnotation ann = annotations.get(b.getAnnotationId());
|
||||
if (ann != null) pageAnnotations.add(ann);
|
||||
}
|
||||
|
||||
writePngEntry(zip, basename, pageImage);
|
||||
writePageXmlEntry(zip, basename, pageImage, pageAnnotations);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.warn("Skipping document {} — rendering failed: {}", docId, e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void writePngEntry(ZipOutputStream zip, String basename, BufferedImage image) throws IOException {
|
||||
zip.putNextEntry(new ZipEntry(basename + ".png"));
|
||||
ImageIO.write(image, "PNG", zip);
|
||||
zip.closeEntry();
|
||||
}
|
||||
|
||||
private void writePageXmlEntry(ZipOutputStream zip, String basename,
|
||||
BufferedImage pageImage,
|
||||
List<DocumentAnnotation> annotations) throws IOException {
|
||||
int imgW = pageImage.getWidth();
|
||||
int imgH = pageImage.getHeight();
|
||||
|
||||
StringBuilder regions = new StringBuilder();
|
||||
for (DocumentAnnotation ann : annotations) {
|
||||
String coords = buildPolygonCoords(ann, imgW, imgH);
|
||||
String regionId = ann.getId().toString();
|
||||
regions.append(" <TextRegion id=\"").append(regionId).append("\">\n");
|
||||
regions.append(" <Coords points=\"").append(coords).append("\"/>\n");
|
||||
regions.append(" </TextRegion>\n");
|
||||
}
|
||||
|
||||
String xml = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n"
|
||||
+ "<PcGts xmlns=\"http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15\">\n"
|
||||
+ " <Page imageFilename=\"" + basename + ".png\""
|
||||
+ " imageWidth=\"" + imgW + "\""
|
||||
+ " imageHeight=\"" + imgH + "\">\n"
|
||||
+ regions
|
||||
+ " </Page>\n"
|
||||
+ "</PcGts>\n";
|
||||
|
||||
zip.putNextEntry(new ZipEntry(basename + ".xml"));
|
||||
zip.write(xml.getBytes(StandardCharsets.UTF_8));
|
||||
zip.closeEntry();
|
||||
}
|
||||
|
||||
String buildPolygonCoords(DocumentAnnotation ann, int imgW, int imgH) {
|
||||
List<List<Double>> polygon = ann.getPolygon();
|
||||
if (polygon != null && !polygon.isEmpty()) {
|
||||
// Use explicit polygon — de-normalize to pixel coordinates
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (List<Double> pt : polygon) {
|
||||
if (sb.length() > 0) sb.append(' ');
|
||||
int px = (int) (pt.get(0) * imgW);
|
||||
int py = (int) (pt.get(1) * imgH);
|
||||
sb.append(px).append(',').append(py);
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
// Fall back to bounding box from x/y/width/height
|
||||
int x = (int) (ann.getX() * imgW);
|
||||
int y = (int) (ann.getY() * imgH);
|
||||
int w = (int) (ann.getWidth() * imgW);
|
||||
int h = (int) (ann.getHeight() * imgH);
|
||||
return x + "," + y + " " + (x + w) + "," + y + " " + (x + w) + "," + (y + h) + " " + x + "," + (y + h);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.pdfbox.Loader;
|
||||
import org.apache.pdfbox.pdmodel.PDDocument;
|
||||
import org.apache.pdfbox.rendering.PDFRenderer;
|
||||
import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.model.DocumentAnnotation;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.AnnotationRepository;
|
||||
import org.raddatz.familienarchiv.repository.DocumentRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.*;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipOutputStream;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class TrainingDataExportService {
|
||||
|
||||
private final TranscriptionBlockRepository blockRepository;
|
||||
private final AnnotationRepository annotationRepository;
|
||||
private final DocumentRepository documentRepository;
|
||||
private final FileService fileService;
|
||||
|
||||
public List<TranscriptionBlock> queryEligibleBlocks() {
|
||||
return blockRepository.findEligibleKurrentBlocks();
|
||||
}
|
||||
|
||||
public StreamingResponseBody exportToZip() {
|
||||
// Collect all data before entering the lambda — no open DB txn during streaming
|
||||
List<TranscriptionBlock> blocks = queryEligibleBlocks();
|
||||
if (blocks.isEmpty()) {
|
||||
return out -> {}; // caller checks isEmpty() for 204 response
|
||||
}
|
||||
|
||||
// Group blocks by documentId so we only download each PDF once
|
||||
Map<UUID, List<TranscriptionBlock>> byDoc = new LinkedHashMap<>();
|
||||
for (TranscriptionBlock b : blocks) {
|
||||
byDoc.computeIfAbsent(b.getDocumentId(), k -> new ArrayList<>()).add(b);
|
||||
}
|
||||
|
||||
// Pre-fetch annotations keyed by id
|
||||
Map<UUID, DocumentAnnotation> annotations = new HashMap<>();
|
||||
for (TranscriptionBlock b : blocks) {
|
||||
annotationRepository.findById(b.getAnnotationId())
|
||||
.ifPresent(a -> annotations.put(a.getId(), a));
|
||||
}
|
||||
|
||||
// Pre-fetch documents keyed by id
|
||||
Map<UUID, Document> documents = new HashMap<>();
|
||||
for (UUID docId : byDoc.keySet()) {
|
||||
documentRepository.findById(docId).ifPresent(d -> documents.put(d.getId(), d));
|
||||
}
|
||||
|
||||
return out -> {
|
||||
try (ZipOutputStream zip = new ZipOutputStream(out)) {
|
||||
for (Map.Entry<UUID, List<TranscriptionBlock>> entry : byDoc.entrySet()) {
|
||||
UUID docId = entry.getKey();
|
||||
Document doc = documents.get(docId);
|
||||
if (doc == null || doc.getFilePath() == null) {
|
||||
log.warn("Skipping document {} — no file path", docId);
|
||||
continue;
|
||||
}
|
||||
|
||||
byte[] pdfBytes;
|
||||
try {
|
||||
pdfBytes = fileService.downloadFileBytes(doc.getFilePath());
|
||||
} catch (FileService.StorageFileNotFoundException | IOException e) {
|
||||
log.warn("Skipping document {} — S3 download failed: {}", docId, e.getMessage());
|
||||
continue;
|
||||
}
|
||||
|
||||
try (PDDocument pdf = Loader.loadPDF(pdfBytes)) {
|
||||
PDFRenderer renderer = new PDFRenderer(pdf);
|
||||
for (TranscriptionBlock block : entry.getValue()) {
|
||||
DocumentAnnotation ann = annotations.get(block.getAnnotationId());
|
||||
if (ann == null) continue;
|
||||
|
||||
int pageIdx = ann.getPageNumber() - 1; // pageNumber is 1-based
|
||||
if (pageIdx < 0 || pageIdx >= pdf.getNumberOfPages()) continue;
|
||||
|
||||
BufferedImage pageImage = renderPageImage(renderer, pageIdx);
|
||||
BufferedImage cropped = cropBlockImage(pageImage, ann);
|
||||
|
||||
writeTrainingPair(zip, block.getId(), cropped, block.getText());
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.warn("Skipping document {} — rendering failed: {}", docId, e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
BufferedImage renderPageImage(PDFRenderer renderer, int pageIdx) throws IOException {
|
||||
return renderer.renderImageWithDPI(pageIdx, 300);
|
||||
}
|
||||
|
||||
BufferedImage cropBlockImage(BufferedImage page, DocumentAnnotation ann) {
|
||||
int imgW = page.getWidth();
|
||||
int imgH = page.getHeight();
|
||||
|
||||
int x = (int) (ann.getX() * imgW);
|
||||
int y = (int) (ann.getY() * imgH);
|
||||
int w = (int) (ann.getWidth() * imgW);
|
||||
int h = (int) (ann.getHeight() * imgH);
|
||||
|
||||
// Clamp to image bounds
|
||||
x = Math.max(0, Math.min(x, imgW - 1));
|
||||
y = Math.max(0, Math.min(y, imgH - 1));
|
||||
w = Math.max(1, Math.min(w, imgW - x));
|
||||
h = Math.max(1, Math.min(h, imgH - y));
|
||||
|
||||
return page.getSubimage(x, y, w, h);
|
||||
}
|
||||
|
||||
void writeTrainingPair(ZipOutputStream zip, UUID blockId, BufferedImage image, String text) throws IOException {
|
||||
String base = blockId.toString();
|
||||
int w = image.getWidth();
|
||||
int h = image.getHeight();
|
||||
// Baseline at 75 % height — typical text baseline position in a cropped line image
|
||||
int baselineY = (h * 3) / 4;
|
||||
|
||||
// Write PNG
|
||||
zip.putNextEntry(new ZipEntry(base + ".png"));
|
||||
ImageIO.write(image, "PNG", zip);
|
||||
zip.closeEntry();
|
||||
|
||||
// Write PAGE XML (Kraken 7+ dropped the legacy "path" format)
|
||||
String safeText = escapeXml(text != null ? text : "");
|
||||
String xml = String.format(
|
||||
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" +
|
||||
"<PcGts xmlns=\"http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15\">\n" +
|
||||
" <Metadata><Creator>familienarchiv</Creator></Metadata>\n" +
|
||||
" <Page imageFilename=\"%s.png\" imageWidth=\"%d\" imageHeight=\"%d\">\n" +
|
||||
" <TextRegion id=\"r0\" type=\"paragraph\">\n" +
|
||||
" <Coords points=\"0,0 %d,0 %d,%d 0,%d\"/>\n" +
|
||||
" <TextLine id=\"l0\">\n" +
|
||||
" <Coords points=\"0,0 %d,0 %d,%d 0,%d\"/>\n" +
|
||||
" <Baseline points=\"0,%d %d,%d\"/>\n" +
|
||||
" <TextEquiv><Unicode>%s</Unicode></TextEquiv>\n" +
|
||||
" </TextLine>\n" +
|
||||
" </TextRegion>\n" +
|
||||
" </Page>\n" +
|
||||
"</PcGts>\n",
|
||||
base, w, h,
|
||||
w - 1, w - 1, h - 1, h - 1,
|
||||
w - 1, w - 1, h - 1, h - 1,
|
||||
baselineY, w - 1, baselineY,
|
||||
safeText);
|
||||
|
||||
zip.putNextEntry(new ZipEntry(base + ".xml"));
|
||||
zip.write(xml.getBytes(StandardCharsets.UTF_8));
|
||||
zip.closeEntry();
|
||||
}
|
||||
|
||||
private static String escapeXml(String text) {
|
||||
return text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">");
|
||||
}
|
||||
}
|
||||
@@ -94,6 +94,27 @@ public class TranscriptionService {
|
||||
return saved;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upsert an OCR transcription block for a pre-existing annotation (guided OCR mode).
|
||||
* If the annotation already has a MANUAL block, it is left unchanged.
|
||||
* If it has an OCR block, the text is updated in-place.
|
||||
* If it has no block yet, a new OCR block is created.
|
||||
*/
|
||||
@Transactional
|
||||
public TranscriptionBlock upsertGuidedBlock(UUID documentId, UUID annotationId,
|
||||
String text, UUID userId) {
|
||||
return blockRepository.findByAnnotationId(annotationId).map(existing -> {
|
||||
if (existing.getSource() == BlockSource.MANUAL && !existing.getText().isBlank()) {
|
||||
return existing; // never overwrite non-empty manual transcription
|
||||
}
|
||||
existing.setText(sanitizeText(text));
|
||||
existing.setUpdatedBy(userId);
|
||||
TranscriptionBlock saved = blockRepository.save(existing);
|
||||
saveVersion(saved, userId);
|
||||
return saved;
|
||||
}).orElseGet(() -> createOcrBlock(documentId, annotationId, text, 0, userId));
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public TranscriptionBlock updateBlock(UUID documentId, UUID blockId,
|
||||
UpdateTranscriptionBlockDTO dto, UUID userId) {
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE document_training_labels (
|
||||
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
||||
label VARCHAR(50) NOT NULL,
|
||||
PRIMARY KEY (document_id, label)
|
||||
);
|
||||
@@ -0,0 +1,16 @@
|
||||
CREATE TABLE ocr_training_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'RUNNING',
|
||||
block_count INT NOT NULL,
|
||||
document_count INT NOT NULL,
|
||||
model_name VARCHAR(100) NOT NULL,
|
||||
error_message TEXT,
|
||||
triggered_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
completed_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
-- Enforce single active run at the DB layer (application check is the UX layer)
|
||||
CREATE UNIQUE INDEX idx_ocr_training_runs_one_running
|
||||
ON ocr_training_runs (status)
|
||||
WHERE status = 'RUNNING';
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Intentional: segmentation-only blocks have no text.
|
||||
-- This migration is irreversible without a data cleanup step
|
||||
-- (cannot re-add NOT NULL if null rows exist).
|
||||
ALTER TABLE transcription_blocks ALTER COLUMN text DROP NOT NULL;
|
||||
ALTER TABLE transcription_blocks ALTER COLUMN text SET DEFAULT '';
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE ocr_training_runs
|
||||
ADD COLUMN cer DOUBLE PRECISION,
|
||||
ADD COLUMN loss DOUBLE PRECISION,
|
||||
ADD COLUMN accuracy DOUBLE PRECISION,
|
||||
ADD COLUMN epochs INT;
|
||||
@@ -123,15 +123,19 @@ class AnnotationControllerTest {
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ANNOTATE_ALL")
|
||||
void createAnnotation_returns409_whenOverlap() throws Exception {
|
||||
void createAnnotation_returns201_whenAnnotationsOverlap() throws Exception {
|
||||
// Overlapping annotations are allowed — historical letter lines often intersect
|
||||
UUID docId = UUID.randomUUID();
|
||||
DocumentAnnotation saved = DocumentAnnotation.builder()
|
||||
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
|
||||
.x(0.1).y(0.1).width(0.3).height(0.3).color("#ff0000").build();
|
||||
when(documentService.getDocumentById(any())).thenReturn(Document.builder().build());
|
||||
when(annotationService.createAnnotation(any(), any(), any(), any()))
|
||||
.thenThrow(DomainException.conflict(ErrorCode.ANNOTATION_OVERLAP, "Overlap"));
|
||||
when(annotationService.createAnnotation(any(), any(), any(), any())).thenReturn(saved);
|
||||
|
||||
mockMvc.perform(post("/api/documents/" + UUID.randomUUID() + "/annotations")
|
||||
mockMvc.perform(post("/api/documents/" + docId + "/annotations")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content(ANNOTATION_JSON))
|
||||
.andExpect(status().isConflict());
|
||||
.andExpect(status().isCreated());
|
||||
}
|
||||
|
||||
// ─── DELETE /api/documents/{documentId}/annotations/{annotationId} ─────────
|
||||
|
||||
@@ -14,6 +14,7 @@ import org.raddatz.familienarchiv.service.FileService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.webmvc.test.autoconfigure.WebMvcTest;
|
||||
import org.raddatz.familienarchiv.config.SecurityConfig;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.boot.autoconfigure.aop.AopAutoConfiguration;
|
||||
import org.springframework.context.annotation.Import;
|
||||
import org.springframework.security.test.context.support.WithMockUser;
|
||||
@@ -33,6 +34,7 @@ import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.multipart;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.patch;
|
||||
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
|
||||
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
|
||||
|
||||
@@ -485,6 +487,58 @@ class DocumentControllerTest {
|
||||
.andExpect(jsonPath("$[0].editorName").value("Emma Müller"));
|
||||
}
|
||||
|
||||
// ─── PATCH /api/documents/{id}/training-labels ───────────────────────────
|
||||
|
||||
@Test
|
||||
void patchTrainingLabels_returns401_whenUnauthenticated() throws Exception {
|
||||
mockMvc.perform(patch("/api/documents/" + UUID.randomUUID() + "/training-labels")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{\"label\":\"KURRENT_RECOGNITION\",\"enrolled\":true}"))
|
||||
.andExpect(status().isUnauthorized());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser
|
||||
void patchTrainingLabels_returns403_whenMissingWritePermission() throws Exception {
|
||||
mockMvc.perform(patch("/api/documents/" + UUID.randomUUID() + "/training-labels")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{\"label\":\"KURRENT_RECOGNITION\",\"enrolled\":true}"))
|
||||
.andExpect(status().isForbidden());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void patchTrainingLabels_returns204_whenAddingLabel() throws Exception {
|
||||
UUID id = UUID.randomUUID();
|
||||
mockMvc.perform(patch("/api/documents/" + id + "/training-labels")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{\"label\":\"KURRENT_RECOGNITION\",\"enrolled\":true}"))
|
||||
.andExpect(status().isNoContent());
|
||||
|
||||
verify(documentService).addTrainingLabel(eq(id), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void patchTrainingLabels_returns204_whenRemovingLabel() throws Exception {
|
||||
UUID id = UUID.randomUUID();
|
||||
mockMvc.perform(patch("/api/documents/" + id + "/training-labels")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{\"label\":\"KURRENT_SEGMENTATION\",\"enrolled\":false}"))
|
||||
.andExpect(status().isNoContent());
|
||||
|
||||
verify(documentService).removeTrainingLabel(eq(id), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void patchTrainingLabels_returns400_whenUnknownLabel() throws Exception {
|
||||
mockMvc.perform(patch("/api/documents/" + UUID.randomUUID() + "/training-labels")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{\"label\":\"UNKNOWN_GARBAGE\",\"enrolled\":true}"))
|
||||
.andExpect(status().isBadRequest());
|
||||
}
|
||||
|
||||
// ─── GET /api/documents/{id}/versions/{versionId} ────────────────────────
|
||||
|
||||
@Test
|
||||
|
||||
@@ -24,6 +24,7 @@ import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
||||
@@ -43,15 +44,18 @@ class OcrControllerTest {
|
||||
@MockitoBean OcrProgressService ocrProgressService;
|
||||
@MockitoBean UserService userService;
|
||||
@MockitoBean CustomUserDetailsService customUserDetailsService;
|
||||
@MockitoBean TrainingDataExportService trainingDataExportService;
|
||||
@MockitoBean SegmentationTrainingExportService segmentationTrainingExportService;
|
||||
@MockitoBean OcrTrainingService ocrTrainingService;
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void triggerOcr_returns202_withJobId() throws Exception {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID jobId = UUID.randomUUID();
|
||||
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER);
|
||||
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER, false);
|
||||
|
||||
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any())).thenReturn(jobId);
|
||||
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any(), anyBoolean())).thenReturn(jobId);
|
||||
|
||||
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
@@ -64,7 +68,7 @@ class OcrControllerTest {
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void triggerOcr_returns400_whenDocumentNotUploaded() throws Exception {
|
||||
UUID docId = UUID.randomUUID();
|
||||
when(ocrService.startOcr(eq(docId), any(), any()))
|
||||
when(ocrService.startOcr(eq(docId), any(), any(), anyBoolean()))
|
||||
.thenThrow(DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED, "Not uploaded"));
|
||||
|
||||
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
|
||||
@@ -121,6 +125,109 @@ class OcrControllerTest {
|
||||
.andExpect(jsonPath("$.jobId").value(jobId.toString()));
|
||||
}
|
||||
|
||||
// ─── GET /api/ocr/training-data/export ───────────────────────────────────
|
||||
|
||||
@Test
|
||||
void exportTrainingData_returns401_whenUnauthenticated() throws Exception {
|
||||
mockMvc.perform(get("/api/ocr/training-data/export"))
|
||||
.andExpect(status().isUnauthorized());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void exportTrainingData_returns403_whenNotAdmin() throws Exception {
|
||||
mockMvc.perform(get("/api/ocr/training-data/export"))
|
||||
.andExpect(status().isForbidden());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void exportTrainingData_returns204_whenNoEligibleBlocks() throws Exception {
|
||||
when(trainingDataExportService.queryEligibleBlocks()).thenReturn(List.of());
|
||||
|
||||
mockMvc.perform(get("/api/ocr/training-data/export"))
|
||||
.andExpect(status().isNoContent());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void exportTrainingData_returns200_withZipContentType_whenBlocksExist() throws Exception {
|
||||
org.raddatz.familienarchiv.model.TranscriptionBlock block =
|
||||
org.raddatz.familienarchiv.model.TranscriptionBlock.builder()
|
||||
.id(UUID.randomUUID()).documentId(UUID.randomUUID())
|
||||
.annotationId(UUID.randomUUID()).text("x").sortOrder(0).build();
|
||||
when(trainingDataExportService.queryEligibleBlocks()).thenReturn(List.of(block));
|
||||
when(trainingDataExportService.exportToZip()).thenReturn(out -> {});
|
||||
|
||||
mockMvc.perform(get("/api/ocr/training-data/export"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(result ->
|
||||
org.assertj.core.api.Assertions.assertThat(
|
||||
result.getResponse().getContentType()).contains("application/zip"));
|
||||
}
|
||||
|
||||
// ─── POST /api/ocr/train ───────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_returns401_whenUnauthenticated() throws Exception {
|
||||
mockMvc.perform(post("/api/ocr/train"))
|
||||
.andExpect(status().isUnauthorized());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void triggerTraining_returns403_whenNotAdmin() throws Exception {
|
||||
mockMvc.perform(post("/api/ocr/train"))
|
||||
.andExpect(status().isForbidden());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void triggerTraining_returns409_whenRunAlreadyRunning() throws Exception {
|
||||
when(ocrTrainingService.triggerTraining(any()))
|
||||
.thenThrow(DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, "Already running"));
|
||||
|
||||
mockMvc.perform(post("/api/ocr/train"))
|
||||
.andExpect(status().isConflict());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void triggerTraining_returns201_withRunInfo() throws Exception {
|
||||
UUID runId = UUID.randomUUID();
|
||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||
.id(runId).status(TrainingStatus.DONE)
|
||||
.blockCount(10).documentCount(3).modelName("german_kurrent").build();
|
||||
when(ocrTrainingService.triggerTraining(any())).thenReturn(run);
|
||||
|
||||
mockMvc.perform(post("/api/ocr/train"))
|
||||
.andExpect(status().isCreated())
|
||||
.andExpect(jsonPath("$.status").value("DONE"))
|
||||
.andExpect(jsonPath("$.blockCount").value(10));
|
||||
}
|
||||
|
||||
// ─── GET /api/ocr/training-info ───────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void getTrainingInfo_returns403_whenNotAdmin() throws Exception {
|
||||
mockMvc.perform(get("/api/ocr/training-info"))
|
||||
.andExpect(status().isForbidden());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void getTrainingInfo_returns200_withInfo() throws Exception {
|
||||
OcrTrainingService.TrainingInfoResponse info =
|
||||
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of());
|
||||
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
|
||||
|
||||
mockMvc.perform(get("/api/ocr/training-info"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.availableBlocks").value(5))
|
||||
.andExpect(jsonPath("$.ocrServiceAvailable").value(true));
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
package org.raddatz.familienarchiv.repository;
|
||||
|
||||
import jakarta.persistence.EntityManager;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.PostgresContainerConfig;
|
||||
import org.raddatz.familienarchiv.config.FlywayConfig;
|
||||
import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.model.DocumentStatus;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.data.jpa.test.autoconfigure.DataJpaTest;
|
||||
import org.springframework.boot.jdbc.test.autoconfigure.AutoConfigureTestDatabase;
|
||||
import org.springframework.context.annotation.Import;
|
||||
import org.springframework.dao.DataIntegrityViolationException;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.transaction.annotation.Propagation;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
/**
|
||||
* Integration tests that verify DB-level constraints introduced in the OCR pipeline migrations
|
||||
* are actually enforced by PostgreSQL. These tests exercise constraints that cannot be verified
|
||||
* by unit tests alone.
|
||||
*/
|
||||
@DataJpaTest
|
||||
@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE)
|
||||
@Import({PostgresContainerConfig.class, FlywayConfig.class})
|
||||
class MigrationIntegrationTest {
|
||||
|
||||
@Autowired JdbcTemplate jdbc;
|
||||
@Autowired DocumentRepository documentRepository;
|
||||
@Autowired EntityManager em;
|
||||
|
||||
// ─── V23: chk_annotation_polygon_quad CHECK constraint ───────────────────
|
||||
|
||||
@Test
|
||||
void v23_polygonCheckConstraint_rejectsNonQuadrilateral() {
|
||||
UUID docId = createDocument();
|
||||
|
||||
// A 3-point polygon violates chk_annotation_polygon_quad (must be exactly 4 points or NULL)
|
||||
assertThatThrownBy(() ->
|
||||
jdbc.update(
|
||||
"""
|
||||
INSERT INTO document_annotations
|
||||
(id, document_id, page_number, x, y, width, height, color, polygon)
|
||||
VALUES (gen_random_uuid(), ?, 1, 0.1, 0.1, 0.5, 0.1, '#00C7B1',
|
||||
'[[0.1,0.1],[0.9,0.1],[0.9,0.2]]'::jsonb)
|
||||
""",
|
||||
docId)
|
||||
).isInstanceOf(DataIntegrityViolationException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void v23_polygonCheckConstraint_allowsNullPolygon() {
|
||||
UUID docId = createDocument();
|
||||
|
||||
int rows = jdbc.update(
|
||||
"""
|
||||
INSERT INTO document_annotations
|
||||
(id, document_id, page_number, x, y, width, height, color, polygon)
|
||||
VALUES (gen_random_uuid(), ?, 1, 0.1, 0.1, 0.5, 0.1, '#00C7B1', NULL)
|
||||
""",
|
||||
docId);
|
||||
|
||||
assertThat(rows).isEqualTo(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void v23_polygonCheckConstraint_allowsQuadrilateral() {
|
||||
UUID docId = createDocument();
|
||||
|
||||
int rows = jdbc.update(
|
||||
"""
|
||||
INSERT INTO document_annotations
|
||||
(id, document_id, page_number, x, y, width, height, color, polygon)
|
||||
VALUES (gen_random_uuid(), ?, 1, 0.1, 0.1, 0.5, 0.1, '#00C7B1',
|
||||
'[[0.1,0.1],[0.9,0.1],[0.9,0.2],[0.1,0.2]]'::jsonb)
|
||||
""",
|
||||
docId);
|
||||
|
||||
assertThat(rows).isEqualTo(1);
|
||||
}
|
||||
|
||||
// ─── V30: idx_ocr_training_runs_one_running partial unique index ──────────
|
||||
|
||||
@Test
|
||||
@Transactional(propagation = Propagation.NOT_SUPPORTED)
|
||||
void v30_partialUniqueIndex_preventsTwoRunningTrainingRuns() {
|
||||
jdbc.update("""
|
||||
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name)
|
||||
VALUES (gen_random_uuid(), 'RUNNING', 10, 2, 'kurrent_v1')
|
||||
""");
|
||||
|
||||
// A second RUNNING row violates the partial unique index
|
||||
assertThatThrownBy(() ->
|
||||
jdbc.update("""
|
||||
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name)
|
||||
VALUES (gen_random_uuid(), 'RUNNING', 5, 1, 'kurrent_v1')
|
||||
""")
|
||||
).isInstanceOf(DataIntegrityViolationException.class);
|
||||
|
||||
// Clean up — runs outside the DataJpaTest transaction, so must be explicit
|
||||
jdbc.update("DELETE FROM ocr_training_runs WHERE status = 'RUNNING'");
|
||||
}
|
||||
|
||||
@Test
|
||||
void v30_partialUniqueIndex_allowsMultipleDoneRuns() {
|
||||
int rows1 = jdbc.update("""
|
||||
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name)
|
||||
VALUES (gen_random_uuid(), 'DONE', 10, 2, 'kurrent_v1')
|
||||
""");
|
||||
int rows2 = jdbc.update("""
|
||||
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name)
|
||||
VALUES (gen_random_uuid(), 'DONE', 15, 3, 'kurrent_v2')
|
||||
""");
|
||||
|
||||
assertThat(rows1).isEqualTo(1);
|
||||
assertThat(rows2).isEqualTo(1);
|
||||
}
|
||||
|
||||
// ─── helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
private UUID createDocument() {
|
||||
Document doc = documentRepository.save(Document.builder()
|
||||
.title("Testdokument")
|
||||
.originalFilename("test.pdf")
|
||||
.status(DocumentStatus.UPLOADED)
|
||||
.build());
|
||||
// Flush so the row is visible to subsequent JdbcTemplate queries within the same transaction
|
||||
em.flush();
|
||||
return doc.getId();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package org.raddatz.familienarchiv.repository;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.PostgresContainerConfig;
|
||||
import org.raddatz.familienarchiv.config.FlywayConfig;
|
||||
import org.raddatz.familienarchiv.model.*;
|
||||
import java.util.Set;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.jdbc.test.autoconfigure.AutoConfigureTestDatabase;
|
||||
import org.springframework.boot.data.jpa.test.autoconfigure.DataJpaTest;
|
||||
import org.springframework.context.annotation.Import;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@DataJpaTest
|
||||
@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE)
|
||||
@Import({PostgresContainerConfig.class, FlywayConfig.class})
|
||||
class TrainingBlockQueryTest {
|
||||
|
||||
@Autowired TranscriptionBlockRepository blockRepository;
|
||||
@Autowired DocumentRepository documentRepository;
|
||||
@Autowired AnnotationRepository annotationRepository;
|
||||
|
||||
private UUID kurrentDocId;
|
||||
private UUID typewriterDocId;
|
||||
private UUID kurrentAnnotationId;
|
||||
private UUID typewriterAnnotationId;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
Document kurrentDoc = documentRepository.save(Document.builder()
|
||||
.title("Kurrent Brief")
|
||||
.originalFilename("kurrent.pdf")
|
||||
.status(DocumentStatus.UPLOADED)
|
||||
.trainingLabels(new java.util.HashSet<>(Set.of(TrainingLabel.KURRENT_RECOGNITION)))
|
||||
.build());
|
||||
kurrentDocId = kurrentDoc.getId();
|
||||
|
||||
Document typewriterDoc = documentRepository.save(Document.builder()
|
||||
.title("Getippter Brief")
|
||||
.originalFilename("typed.pdf")
|
||||
.status(DocumentStatus.UPLOADED)
|
||||
.build());
|
||||
typewriterDocId = typewriterDoc.getId();
|
||||
|
||||
kurrentAnnotationId = annotationRepository.save(annotation(kurrentDocId)).getId();
|
||||
typewriterAnnotationId = annotationRepository.save(annotation(typewriterDocId)).getId();
|
||||
}
|
||||
|
||||
@Test
|
||||
void findEligibleKurrentBlocks_includesReviewedManualBlock() {
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.MANUAL, true));
|
||||
|
||||
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
|
||||
|
||||
assertThat(result).hasSize(1);
|
||||
assertThat(result.get(0).getSource()).isEqualTo(BlockSource.MANUAL);
|
||||
}
|
||||
|
||||
@Test
|
||||
void findEligibleKurrentBlocks_excludesUnreviewedManualBlock() {
|
||||
// MANUAL blocks filled by guided OCR but not yet reviewed should not count as checked text
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.MANUAL, false));
|
||||
|
||||
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
|
||||
|
||||
assertThat(result).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void findEligibleKurrentBlocks_includesReviewedOcrBlock() {
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.OCR, true));
|
||||
|
||||
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
|
||||
|
||||
assertThat(result).hasSize(1);
|
||||
assertThat(result.get(0).isReviewed()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void findEligibleKurrentBlocks_excludesUnreviewedOcrBlock() {
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.OCR, false));
|
||||
|
||||
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
|
||||
|
||||
assertThat(result).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void findEligibleKurrentBlocks_excludesNonEnrolledDocument() {
|
||||
blockRepository.save(block(typewriterDocId, typewriterAnnotationId, BlockSource.MANUAL, false));
|
||||
|
||||
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
|
||||
|
||||
assertThat(result).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void findEligibleKurrentBlocks_returnsAllReviewedBlocksAcrossBothSources() {
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.MANUAL, true));
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.OCR, true));
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.MANUAL, false)); // excluded
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.OCR, false)); // excluded
|
||||
|
||||
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
|
||||
|
||||
assertThat(result).hasSize(2);
|
||||
}
|
||||
|
||||
// ─── helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
private DocumentAnnotation annotation(UUID docId) {
|
||||
return DocumentAnnotation.builder()
|
||||
.documentId(docId)
|
||||
.pageNumber(1)
|
||||
.x(0.1).y(0.2).width(0.3).height(0.4)
|
||||
.color("#00C7B1")
|
||||
.build();
|
||||
}
|
||||
|
||||
private TranscriptionBlock block(UUID docId, UUID annotId, BlockSource source, boolean reviewed) {
|
||||
return TranscriptionBlock.builder()
|
||||
.annotationId(annotId)
|
||||
.documentId(docId)
|
||||
.text("Liebe Tante")
|
||||
.sortOrder(0)
|
||||
.source(source)
|
||||
.reviewed(reviewed)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import org.raddatz.familienarchiv.dto.CreateAnnotationDTO;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.model.DocumentAnnotation;
|
||||
import org.raddatz.familienarchiv.repository.AnnotationRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
@@ -20,7 +21,6 @@ import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.http.HttpStatus.CONFLICT;
|
||||
import static org.springframework.http.HttpStatus.FORBIDDEN;
|
||||
import static org.springframework.http.HttpStatus.NOT_FOUND;
|
||||
|
||||
@@ -28,39 +28,19 @@ import static org.springframework.http.HttpStatus.NOT_FOUND;
|
||||
class AnnotationServiceTest {
|
||||
|
||||
@Mock AnnotationRepository annotationRepository;
|
||||
@Mock TranscriptionBlockRepository blockRepository;
|
||||
@InjectMocks AnnotationService annotationService;
|
||||
|
||||
// ─── createAnnotation ─────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void createAnnotation_throwsConflict_whenAnnotationOverlapsExisting() {
|
||||
void createAnnotation_savesAnnotation() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.3, 0.3, "#ff0000");
|
||||
|
||||
DocumentAnnotation existing = DocumentAnnotation.builder()
|
||||
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
|
||||
.x(0.2).y(0.2).width(0.3).height(0.3).color("#00ff00").build();
|
||||
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1))
|
||||
.thenReturn(List.of(existing));
|
||||
|
||||
assertThatThrownBy(() -> annotationService.createAnnotation(docId, dto, userId, null))
|
||||
.isInstanceOf(DomainException.class)
|
||||
.satisfies(e -> assertThat(((DomainException) e).getStatus()).isEqualTo(CONFLICT));
|
||||
|
||||
verify(annotationRepository, never()).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void createAnnotation_savesAndReturns_whenNoOverlap() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.0, 0.0, 0.05, 0.05, "#ff0000");
|
||||
|
||||
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of());
|
||||
DocumentAnnotation saved = DocumentAnnotation.builder()
|
||||
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
|
||||
.x(0.0).y(0.0).width(0.05).height(0.05).color("#ff0000").createdBy(userId).build();
|
||||
.x(0.1).y(0.1).width(0.3).height(0.3).color("#ff0000").createdBy(userId).build();
|
||||
when(annotationRepository.save(any())).thenReturn(saved);
|
||||
|
||||
DocumentAnnotation result = annotationService.createAnnotation(docId, dto, userId, null);
|
||||
@@ -69,6 +49,77 @@ class AnnotationServiceTest {
|
||||
verify(annotationRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void createAnnotation_allowsOverlappingAnnotations() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.3, 0.3, "#ff0000");
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
// Should not throw even when overlapping annotations exist on the same page
|
||||
annotationService.createAnnotation(docId, dto, userId, null);
|
||||
|
||||
verify(annotationRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void createAnnotation_setsFileHash_whenProvided() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.0, 0.0, 0.05, 0.05, "#ff0000");
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
DocumentAnnotation result = annotationService.createAnnotation(docId, dto, userId, "abc123");
|
||||
|
||||
assertThat(result.getFileHash()).isEqualTo("abc123");
|
||||
}
|
||||
|
||||
@Test
|
||||
void createAnnotation_setsNullFileHash_whenNoneProvided() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.0, 0.0, 0.05, 0.05, "#ff0000");
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
DocumentAnnotation result = annotationService.createAnnotation(docId, dto, userId, null);
|
||||
|
||||
assertThat(result.getFileHash()).isNull();
|
||||
}
|
||||
|
||||
// ─── createOcrAnnotation ──────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void createOcrAnnotation_savesWithPolygon() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.8, 0.04, "#00C7B1");
|
||||
List<List<Double>> polygon = List.of(
|
||||
List.of(0.1, 0.1), List.of(0.9, 0.11),
|
||||
List.of(0.89, 0.14), List.of(0.11, 0.13));
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
DocumentAnnotation result = annotationService.createOcrAnnotation(
|
||||
docId, dto, userId, "filehash", polygon);
|
||||
|
||||
assertThat(result.getPolygon()).isEqualTo(polygon);
|
||||
assertThat(result.getDocumentId()).isEqualTo(docId);
|
||||
verify(annotationRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void createOcrAnnotation_savesWithNullPolygon_whenPolygonNotProvided() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.8, 0.04, "#00C7B1");
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
DocumentAnnotation result = annotationService.createOcrAnnotation(
|
||||
docId, dto, userId, "filehash", null);
|
||||
|
||||
assertThat(result.getPolygon()).isNull();
|
||||
verify(annotationRepository).save(any());
|
||||
}
|
||||
|
||||
// ─── deleteAnnotation ─────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
@@ -114,36 +165,42 @@ class AnnotationServiceTest {
|
||||
|
||||
annotationService.deleteAnnotation(docId, annotId, ownerId);
|
||||
|
||||
verify(blockRepository).deleteByAnnotationId(annotId);
|
||||
verify(annotationRepository).delete(annotation);
|
||||
}
|
||||
|
||||
@Test
|
||||
void createAnnotation_setsFileHash_whenProvided() {
|
||||
void deleteAnnotation_deletesLinkedTranscriptionBlock_beforeAnnotation() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.0, 0.0, 0.05, 0.05, "#ff0000");
|
||||
String fileHash = "abc123";
|
||||
UUID annotId = UUID.randomUUID();
|
||||
UUID ownerId = UUID.randomUUID();
|
||||
|
||||
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of());
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
DocumentAnnotation annotation = DocumentAnnotation.builder()
|
||||
.id(annotId).documentId(docId).createdBy(ownerId).build();
|
||||
when(annotationRepository.findByIdAndDocumentId(annotId, docId))
|
||||
.thenReturn(Optional.of(annotation));
|
||||
|
||||
DocumentAnnotation result = annotationService.createAnnotation(docId, dto, userId, fileHash);
|
||||
annotationService.deleteAnnotation(docId, annotId, ownerId);
|
||||
|
||||
assertThat(result.getFileHash()).isEqualTo(fileHash);
|
||||
var inOrder = org.mockito.Mockito.inOrder(blockRepository, annotationRepository);
|
||||
inOrder.verify(blockRepository).deleteByAnnotationId(annotId);
|
||||
inOrder.verify(annotationRepository).delete(annotation);
|
||||
}
|
||||
|
||||
@Test
|
||||
void createAnnotation_setsNullFileHash_whenNoneProvided() {
|
||||
void deleteAnnotation_throwsForbidden_whenUserIdIsNull() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.0, 0.0, 0.05, 0.05, "#ff0000");
|
||||
UUID annotId = UUID.randomUUID();
|
||||
UUID ownerId = UUID.randomUUID();
|
||||
|
||||
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of());
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
DocumentAnnotation annotation = DocumentAnnotation.builder()
|
||||
.id(annotId).documentId(docId).createdBy(ownerId).build();
|
||||
when(annotationRepository.findByIdAndDocumentId(annotId, docId))
|
||||
.thenReturn(Optional.of(annotation));
|
||||
|
||||
DocumentAnnotation result = annotationService.createAnnotation(docId, dto, userId, null);
|
||||
|
||||
assertThat(result.getFileHash()).isNull();
|
||||
assertThatThrownBy(() -> annotationService.deleteAnnotation(docId, annotId, null))
|
||||
.isInstanceOf(DomainException.class)
|
||||
.satisfies(e -> assertThat(((DomainException) e).getStatus()).isEqualTo(FORBIDDEN));
|
||||
}
|
||||
|
||||
// ─── listAnnotations ──────────────────────────────────────────────────────
|
||||
@@ -183,149 +240,4 @@ class AnnotationServiceTest {
|
||||
|
||||
verify(annotationRepository, never()).save(any());
|
||||
}
|
||||
|
||||
// ─── deleteAnnotation — null userId ───────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void deleteAnnotation_throwsForbidden_whenUserIdIsNull() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID annotId = UUID.randomUUID();
|
||||
UUID ownerId = UUID.randomUUID();
|
||||
|
||||
DocumentAnnotation annotation = DocumentAnnotation.builder()
|
||||
.id(annotId).documentId(docId).createdBy(ownerId).build();
|
||||
when(annotationRepository.findByIdAndDocumentId(annotId, docId))
|
||||
.thenReturn(Optional.of(annotation));
|
||||
|
||||
assertThatThrownBy(() -> annotationService.deleteAnnotation(docId, annotId, null))
|
||||
.isInstanceOf(DomainException.class)
|
||||
.satisfies(e -> assertThat(((DomainException) e).getStatus()).isEqualTo(FORBIDDEN));
|
||||
}
|
||||
|
||||
// ─── overlaps — partial overlap cases ────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void createAnnotation_noConflict_whenAnnotationIsToTheLeft() {
|
||||
// existing: x=0.5, w=0.3 (x2=0.8); dto: x=0.0, w=0.4 (dx2=0.4)
|
||||
// existing.getX() < dx2 → 0.5 < 0.4 → FALSE → no overlap (first && fails)
|
||||
UUID docId = UUID.randomUUID();
|
||||
DocumentAnnotation existing = DocumentAnnotation.builder()
|
||||
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
|
||||
.x(0.5).y(0.0).width(0.3).height(0.5).color("#ff0000").build();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.0, 0.0, 0.4, 0.5, "#0000ff");
|
||||
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of(existing));
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
annotationService.createAnnotation(docId, dto, UUID.randomUUID(), null);
|
||||
|
||||
verify(annotationRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void createAnnotation_noConflict_whenAnnotationIsToTheRight() {
|
||||
// existing: x=0.0, w=0.1 (ex2=0.1); dto: x=0.2, w=0.3 (dx2=0.5)
|
||||
// existing.getX() < dx2 → 0.0 < 0.5 → TRUE
|
||||
// ex2 > dto.getX() → 0.1 > 0.2 → FALSE → no overlap (second && fails)
|
||||
UUID docId = UUID.randomUUID();
|
||||
DocumentAnnotation existing = DocumentAnnotation.builder()
|
||||
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
|
||||
.x(0.0).y(0.0).width(0.1).height(0.5).color("#ff0000").build();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.2, 0.0, 0.3, 0.5, "#0000ff");
|
||||
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of(existing));
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
annotationService.createAnnotation(docId, dto, UUID.randomUUID(), null);
|
||||
|
||||
verify(annotationRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void createAnnotation_noConflict_whenAnnotationIsBelow() {
|
||||
// x ranges overlap, but y ranges don't
|
||||
// existing: x=0.0, w=0.5, y=0.5, h=0.2 (ey2=0.7)
|
||||
// dto: x=0.1, w=0.3 (dx2=0.4), y=0.0, h=0.4 (dy2=0.4)
|
||||
// existing.getX() < dx2 → 0.0 < 0.4 → TRUE
|
||||
// ex2 > dto.getX() → 0.5 > 0.1 → TRUE
|
||||
// existing.getY() < dy2 → 0.5 < 0.4 → FALSE → no overlap (third && fails)
|
||||
UUID docId = UUID.randomUUID();
|
||||
DocumentAnnotation existing = DocumentAnnotation.builder()
|
||||
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
|
||||
.x(0.0).y(0.5).width(0.5).height(0.2).color("#ff0000").build();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.0, 0.3, 0.4, "#0000ff");
|
||||
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of(existing));
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
annotationService.createAnnotation(docId, dto, UUID.randomUUID(), null);
|
||||
|
||||
verify(annotationRepository).save(any());
|
||||
}
|
||||
|
||||
// ─── createOcrAnnotation ──────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void createOcrAnnotation_skipsOverlapCheck_andSavesWithPolygon() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.8, 0.04, "#00C7B1");
|
||||
List<List<Double>> polygon = List.of(
|
||||
List.of(0.1, 0.1), List.of(0.9, 0.11),
|
||||
List.of(0.89, 0.14), List.of(0.11, 0.13));
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
DocumentAnnotation result = annotationService.createOcrAnnotation(
|
||||
docId, dto, userId, "filehash", polygon);
|
||||
|
||||
assertThat(result.getPolygon()).isEqualTo(polygon);
|
||||
assertThat(result.getDocumentId()).isEqualTo(docId);
|
||||
verify(annotationRepository).save(any());
|
||||
verify(annotationRepository, never()).findByDocumentIdAndPageNumber(any(), any(int.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void createOcrAnnotation_savesWithNullPolygon_whenPolygonNotProvided() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.8, 0.04, "#00C7B1");
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
DocumentAnnotation result = annotationService.createOcrAnnotation(
|
||||
docId, dto, userId, "filehash", null);
|
||||
|
||||
assertThat(result.getPolygon()).isNull();
|
||||
verify(annotationRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void createOcrAnnotation_doesNotCheckOverlap_evenWhenOverlappingAnnotationExists() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.3, 0.3, "#00C7B1");
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
annotationService.createOcrAnnotation(docId, dto, userId, "hash", null);
|
||||
|
||||
verify(annotationRepository, never()).findByDocumentIdAndPageNumber(any(), any(int.class));
|
||||
}
|
||||
|
||||
// ─── overlaps — partial overlap cases ────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void createAnnotation_noConflict_whenAnnotationIsAbove() {
|
||||
// x ranges overlap, y ranges don't — existing is ABOVE the new annotation
|
||||
// existing: x=0.0, w=0.5, y=0.0, h=0.1 (ey2=0.1)
|
||||
// dto: x=0.1, w=0.3 (dx2=0.4), y=0.2, h=0.3 (dy2=0.5)
|
||||
// A: 0.0 < 0.4 → TRUE, B: 0.5 > 0.1 → TRUE, C: 0.0 < 0.5 → TRUE
|
||||
// D: ey2 > dto.getY() → 0.1 > 0.2 → FALSE → no overlap (fourth && fails)
|
||||
UUID docId = UUID.randomUUID();
|
||||
DocumentAnnotation existing = DocumentAnnotation.builder()
|
||||
.id(UUID.randomUUID()).documentId(docId).pageNumber(1)
|
||||
.x(0.0).y(0.0).width(0.5).height(0.1).color("#ff0000").build();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.2, 0.3, 0.3, "#0000ff");
|
||||
when(annotationRepository.findByDocumentIdAndPageNumber(docId, 1)).thenReturn(List.of(existing));
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
annotationService.createAnnotation(docId, dto, UUID.randomUUID(), null);
|
||||
|
||||
verify(annotationRepository).save(any());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,8 +60,8 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2")));
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null)));
|
||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||
|
||||
@@ -84,7 +84,7 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test")));
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null)));
|
||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||
|
||||
@@ -112,12 +112,12 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(1));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -142,7 +142,7 @@ class OcrAsyncRunnerTest {
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -174,18 +174,18 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
List<String> progressMessages = new ArrayList<>();
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(3));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2"))));
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2", null))));
|
||||
progressMessages.add(job.getProgressMessage());
|
||||
handler.accept(new OcrStreamEvent.Page(1, List.of(
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3"))));
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3", null))));
|
||||
progressMessages.add(job.getProgressMessage());
|
||||
handler.accept(new OcrStreamEvent.Done(3, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -215,14 +215,14 @@ class OcrAsyncRunnerTest {
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(3));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Error(1, "failed"));
|
||||
handler.accept(new OcrStreamEvent.Page(2, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 1));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -251,13 +251,13 @@ class OcrAsyncRunnerTest {
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(2));
|
||||
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
|
||||
handler.accept(new OcrStreamEvent.Page(1, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 1));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
|
||||
@@ -12,13 +12,19 @@ class OcrClientDefaultStreamTest {
|
||||
|
||||
@Test
|
||||
void defaultStreamBlocksSynthesizesEventsFromExtractBlocks() {
|
||||
OcrClient client = (pdfUrl, scriptType) -> List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
|
||||
OcrClient client = new OcrClient() {
|
||||
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
|
||||
return List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null),
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3", null));
|
||||
}
|
||||
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
|
||||
};
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
|
||||
|
||||
assertThat(events).hasSize(4);
|
||||
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
||||
@@ -42,10 +48,14 @@ class OcrClientDefaultStreamTest {
|
||||
|
||||
@Test
|
||||
void defaultStreamBlocksHandlesEmptyResults() {
|
||||
OcrClient client = (pdfUrl, scriptType) -> List.of();
|
||||
OcrClient client = new OcrClient() {
|
||||
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); }
|
||||
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
|
||||
};
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
|
||||
|
||||
assertThat(events).hasSize(2);
|
||||
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
||||
|
||||
@@ -142,7 +142,7 @@ class OcrServiceTest {
|
||||
UUID result = ocrService.startOcr(docId, null, userId);
|
||||
|
||||
assertThat(result).isEqualTo(jobId);
|
||||
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId);
|
||||
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -17,7 +17,7 @@ class OcrStreamEventTest {
|
||||
|
||||
@Test
|
||||
void pageRecordHoldsBlocksAndPageNumber() {
|
||||
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test");
|
||||
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test", null);
|
||||
var page = new OcrStreamEvent.Page(0, List.of(block));
|
||||
assertThat(page.pageNumber()).isEqualTo(0);
|
||||
assertThat(page.blocks()).hasSize(1);
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingLabel;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.transaction.support.TransactionCallback;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.argThat;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class OcrTrainingServiceTest {
|
||||
|
||||
OcrTrainingRunRepository runRepository;
|
||||
TrainingDataExportService exportService;
|
||||
SegmentationTrainingExportService segExportService;
|
||||
OcrClient ocrClient;
|
||||
OcrHealthClient healthClient;
|
||||
TranscriptionBlockRepository blockRepository;
|
||||
TransactionTemplate txTemplate;
|
||||
OcrTrainingService service;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
runRepository = mock(OcrTrainingRunRepository.class);
|
||||
exportService = mock(TrainingDataExportService.class);
|
||||
segExportService = mock(SegmentationTrainingExportService.class);
|
||||
ocrClient = mock(OcrClient.class);
|
||||
healthClient = mock(OcrHealthClient.class);
|
||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||
txTemplate = mock(TransactionTemplate.class);
|
||||
|
||||
// Execute transaction callbacks inline so unit tests run without a real DataSource
|
||||
when(txTemplate.execute(any())).thenAnswer(inv -> {
|
||||
TransactionCallback<?> callback = inv.getArgument(0);
|
||||
return callback.doInTransaction(null);
|
||||
});
|
||||
|
||||
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
|
||||
|
||||
when(blockRepository.count()).thenReturn(0L);
|
||||
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||
when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
|
||||
}
|
||||
|
||||
// ─── Concurrent guard ─────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_throws409_whenRunningRunExists() {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING))
|
||||
.thenReturn(Optional.of(OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(2).modelName("german_kurrent").build()));
|
||||
|
||||
assertThatThrownBy(() -> service.triggerTraining(null))
|
||||
.isInstanceOf(DomainException.class)
|
||||
.extracting("status")
|
||||
.satisfies(s -> assertThat(s.toString()).contains("409"));
|
||||
}
|
||||
|
||||
// ─── Threshold guard ──────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_throws422_whenFewerThan5Blocks() {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
when(exportService.queryEligibleBlocks()).thenReturn(List.of(
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build()
|
||||
));
|
||||
|
||||
assertThatThrownBy(() -> service.triggerTraining(null))
|
||||
.isInstanceOf(DomainException.class);
|
||||
}
|
||||
|
||||
// ─── Happy path ───────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_createsRunWithCorrectCounts_andMarksDone() throws Exception {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
UUID docA = UUID.randomUUID();
|
||||
UUID docB = UUID.randomUUID();
|
||||
List<TranscriptionBlock> blocks = List.of(
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docB).build()
|
||||
);
|
||||
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
|
||||
when(exportService.exportToZip()).thenReturn(out -> {});
|
||||
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
|
||||
|
||||
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(2).modelName("german_kurrent").build();
|
||||
when(runRepository.save(any())).thenReturn(saved);
|
||||
|
||||
service.triggerTraining(null);
|
||||
|
||||
// Verify run created with correct counts and then updated to DONE
|
||||
verify(runRepository, times(2)).save(argThat(run ->
|
||||
run.getBlockCount() == 5 || run.getStatus() == TrainingStatus.DONE));
|
||||
}
|
||||
|
||||
// ─── Failure path ─────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerTraining_marksRunFailed_whenOcrClientThrows() throws Exception {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
UUID docA = UUID.randomUUID();
|
||||
List<TranscriptionBlock> blocks = List.of(
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build()
|
||||
);
|
||||
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
|
||||
when(exportService.exportToZip()).thenReturn(out -> {});
|
||||
when(ocrClient.trainModel(any())).thenThrow(new RuntimeException("OCR service timeout"));
|
||||
|
||||
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent").build();
|
||||
when(runRepository.save(any())).thenReturn(saved);
|
||||
|
||||
service.triggerTraining(null);
|
||||
|
||||
verify(runRepository, atLeastOnce()).save(argThat(run ->
|
||||
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
|
||||
}
|
||||
|
||||
// ─── Orphan recovery ──────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void recoverOrphanedRuns_marksRunFailed_whenOlderThanOneHour() {
|
||||
OcrTrainingRun orphan = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent")
|
||||
.createdAt(Instant.now().minusSeconds(7200))
|
||||
.build();
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(orphan));
|
||||
when(runRepository.save(any())).thenReturn(orphan);
|
||||
|
||||
service.recoverOrphanedRuns();
|
||||
|
||||
verify(runRepository).save(argThat(run ->
|
||||
run.getStatus() == TrainingStatus.FAILED
|
||||
&& run.getErrorMessage().contains("Abgebrochen")));
|
||||
}
|
||||
|
||||
@Test
|
||||
void recoverOrphanedRuns_doesNothing_whenRunIsRecent() {
|
||||
OcrTrainingRun recent = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent")
|
||||
.createdAt(Instant.now().minusSeconds(60))
|
||||
.build();
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.of(recent));
|
||||
|
||||
service.recoverOrphanedRuns();
|
||||
|
||||
verify(runRepository, never()).save(any());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,290 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.apache.pdfbox.Loader;
|
||||
import org.apache.pdfbox.pdmodel.PDDocument;
|
||||
import org.apache.pdfbox.pdmodel.PDPage;
|
||||
import org.apache.pdfbox.pdmodel.common.PDRectangle;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.PostgresContainerConfig;
|
||||
import org.raddatz.familienarchiv.config.FlywayConfig;
|
||||
import org.raddatz.familienarchiv.model.*;
|
||||
import org.raddatz.familienarchiv.repository.AnnotationRepository;
|
||||
import org.raddatz.familienarchiv.repository.DocumentRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.jdbc.test.autoconfigure.AutoConfigureTestDatabase;
|
||||
import org.springframework.boot.data.jpa.test.autoconfigure.DataJpaTest;
|
||||
import org.springframework.context.annotation.Import;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipInputStream;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@DataJpaTest
|
||||
@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE)
|
||||
@Import({PostgresContainerConfig.class, FlywayConfig.class})
|
||||
class TrainingDataExportServiceTest {
|
||||
|
||||
@Autowired TranscriptionBlockRepository blockRepository;
|
||||
@Autowired DocumentRepository documentRepository;
|
||||
@Autowired AnnotationRepository annotationRepository;
|
||||
|
||||
static byte[] minimalPdfBytes;
|
||||
|
||||
@BeforeAll
|
||||
static void createMinimalPdf() throws Exception {
|
||||
try (PDDocument doc = new PDDocument()) {
|
||||
doc.addPage(new PDPage(PDRectangle.A4));
|
||||
ByteArrayOutputStream out = new ByteArrayOutputStream();
|
||||
doc.save(out);
|
||||
minimalPdfBytes = out.toByteArray();
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Query: enrollment filter ─────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void export_includesManualBlockFromEnrolledDocument() throws Exception {
|
||||
UUID docId = enrolledDoc("enrolled.pdf");
|
||||
UUID annotId = annotation(docId);
|
||||
blockRepository.save(manualBlock(docId, annotId, "Liebe Mutter"));
|
||||
|
||||
FileService fileService = mockFileService();
|
||||
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
|
||||
|
||||
StreamingResponseBody body = service.exportToZip();
|
||||
byte[] zipBytes = stream(body);
|
||||
assertThat(zipEntryNames(zipBytes)).isNotEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void export_excludesUnreviewedManualBlockFromEnrolledDocument() throws Exception {
|
||||
// MANUAL blocks whose text hasn't been verified yet should not count as training data
|
||||
UUID docId = enrolledDoc("unreviewed-manual.pdf");
|
||||
UUID annotId = annotation(docId);
|
||||
TranscriptionBlock block = TranscriptionBlock.builder()
|
||||
.annotationId(annotId).documentId(docId)
|
||||
.text("Liebe Tante").sortOrder(0)
|
||||
.source(BlockSource.MANUAL).reviewed(false).build();
|
||||
blockRepository.save(block);
|
||||
|
||||
FileService fileService = mockFileService();
|
||||
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
|
||||
|
||||
StreamingResponseBody body = service.exportToZip();
|
||||
assertThat(zipEntryNames(stream(body))).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void export_excludesManualBlockFromNonEnrolledDocument() throws Exception {
|
||||
UUID docId = nonEnrolledDoc("notenrolled.pdf");
|
||||
UUID annotId = annotation(docId);
|
||||
blockRepository.save(manualBlock(docId, annotId, "Liebe Tante"));
|
||||
|
||||
FileService fileService = mockFileService();
|
||||
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
|
||||
|
||||
StreamingResponseBody body = service.exportToZip();
|
||||
byte[] zipBytes = stream(body);
|
||||
assertThat(zipEntryNames(zipBytes)).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void export_includesReviewedOcrBlockFromEnrolledDocument() throws Exception {
|
||||
UUID docId = enrolledDoc("ocr-reviewed.pdf");
|
||||
UUID annotId = annotation(docId);
|
||||
TranscriptionBlock block = TranscriptionBlock.builder()
|
||||
.annotationId(annotId).documentId(docId)
|
||||
.text("OCR text").sortOrder(0)
|
||||
.source(BlockSource.OCR).reviewed(true).build();
|
||||
blockRepository.save(block);
|
||||
|
||||
FileService fileService = mockFileService();
|
||||
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
|
||||
|
||||
StreamingResponseBody body = service.exportToZip();
|
||||
assertThat(zipEntryNames(stream(body))).isNotEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void export_excludesUnreviewedOcrBlockFromEnrolledDocument() throws Exception {
|
||||
UUID docId = enrolledDoc("ocr-unreviewed.pdf");
|
||||
UUID annotId = annotation(docId);
|
||||
TranscriptionBlock block = TranscriptionBlock.builder()
|
||||
.annotationId(annotId).documentId(docId)
|
||||
.text("Raw OCR").sortOrder(0)
|
||||
.source(BlockSource.OCR).reviewed(false).build();
|
||||
blockRepository.save(block);
|
||||
|
||||
FileService fileService = mockFileService();
|
||||
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
|
||||
|
||||
StreamingResponseBody body = service.exportToZip();
|
||||
assertThat(zipEntryNames(stream(body))).isEmpty();
|
||||
}
|
||||
|
||||
// ─── ZIP structure ────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void export_producesExactly2EntriesPerBlock_pngAndTxt() throws Exception {
|
||||
UUID docId = enrolledDoc("zip-struct.pdf");
|
||||
UUID annotId = annotation(docId);
|
||||
blockRepository.save(manualBlock(docId, annotId, "Erste Zeile"));
|
||||
blockRepository.save(manualBlock(docId, annotId, "Zweite Zeile"));
|
||||
|
||||
FileService fileService = mockFileService();
|
||||
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
|
||||
|
||||
byte[] zipBytes = stream(service.exportToZip());
|
||||
var names = zipEntryNames(zipBytes);
|
||||
assertThat(names).hasSize(4); // 2 blocks × 2 entries each
|
||||
assertThat(names.stream().filter(n -> n.endsWith(".png")).count()).isEqualTo(2);
|
||||
assertThat(names.stream().filter(n -> n.endsWith(".xml")).count()).isEqualTo(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void export_pageXmlContainsBlockText() throws Exception {
|
||||
UUID docId = enrolledDoc("txt-content.pdf");
|
||||
UUID annotId = annotation(docId);
|
||||
String expectedText = "Sehr geehrte Frau";
|
||||
blockRepository.save(manualBlock(docId, annotId, expectedText));
|
||||
|
||||
FileService fileService = mockFileService();
|
||||
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
|
||||
|
||||
byte[] zipBytes = stream(service.exportToZip());
|
||||
String xmlContent = readZipEntry(zipBytes, ".xml");
|
||||
assertThat(xmlContent).contains("<Unicode>" + expectedText + "</Unicode>");
|
||||
}
|
||||
|
||||
@Test
|
||||
void export_pageXmlEscapesSpecialCharacters() throws Exception {
|
||||
UUID docId = enrolledDoc("special-chars.pdf");
|
||||
UUID annotId = annotation(docId);
|
||||
blockRepository.save(manualBlock(docId, annotId, "A & B < C > D"));
|
||||
|
||||
FileService fileService = mockFileService();
|
||||
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
|
||||
|
||||
byte[] zipBytes = stream(service.exportToZip());
|
||||
String xmlContent = readZipEntry(zipBytes, ".xml");
|
||||
assertThat(xmlContent).contains("<Unicode>A & B < C > D</Unicode>");
|
||||
}
|
||||
|
||||
// ─── S3 failure resilience ────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void export_skipsDocumentWhenS3DownloadFails_andStillIncludesOtherDocuments() throws Exception {
|
||||
UUID failDocId = enrolledDoc("fail.pdf");
|
||||
UUID okDocId = enrolledDoc("ok.pdf");
|
||||
UUID failAnnotId = annotation(failDocId);
|
||||
UUID okAnnotId = annotation(okDocId);
|
||||
blockRepository.save(manualBlock(failDocId, failAnnotId, "Will fail"));
|
||||
blockRepository.save(manualBlock(okDocId, okAnnotId, "Will succeed"));
|
||||
|
||||
FileService fileService = mock(FileService.class);
|
||||
when(fileService.downloadFileBytes("fail.pdf")).thenThrow(new FileService.StorageFileNotFoundException("missing"));
|
||||
when(fileService.downloadFileBytes("ok.pdf")).thenReturn(minimalPdfBytes);
|
||||
|
||||
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
|
||||
|
||||
byte[] zipBytes = stream(service.exportToZip());
|
||||
var names = zipEntryNames(zipBytes);
|
||||
// ok.pdf block produces 2 entries; fail.pdf block is skipped
|
||||
assertThat(names).hasSize(2);
|
||||
}
|
||||
|
||||
// ─── Empty export ─────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void queryEligibleBlocks_returnsEmpty_whenNoEnrolledDocuments() {
|
||||
FileService fileService = mockFileService();
|
||||
TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService);
|
||||
|
||||
assertThat(service.queryEligibleBlocks()).isEmpty();
|
||||
}
|
||||
|
||||
// ─── helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
private UUID enrolledDoc(String filename) {
|
||||
Document doc = documentRepository.save(Document.builder()
|
||||
.title(filename).originalFilename(filename).filePath(filename)
|
||||
.status(DocumentStatus.UPLOADED)
|
||||
.trainingLabels(new HashSet<>(Set.of(TrainingLabel.KURRENT_RECOGNITION)))
|
||||
.build());
|
||||
return doc.getId();
|
||||
}
|
||||
|
||||
private UUID nonEnrolledDoc(String filename) {
|
||||
Document doc = documentRepository.save(Document.builder()
|
||||
.title(filename).originalFilename(filename).filePath(filename)
|
||||
.status(DocumentStatus.UPLOADED)
|
||||
.build());
|
||||
return doc.getId();
|
||||
}
|
||||
|
||||
private UUID annotation(UUID docId) {
|
||||
return annotationRepository.save(DocumentAnnotation.builder()
|
||||
.documentId(docId).pageNumber(1)
|
||||
.x(0.1).y(0.1).width(0.8).height(0.1).color("#00C7B1")
|
||||
.build()).getId();
|
||||
}
|
||||
|
||||
private TranscriptionBlock manualBlock(UUID docId, UUID annotId, String text) {
|
||||
return TranscriptionBlock.builder()
|
||||
.annotationId(annotId).documentId(docId)
|
||||
.text(text).sortOrder(0)
|
||||
.source(BlockSource.MANUAL).reviewed(true).build();
|
||||
}
|
||||
|
||||
private FileService mockFileService() {
|
||||
FileService fs = mock(FileService.class);
|
||||
try {
|
||||
when(fs.downloadFileBytes(anyString())).thenReturn(minimalPdfBytes);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return fs;
|
||||
}
|
||||
|
||||
private static byte[] stream(StreamingResponseBody body) throws Exception {
|
||||
ByteArrayOutputStream out = new ByteArrayOutputStream();
|
||||
body.writeTo(out);
|
||||
return out.toByteArray();
|
||||
}
|
||||
|
||||
private static java.util.List<String> zipEntryNames(byte[] zipBytes) throws Exception {
|
||||
var names = new java.util.ArrayList<String>();
|
||||
try (ZipInputStream zis = new ZipInputStream(new ByteArrayInputStream(zipBytes))) {
|
||||
ZipEntry entry;
|
||||
while ((entry = zis.getNextEntry()) != null) {
|
||||
names.add(entry.getName());
|
||||
zis.closeEntry();
|
||||
}
|
||||
}
|
||||
return names;
|
||||
}
|
||||
|
||||
private static String readZipEntry(byte[] zipBytes, String suffix) throws Exception {
|
||||
try (ZipInputStream zis = new ZipInputStream(new ByteArrayInputStream(zipBytes))) {
|
||||
ZipEntry entry;
|
||||
while ((entry = zis.getNextEntry()) != null) {
|
||||
if (entry.getName().endsWith(suffix)) {
|
||||
return new String(zis.readAllBytes(), java.nio.charset.StandardCharsets.UTF_8);
|
||||
}
|
||||
zis.closeEntry();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.model.BlockSource;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.AnnotationRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockVersionRepository;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class TranscriptionServiceGuidedTest {
|
||||
|
||||
TranscriptionBlockRepository blockRepository;
|
||||
TranscriptionBlockVersionRepository versionRepository;
|
||||
AnnotationRepository annotationRepository;
|
||||
AnnotationService annotationService;
|
||||
DocumentService documentService;
|
||||
TranscriptionService service;
|
||||
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID annId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||
versionRepository = mock(TranscriptionBlockVersionRepository.class);
|
||||
annotationRepository = mock(AnnotationRepository.class);
|
||||
annotationService = mock(AnnotationService.class);
|
||||
documentService = mock(DocumentService.class);
|
||||
|
||||
service = new TranscriptionService(blockRepository, versionRepository,
|
||||
annotationRepository, annotationService, documentService);
|
||||
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
void upsertGuidedBlock_createsNewBlock_whenAnnotationHasNoBlock() {
|
||||
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.empty());
|
||||
|
||||
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "Hallo", userId);
|
||||
|
||||
assertThat(result.getText()).isEqualTo("Hallo");
|
||||
assertThat(result.getSource()).isEqualTo(BlockSource.OCR);
|
||||
assertThat(result.getAnnotationId()).isEqualTo(annId);
|
||||
verify(blockRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void upsertGuidedBlock_updatesExistingOcrBlock_whenAnnotationHasOcrBlock() {
|
||||
TranscriptionBlock existing = TranscriptionBlock.builder()
|
||||
.id(UUID.randomUUID())
|
||||
.annotationId(annId)
|
||||
.documentId(docId)
|
||||
.text("old text")
|
||||
.source(BlockSource.OCR)
|
||||
.sortOrder(0)
|
||||
.build();
|
||||
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(existing));
|
||||
|
||||
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "new text", userId);
|
||||
|
||||
assertThat(result.getText()).isEqualTo("new text");
|
||||
verify(blockRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void upsertGuidedBlock_doesNotOverwriteNonEmptyManualBlock() {
|
||||
TranscriptionBlock manual = TranscriptionBlock.builder()
|
||||
.id(UUID.randomUUID())
|
||||
.annotationId(annId)
|
||||
.documentId(docId)
|
||||
.text("manually written")
|
||||
.source(BlockSource.MANUAL)
|
||||
.sortOrder(0)
|
||||
.build();
|
||||
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(manual));
|
||||
|
||||
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "ocr result", userId);
|
||||
|
||||
assertThat(result.getText()).isEqualTo("manually written");
|
||||
verify(blockRepository, never()).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void upsertGuidedBlock_fillsEmptyManualBlock_withOcrText() {
|
||||
TranscriptionBlock emptyManual = TranscriptionBlock.builder()
|
||||
.id(UUID.randomUUID())
|
||||
.annotationId(annId)
|
||||
.documentId(docId)
|
||||
.text("")
|
||||
.source(BlockSource.MANUAL)
|
||||
.sortOrder(0)
|
||||
.build();
|
||||
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(emptyManual));
|
||||
|
||||
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "ocr result", userId);
|
||||
|
||||
assertThat(result.getText()).isEqualTo("ocr result");
|
||||
verify(blockRepository).save(any());
|
||||
}
|
||||
}
|
||||
@@ -72,6 +72,9 @@ services:
|
||||
- archive-net
|
||||
|
||||
# --- OCR: Python microservice (Surya + Kraken) ---
|
||||
# Single-node only: OCR training reloads the model in-process after each run.
|
||||
# Running multiple replicas would cause training conflicts and model-state divergence.
|
||||
# See ADR-001 for the architectural rationale.
|
||||
ocr-service:
|
||||
build:
|
||||
context: ./ocr-service
|
||||
@@ -87,6 +90,7 @@ services:
|
||||
- ocr_cache:/root/.cache
|
||||
environment:
|
||||
KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel
|
||||
TRAINING_TOKEN: "${OCR_TRAINING_TOKEN:-}"
|
||||
OCR_CONFIDENCE_THRESHOLD: "0.3"
|
||||
OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5"
|
||||
RECOGNITION_BATCH_SIZE: "16"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{
|
||||
"$schema": "https://inlang.com/schema/inlang-message-format",
|
||||
"error_annotation_not_found": "Die Annotation wurde nicht gefunden.",
|
||||
"error_annotation_overlap": "Die Annotation überschneidet sich mit einer vorhandenen.",
|
||||
"annotation_outdated_notice": "Einige Annotationen beziehen sich auf eine frühere Dateiversion und werden nicht angezeigt.",
|
||||
"error_document_not_found": "Das Dokument wurde nicht gefunden.",
|
||||
"error_document_no_file": "Diesem Dokument ist noch keine Datei zugeordnet.",
|
||||
@@ -480,6 +479,7 @@
|
||||
"scan_collapse": "Scan verkleinern",
|
||||
"transcription_empty_title": "Noch keine Transkription",
|
||||
"transcription_empty_desc": "Zeichne Bereiche auf dem Scan und tippe den Text ab, um eine Transkription zu erstellen.",
|
||||
"transcription_empty_draw_hint": "Zeichnen Sie Bereiche auf dem Dokument, um mit der Transkription zu beginnen.",
|
||||
"transcription_panel_close": "Panel schließen",
|
||||
"person_alias_heading": "Namensverlauf",
|
||||
"person_alias_empty": "Noch keine Namensaenderungen erfasst.",
|
||||
@@ -505,6 +505,7 @@
|
||||
"error_ocr_job_not_found": "Der OCR-Auftrag wurde nicht gefunden.",
|
||||
"error_ocr_document_not_uploaded": "Das Dokument hat keine Datei — OCR ist nicht möglich.",
|
||||
"error_ocr_processing_failed": "Die OCR-Verarbeitung ist fehlgeschlagen.",
|
||||
"error_training_already_running": "Es läuft bereits ein Trainings-Vorgang.",
|
||||
"ocr_script_type_typewriter": "Schreibmaschine",
|
||||
"ocr_script_type_handwriting_latin": "Handschrift (lateinisch)",
|
||||
"ocr_script_type_handwriting_kurrent": "Handschrift (Kurrent/Sütterlin)",
|
||||
@@ -530,7 +531,35 @@
|
||||
"ocr_status_analyzing_page": "Seite {current} von {total} wird analysiert…",
|
||||
"ocr_status_done_skipped": "{count} Blöcke erstellt, {skipped} Seite(n) übersprungen",
|
||||
"ocr_status_error": "OCR fehlgeschlagen",
|
||||
"ocr_trigger_no_annotations": "Zeichnen Sie zuerst Bereiche auf dem Dokument ein.",
|
||||
"ocr_section_heading": "OCR ausführen",
|
||||
"transcription_block_review": "Als geprüft markieren",
|
||||
"transcription_block_unreview": "Markierung aufheben",
|
||||
"transcription_reviewed_count": "{reviewed} von {total} geprüft"
|
||||
"transcription_reviewed_count": "{reviewed} von {total} geprüft",
|
||||
"training_ocr_heading": "Kurrent-Erkennung trainieren",
|
||||
"training_ocr_description": "Starte ein neues Training mit den bisher geprüften OCR-Blöcken, um die Erkennungsgenauigkeit für Kurrentschrift zu verbessern.",
|
||||
"training_ocr_blocks_ready": "{blocks} geprüfte Blöcke bereit / {docs} Dokumente",
|
||||
"training_ocr_blocks_total": "(von {total} OCR-Blöcken gesamt)",
|
||||
"training_start_btn": "Training starten",
|
||||
"training_in_progress": "…",
|
||||
"training_success": "Training wurde gestartet und abgeschlossen.",
|
||||
"training_too_few_blocks": "Mindestens 5 geprüfte Blöcke erforderlich (aktuell: {available}).",
|
||||
"training_service_down": "OCR-Dienst ist nicht erreichbar.",
|
||||
"training_history_heading": "Verlauf",
|
||||
"training_history_empty": "Noch keine Trainings-Läufe.",
|
||||
"training_history_col_date": "Datum",
|
||||
"training_history_col_status": "Status",
|
||||
"training_history_col_blocks": "Blöcke",
|
||||
"training_history_col_docs": "Dokumente",
|
||||
"training_history_col_cer": "Fehlerrate",
|
||||
"training_status_done": "Fertig",
|
||||
"training_status_failed": "Fehler",
|
||||
"training_status_running": "Läuft…",
|
||||
"training_seg_heading": "Segmentierung trainieren",
|
||||
"training_seg_description": "Starte ein neues Training mit annotierten Segmentierungsbereichen, um die Texterkennung zu verbessern.",
|
||||
"training_seg_blocks_ready": "{blocks} Segmentierungsblöcke bereit",
|
||||
"training_seg_too_few_blocks": "Mindestens 5 Segmentierungsblöcke erforderlich (aktuell: {available}).",
|
||||
"transcription_block_segmentation_only": "Nur Segmentierung",
|
||||
"training_chip_kurrent": "Kurrent-Erkennung",
|
||||
"training_chip_segmentation": "Segmentierung"
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{
|
||||
"$schema": "https://inlang.com/schema/inlang-message-format",
|
||||
"error_annotation_not_found": "Annotation not found.",
|
||||
"error_annotation_overlap": "The annotation overlaps an existing one.",
|
||||
"annotation_outdated_notice": "Some annotations refer to an earlier file version and are not shown.",
|
||||
"error_document_not_found": "Document not found.",
|
||||
"error_document_no_file": "No file is associated with this document.",
|
||||
@@ -480,6 +479,7 @@
|
||||
"scan_collapse": "Collapse scan",
|
||||
"transcription_empty_title": "No transcription yet",
|
||||
"transcription_empty_desc": "Draw regions on the scan and type the text to create a transcription.",
|
||||
"transcription_empty_draw_hint": "Draw regions on the document to start transcribing.",
|
||||
"transcription_panel_close": "Close panel",
|
||||
"person_alias_heading": "Name history",
|
||||
"person_alias_empty": "No name changes recorded yet.",
|
||||
@@ -505,6 +505,7 @@
|
||||
"error_ocr_job_not_found": "The OCR job was not found.",
|
||||
"error_ocr_document_not_uploaded": "The document has no file — OCR is not possible.",
|
||||
"error_ocr_processing_failed": "OCR processing failed.",
|
||||
"error_training_already_running": "A training run is already in progress.",
|
||||
"ocr_script_type_typewriter": "Typewriter",
|
||||
"ocr_script_type_handwriting_latin": "Handwriting (Latin)",
|
||||
"ocr_script_type_handwriting_kurrent": "Handwriting (Kurrent/Sütterlin)",
|
||||
@@ -530,7 +531,35 @@
|
||||
"ocr_status_analyzing_page": "Analyzing page {current} of {total}…",
|
||||
"ocr_status_done_skipped": "{count} blocks created, {skipped} page(s) skipped",
|
||||
"ocr_status_error": "OCR failed",
|
||||
"ocr_trigger_no_annotations": "Draw regions on the document first.",
|
||||
"ocr_section_heading": "Run OCR",
|
||||
"transcription_block_review": "Mark as reviewed",
|
||||
"transcription_block_unreview": "Unmark as reviewed",
|
||||
"transcription_reviewed_count": "{reviewed} of {total} reviewed"
|
||||
"transcription_reviewed_count": "{reviewed} of {total} reviewed",
|
||||
"training_ocr_heading": "Train Kurrent recognition",
|
||||
"training_ocr_description": "Start a new training run using the reviewed OCR blocks to improve recognition accuracy for Kurrent script.",
|
||||
"training_ocr_blocks_ready": "{blocks} reviewed blocks ready / {docs} documents",
|
||||
"training_ocr_blocks_total": "(of {total} OCR blocks total)",
|
||||
"training_start_btn": "Start training",
|
||||
"training_in_progress": "…",
|
||||
"training_success": "Training started and completed.",
|
||||
"training_too_few_blocks": "At least 5 reviewed blocks required (currently: {available}).",
|
||||
"training_service_down": "OCR service is unavailable.",
|
||||
"training_history_heading": "History",
|
||||
"training_history_empty": "No training runs yet.",
|
||||
"training_history_col_date": "Date",
|
||||
"training_history_col_status": "Status",
|
||||
"training_history_col_blocks": "Blocks",
|
||||
"training_history_col_docs": "Documents",
|
||||
"training_history_col_cer": "Error Rate",
|
||||
"training_status_done": "Done",
|
||||
"training_status_failed": "Failed",
|
||||
"training_status_running": "Running…",
|
||||
"training_seg_heading": "Train segmentation",
|
||||
"training_seg_description": "Start a new training run using annotated segmentation regions to improve text detection.",
|
||||
"training_seg_blocks_ready": "{blocks} segmentation blocks ready",
|
||||
"training_seg_too_few_blocks": "At least 5 segmentation blocks required (currently: {available}).",
|
||||
"transcription_block_segmentation_only": "Segmentation only",
|
||||
"training_chip_kurrent": "Kurrent recognition",
|
||||
"training_chip_segmentation": "Segmentation"
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{
|
||||
"$schema": "https://inlang.com/schema/inlang-message-format",
|
||||
"error_annotation_not_found": "Anotación no encontrada.",
|
||||
"error_annotation_overlap": "La anotación se superpone con una existente.",
|
||||
"annotation_outdated_notice": "Algunas anotaciones hacen referencia a una versión anterior del archivo y no se muestran.",
|
||||
"error_document_not_found": "Documento no encontrado.",
|
||||
"error_document_no_file": "No hay ningún archivo asociado a este documento.",
|
||||
@@ -480,6 +479,7 @@
|
||||
"scan_collapse": "Reducir escaneo",
|
||||
"transcription_empty_title": "Sin transcripcion",
|
||||
"transcription_empty_desc": "Dibuja regiones en el escaneo y escribe el texto para crear una transcripcion.",
|
||||
"transcription_empty_draw_hint": "Dibuje regiones en el documento para comenzar a transcribir.",
|
||||
"transcription_panel_close": "Cerrar panel",
|
||||
"person_alias_heading": "Historial de nombres",
|
||||
"person_alias_empty": "Aun no se han registrado cambios de nombre.",
|
||||
@@ -505,6 +505,7 @@
|
||||
"error_ocr_job_not_found": "No se encontró el trabajo OCR.",
|
||||
"error_ocr_document_not_uploaded": "El documento no tiene archivo — OCR no es posible.",
|
||||
"error_ocr_processing_failed": "El procesamiento OCR ha fallado.",
|
||||
"error_training_already_running": "Ya hay un proceso de entrenamiento en curso.",
|
||||
"ocr_script_type_typewriter": "Máquina de escribir",
|
||||
"ocr_script_type_handwriting_latin": "Escritura manuscrita (latina)",
|
||||
"ocr_script_type_handwriting_kurrent": "Escritura manuscrita (Kurrent/Sütterlin)",
|
||||
@@ -530,7 +531,35 @@
|
||||
"ocr_status_analyzing_page": "Analizando página {current} de {total}…",
|
||||
"ocr_status_done_skipped": "{count} bloques creados, {skipped} página(s) omitida(s)",
|
||||
"ocr_status_error": "OCR fallido",
|
||||
"ocr_trigger_no_annotations": "Dibuje regiones en el documento primero.",
|
||||
"ocr_section_heading": "Ejecutar OCR",
|
||||
"transcription_block_review": "Marcar como revisado",
|
||||
"transcription_block_unreview": "Desmarcar como revisado",
|
||||
"transcription_reviewed_count": "{reviewed} de {total} revisados"
|
||||
"transcription_reviewed_count": "{reviewed} de {total} revisados",
|
||||
"training_ocr_heading": "Entrenar reconocimiento Kurrent",
|
||||
"training_ocr_description": "Inicia un nuevo entrenamiento con los bloques OCR revisados para mejorar la precisión de reconocimiento del script Kurrent.",
|
||||
"training_ocr_blocks_ready": "{blocks} bloques revisados listos / {docs} documentos",
|
||||
"training_ocr_blocks_total": "(de {total} bloques OCR en total)",
|
||||
"training_start_btn": "Iniciar entrenamiento",
|
||||
"training_in_progress": "…",
|
||||
"training_success": "Entrenamiento iniciado y completado.",
|
||||
"training_too_few_blocks": "Se requieren al menos 5 bloques revisados (actualmente: {available}).",
|
||||
"training_service_down": "El servicio OCR no está disponible.",
|
||||
"training_history_heading": "Historial",
|
||||
"training_history_empty": "Todavía no hay ejecuciones de entrenamiento.",
|
||||
"training_history_col_date": "Fecha",
|
||||
"training_history_col_status": "Estado",
|
||||
"training_history_col_blocks": "Bloques",
|
||||
"training_history_col_docs": "Documentos",
|
||||
"training_history_col_cer": "Tasa de error",
|
||||
"training_status_done": "Listo",
|
||||
"training_status_failed": "Error",
|
||||
"training_status_running": "Ejecutando…",
|
||||
"training_seg_heading": "Entrenar segmentación",
|
||||
"training_seg_description": "Inicia un nuevo entrenamiento con regiones de segmentación anotadas para mejorar la detección de texto.",
|
||||
"training_seg_blocks_ready": "{blocks} bloques de segmentación listos",
|
||||
"training_seg_too_few_blocks": "Se requieren al menos 5 bloques de segmentación (actualmente: {available}).",
|
||||
"transcription_block_segmentation_only": "Solo segmentación",
|
||||
"training_chip_kurrent": "Reconocimiento Kurrent",
|
||||
"training_chip_segmentation": "Segmentación"
|
||||
}
|
||||
|
||||
@@ -72,7 +72,9 @@ $effect(() => {
|
||||
></div>
|
||||
</div>
|
||||
<p class="mt-2 text-right text-sm text-gray-500">
|
||||
{m.ocr_progress_page({ current: String(currentPage), total: String(totalPages) })}
|
||||
<span aria-live="polite" aria-atomic="true">
|
||||
{m.ocr_progress_page({ current: String(currentPage), total: String(totalPages) })}
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
{:else if status === 'error'}
|
||||
|
||||
92
frontend/src/lib/components/OcrTrainingCard.svelte
Normal file
92
frontend/src/lib/components/OcrTrainingCard.svelte
Normal file
@@ -0,0 +1,92 @@
|
||||
<script lang="ts">
|
||||
import TrainingHistory from './TrainingHistory.svelte';
|
||||
import { m } from '$lib/paraglide/messages.js';
|
||||
|
||||
interface Run {
|
||||
id: string;
|
||||
status: 'RUNNING' | 'DONE' | 'FAILED';
|
||||
blockCount: number;
|
||||
documentCount: number;
|
||||
modelName: string;
|
||||
errorMessage?: string;
|
||||
triggeredBy?: string;
|
||||
createdAt: string;
|
||||
completedAt?: string;
|
||||
}
|
||||
|
||||
interface TrainingInfo {
|
||||
availableBlocks?: number;
|
||||
totalOcrBlocks?: number;
|
||||
availableDocuments?: number;
|
||||
ocrServiceAvailable?: boolean;
|
||||
lastRun?: Run | null;
|
||||
runs?: Run[];
|
||||
}
|
||||
|
||||
interface Props {
|
||||
trainingInfo: TrainingInfo | null;
|
||||
}
|
||||
|
||||
let { trainingInfo }: Props = $props();
|
||||
|
||||
let training = $state(false);
|
||||
let successMessage = $state<string | null>(null);
|
||||
|
||||
const available = $derived(trainingInfo?.availableBlocks ?? 0);
|
||||
const tooFewBlocks = $derived(available < 5);
|
||||
const serviceDown = $derived(trainingInfo?.ocrServiceAvailable === false);
|
||||
const disabled = $derived(training || tooFewBlocks || serviceDown);
|
||||
|
||||
async function startTraining() {
|
||||
training = true;
|
||||
successMessage = null;
|
||||
try {
|
||||
const res = await fetch('/api/ocr/train', { method: 'POST' });
|
||||
if (res.ok) {
|
||||
successMessage = m.training_success();
|
||||
setTimeout(() => {
|
||||
successMessage = null;
|
||||
}, 5000);
|
||||
}
|
||||
} finally {
|
||||
training = false;
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="rounded-sm border border-line bg-surface p-6 shadow-sm">
|
||||
<h2 class="mb-1 font-sans text-sm font-bold text-ink">{m.training_ocr_heading()}</h2>
|
||||
<p class="mb-4 text-sm text-ink-2">{m.training_ocr_description()}</p>
|
||||
|
||||
<p class="mb-3 text-sm text-ink">
|
||||
{m.training_ocr_blocks_ready({ blocks: available, docs: trainingInfo?.availableDocuments ?? 0 })}
|
||||
<span class="text-ink-2"
|
||||
>{m.training_ocr_blocks_total({ total: trainingInfo?.totalOcrBlocks ?? 0 })}</span
|
||||
>
|
||||
</p>
|
||||
|
||||
<button
|
||||
onclick={startTraining}
|
||||
disabled={disabled}
|
||||
class="rounded-sm bg-primary px-5 py-2 font-sans text-xs font-bold tracking-widest text-primary-fg uppercase transition-opacity hover:opacity-80 focus-visible:ring-2 focus-visible:ring-brand-navy disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
{training ? '…' : m.training_start_btn()}
|
||||
</button>
|
||||
|
||||
{#if tooFewBlocks}
|
||||
<p class="mt-2 text-xs text-ink-3">
|
||||
{m.training_too_few_blocks({ available })}
|
||||
</p>
|
||||
{:else if serviceDown}
|
||||
<p class="mt-2 text-xs text-orange-600">{m.training_service_down()}</p>
|
||||
{/if}
|
||||
|
||||
{#if successMessage}
|
||||
<p class="mt-2 text-xs text-green-700">{successMessage}</p>
|
||||
{/if}
|
||||
|
||||
<h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase">
|
||||
{m.training_history_heading()}
|
||||
</h3>
|
||||
<TrainingHistory runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName !== 'blla')} />
|
||||
</div>
|
||||
96
frontend/src/lib/components/OcrTrainingCard.svelte.spec.ts
Normal file
96
frontend/src/lib/components/OcrTrainingCard.svelte.spec.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
import { cleanup, render } from 'vitest-browser-svelte';
|
||||
import { page } from 'vitest/browser';
|
||||
import OcrTrainingCard from './OcrTrainingCard.svelte';
|
||||
|
||||
afterEach(cleanup);
|
||||
afterEach(() => vi.restoreAllMocks());
|
||||
|
||||
const baseInfo = {
|
||||
availableBlocks: 10,
|
||||
totalOcrBlocks: 20,
|
||||
availableDocuments: 3,
|
||||
ocrServiceAvailable: true,
|
||||
lastRun: null,
|
||||
runs: []
|
||||
};
|
||||
|
||||
describe('OcrTrainingCard — disabled states', () => {
|
||||
it('disables button and shows hint when availableBlocks is 0', async () => {
|
||||
render(OcrTrainingCard, { trainingInfo: { ...baseInfo, availableBlocks: 0 } });
|
||||
|
||||
const btn = page.getByRole('button', { name: /Training starten/i });
|
||||
await expect.element(btn).toBeDisabled();
|
||||
await expect
|
||||
.element(page.getByText(/Mindestens 5 geprüfte Blöcke erforderlich/i))
|
||||
.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('disables button and shows hint when availableBlocks is less than 5', async () => {
|
||||
render(OcrTrainingCard, { trainingInfo: { ...baseInfo, availableBlocks: 3 } });
|
||||
|
||||
const btn = page.getByRole('button', { name: /Training starten/i });
|
||||
await expect.element(btn).toBeDisabled();
|
||||
await expect.element(page.getByText(/Mindestens 5/i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('disables button and shows service-down warning when ocrServiceAvailable is false', async () => {
|
||||
render(OcrTrainingCard, { trainingInfo: { ...baseInfo, ocrServiceAvailable: false } });
|
||||
|
||||
const btn = page.getByRole('button', { name: /Training starten/i });
|
||||
await expect.element(btn).toBeDisabled();
|
||||
await expect.element(page.getByText(/OCR-Dienst ist nicht erreichbar/i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('does not show service-down warning when blocks are insufficient', async () => {
|
||||
// tooFewBlocks hint takes priority over serviceDown hint
|
||||
render(OcrTrainingCard, {
|
||||
trainingInfo: { ...baseInfo, availableBlocks: 2, ocrServiceAvailable: false }
|
||||
});
|
||||
|
||||
await expect.element(page.getByText(/Mindestens 5/i)).toBeInTheDocument();
|
||||
// serviceDown text should NOT appear because tooFewBlocks branch hides it
|
||||
const serviceMsg = document.querySelector('.text-orange-600');
|
||||
expect(serviceMsg).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('OcrTrainingCard — enabled state', () => {
|
||||
it('enables button when availableBlocks >= 5 and service is up', async () => {
|
||||
render(OcrTrainingCard, { trainingInfo: baseInfo });
|
||||
|
||||
const btn = page.getByRole('button', { name: /Training starten/i });
|
||||
await expect.element(btn).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it('shows block count info text', async () => {
|
||||
render(OcrTrainingCard, {
|
||||
trainingInfo: { ...baseInfo, availableBlocks: 7, totalOcrBlocks: 15 }
|
||||
});
|
||||
|
||||
await expect.element(page.getByText(/7/)).toBeInTheDocument();
|
||||
await expect.element(page.getByText(/von 15 OCR-Blöcken/i)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('OcrTrainingCard — in-flight state', () => {
|
||||
it('shows "…" while POST is in-flight', async () => {
|
||||
let resolveFetch!: (v: unknown) => void;
|
||||
const pendingFetch = new Promise((resolve) => {
|
||||
resolveFetch = resolve;
|
||||
});
|
||||
|
||||
vi.stubGlobal('fetch', vi.fn().mockReturnValue(pendingFetch));
|
||||
|
||||
render(OcrTrainingCard, { trainingInfo: baseInfo });
|
||||
|
||||
const btn = page.getByRole('button', { name: /Training starten/i });
|
||||
await btn.click();
|
||||
|
||||
// While fetch is still pending the button label becomes "…"
|
||||
await expect.element(page.getByRole('button', { name: '…' })).toBeInTheDocument();
|
||||
|
||||
// Cleanup: resolve the pending promise
|
||||
resolveFetch({ ok: false });
|
||||
});
|
||||
});
|
||||
@@ -1,37 +1,23 @@
|
||||
<script lang="ts">
|
||||
import { untrack } from 'svelte';
|
||||
import { m } from '$lib/paraglide/messages.js';
|
||||
import { getConfirmService } from '$lib/services/confirm.svelte';
|
||||
import ScriptTypeSelect from './ScriptTypeSelect.svelte';
|
||||
|
||||
interface Props {
|
||||
existingBlockCount: number;
|
||||
blockCount: number;
|
||||
storedScriptType: string;
|
||||
onTrigger: (scriptType: string) => void;
|
||||
onTrigger: (scriptType: string, useExistingAnnotations: boolean) => void;
|
||||
}
|
||||
|
||||
let { existingBlockCount, storedScriptType, onTrigger }: Props = $props();
|
||||
|
||||
const { confirm } = getConfirmService();
|
||||
let { blockCount, storedScriptType, onTrigger }: Props = $props();
|
||||
|
||||
import { untrack } from 'svelte';
|
||||
let selectedScriptType: string = $state(
|
||||
untrack(() => (storedScriptType && storedScriptType !== 'UNKNOWN' ? storedScriptType : ''))
|
||||
);
|
||||
|
||||
async function handleClick() {
|
||||
function handleClick() {
|
||||
if (!selectedScriptType) return;
|
||||
|
||||
if (existingBlockCount > 0) {
|
||||
const confirmed = await confirm({
|
||||
title: m.ocr_confirm_title(),
|
||||
body: m.ocr_confirm_body({ count: String(existingBlockCount) }),
|
||||
confirmLabel: m.ocr_confirm_btn(),
|
||||
destructive: true
|
||||
});
|
||||
if (!confirmed) return;
|
||||
}
|
||||
|
||||
onTrigger(selectedScriptType);
|
||||
onTrigger(selectedScriptType, true);
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -39,11 +25,14 @@ async function handleClick() {
|
||||
<ScriptTypeSelect bind:value={selectedScriptType} />
|
||||
<button
|
||||
type="button"
|
||||
disabled={!selectedScriptType}
|
||||
disabled={!selectedScriptType || blockCount === 0}
|
||||
title={!selectedScriptType ? m.ocr_trigger_btn_disabled() : undefined}
|
||||
onclick={handleClick}
|
||||
class="min-h-[44px] w-full rounded-sm bg-brand-navy font-sans text-sm font-medium text-white transition-colors hover:bg-brand-navy/90 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
{m.ocr_trigger_btn()}
|
||||
</button>
|
||||
{#if blockCount === 0}
|
||||
<p class="text-xs text-ink-3">{m.ocr_trigger_no_annotations()}</p>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
86
frontend/src/lib/components/SegmentationTrainingCard.svelte
Normal file
86
frontend/src/lib/components/SegmentationTrainingCard.svelte
Normal file
@@ -0,0 +1,86 @@
|
||||
<script lang="ts">
|
||||
import TrainingHistory from './TrainingHistory.svelte';
|
||||
import { m } from '$lib/paraglide/messages.js';
|
||||
|
||||
interface Run {
|
||||
id: string;
|
||||
status: 'RUNNING' | 'DONE' | 'FAILED';
|
||||
blockCount: number;
|
||||
documentCount: number;
|
||||
modelName: string;
|
||||
errorMessage?: string;
|
||||
triggeredBy?: string;
|
||||
createdAt: string;
|
||||
completedAt?: string;
|
||||
}
|
||||
|
||||
interface TrainingInfo {
|
||||
availableSegBlocks?: number;
|
||||
ocrServiceAvailable?: boolean;
|
||||
runs?: Run[];
|
||||
}
|
||||
|
||||
interface Props {
|
||||
trainingInfo: TrainingInfo | null;
|
||||
}
|
||||
|
||||
let { trainingInfo }: Props = $props();
|
||||
|
||||
let training = $state(false);
|
||||
let successMessage = $state<string | null>(null);
|
||||
|
||||
const available = $derived(trainingInfo?.availableSegBlocks ?? 0);
|
||||
const tooFewBlocks = $derived(available < 5);
|
||||
const serviceDown = $derived(trainingInfo?.ocrServiceAvailable === false);
|
||||
const disabled = $derived(training || tooFewBlocks || serviceDown);
|
||||
|
||||
async function startTraining() {
|
||||
training = true;
|
||||
successMessage = null;
|
||||
try {
|
||||
const res = await fetch('/api/ocr/segtrain', { method: 'POST' });
|
||||
if (res.ok) {
|
||||
successMessage = m.training_success();
|
||||
setTimeout(() => {
|
||||
successMessage = null;
|
||||
}, 5000);
|
||||
}
|
||||
} finally {
|
||||
training = false;
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="rounded-sm border border-line bg-surface p-6 shadow-sm">
|
||||
<h2 class="mb-1 font-sans text-sm font-bold text-ink">{m.training_seg_heading()}</h2>
|
||||
<p class="mb-4 text-sm text-ink-2">{m.training_seg_description()}</p>
|
||||
|
||||
<p class="mb-3 text-sm text-ink">
|
||||
{m.training_seg_blocks_ready({ blocks: available })}
|
||||
</p>
|
||||
|
||||
<button
|
||||
onclick={startTraining}
|
||||
disabled={disabled}
|
||||
class="rounded-sm bg-primary px-5 py-2 font-sans text-xs font-bold tracking-widest text-primary-fg uppercase transition-opacity hover:opacity-80 focus-visible:ring-2 focus-visible:ring-brand-navy disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
{training ? '…' : m.training_start_btn()}
|
||||
</button>
|
||||
|
||||
{#if tooFewBlocks}
|
||||
<p class="mt-2 text-xs text-ink-3">
|
||||
{m.training_seg_too_few_blocks({ available })}
|
||||
</p>
|
||||
{:else if serviceDown}
|
||||
<p class="mt-2 text-xs text-orange-600">{m.training_service_down()}</p>
|
||||
{/if}
|
||||
|
||||
{#if successMessage}
|
||||
<p class="mt-2 text-xs text-green-700">{successMessage}</p>
|
||||
{/if}
|
||||
|
||||
<h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase">
|
||||
{m.training_history_heading()}
|
||||
</h3>
|
||||
<TrainingHistory runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName === 'blla')} />
|
||||
</div>
|
||||
119
frontend/src/lib/components/TrainingHistory.svelte
Normal file
119
frontend/src/lib/components/TrainingHistory.svelte
Normal file
@@ -0,0 +1,119 @@
|
||||
<script lang="ts">
|
||||
import { m } from '$lib/paraglide/messages.js';
|
||||
|
||||
interface Run {
|
||||
id: string;
|
||||
status: 'RUNNING' | 'DONE' | 'FAILED';
|
||||
blockCount: number;
|
||||
documentCount: number;
|
||||
modelName: string;
|
||||
errorMessage?: string;
|
||||
triggeredBy?: string;
|
||||
createdAt: string;
|
||||
completedAt?: string;
|
||||
cer?: number;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
runs: Run[];
|
||||
}
|
||||
|
||||
let { runs }: Props = $props();
|
||||
|
||||
const dateFormatter = new Intl.DateTimeFormat('de-DE', {
|
||||
day: 'numeric',
|
||||
month: 'short',
|
||||
year: 'numeric'
|
||||
});
|
||||
|
||||
function formatDate(iso: string): string {
|
||||
return dateFormatter.format(new Date(iso));
|
||||
}
|
||||
|
||||
function formatCer(cer: number | undefined | null): string {
|
||||
if (cer == null) return '—';
|
||||
return (cer * 100).toFixed(1) + ' %';
|
||||
}
|
||||
</script>
|
||||
|
||||
<table class="w-full text-sm">
|
||||
<thead>
|
||||
<tr class="border-b border-line text-xs font-bold tracking-widest text-ink-3 uppercase">
|
||||
<th class="pb-2 text-left">{m.training_history_col_date()}</th>
|
||||
<th class="pb-2 text-left">{m.training_history_col_status()}</th>
|
||||
<th class="pb-2 text-right">{m.training_history_col_blocks()}</th>
|
||||
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_docs()}</th>
|
||||
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_cer()}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#if runs.length === 0}
|
||||
<tr>
|
||||
<td colspan="5" class="py-4 text-center text-sm text-ink-2">
|
||||
{m.training_history_empty()}
|
||||
</td>
|
||||
</tr>
|
||||
{:else}
|
||||
{#each runs as run (run.id)}
|
||||
<tr class="border-b border-line/50 last:border-0">
|
||||
<td class="py-2 text-ink-2">{formatDate(run.createdAt)}</td>
|
||||
<td class="py-2">
|
||||
{#if run.status === 'DONE'}
|
||||
<span
|
||||
class="inline-flex items-center gap-1 rounded-sm bg-green-100 px-1.5 py-0.5 text-xs font-medium text-green-700"
|
||||
>
|
||||
<svg
|
||||
aria-hidden="true"
|
||||
class="h-3 w-3 shrink-0"
|
||||
viewBox="0 0 20 20"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{m.training_status_done()}
|
||||
</span>
|
||||
{:else if run.status === 'FAILED'}
|
||||
<span
|
||||
class="inline-flex items-center gap-1 rounded-sm bg-red-100 px-1.5 py-0.5 text-xs font-medium text-red-700"
|
||||
title={run.errorMessage}
|
||||
>
|
||||
<svg
|
||||
aria-hidden="true"
|
||||
class="h-3 w-3 shrink-0"
|
||||
viewBox="0 0 20 20"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M4.293 4.293a1 1 0 011.414 0L10 8.586l4.293-4.293a1 1 0 111.414 1.414L11.414 10l4.293 4.293a1 1 0 01-1.414 1.414L10 11.414l-4.293 4.293a1 1 0 01-1.414-1.414L8.586 10 4.293 5.707a1 1 0 010-1.414z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{m.training_status_failed()}
|
||||
</span>
|
||||
{:else}
|
||||
<span
|
||||
class="inline-flex items-center gap-1 rounded-sm bg-yellow-100 px-1.5 py-0.5 text-xs font-medium text-yellow-700"
|
||||
>
|
||||
<span
|
||||
aria-hidden="true"
|
||||
class="h-1.5 w-1.5 animate-pulse rounded-full bg-yellow-500"
|
||||
></span>
|
||||
{m.training_status_running()}
|
||||
</span>
|
||||
{/if}
|
||||
</td>
|
||||
<td class="py-2 text-right text-ink-2">{run.blockCount}</td>
|
||||
<td class="hidden py-2 text-right text-ink-2 md:table-cell">{run.documentCount}</td>
|
||||
<td class="hidden py-2 text-right md:table-cell"
|
||||
>{run.status === 'DONE' && run.cer != null ? formatCer(run.cer) : '—'}</td
|
||||
>
|
||||
</tr>
|
||||
{/each}
|
||||
{/if}
|
||||
</tbody>
|
||||
</table>
|
||||
@@ -27,6 +27,7 @@ type Props = {
|
||||
onMoveDown?: () => void;
|
||||
isFirst?: boolean;
|
||||
isLast?: boolean;
|
||||
source?: 'MANUAL' | 'OCR';
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -48,7 +49,8 @@ let {
|
||||
onMoveUp,
|
||||
onMoveDown,
|
||||
isFirst = false,
|
||||
isLast = false
|
||||
isLast = false,
|
||||
source = 'MANUAL'
|
||||
}: Props = $props();
|
||||
|
||||
let localText = $state(text);
|
||||
@@ -172,6 +174,11 @@ function handleTextareaMouseUp() {
|
||||
{label}
|
||||
</span>
|
||||
{/if}
|
||||
{#if (!text || text.trim() === '') && source === 'MANUAL'}
|
||||
<span class="rounded bg-muted px-1.5 py-0.5 text-xs font-medium text-ink-3"
|
||||
>{m.transcription_block_segmentation_only()}</span
|
||||
>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Textarea -->
|
||||
|
||||
@@ -19,7 +19,10 @@ type Props = {
|
||||
onSaveBlock: (blockId: string, text: string) => Promise<void>;
|
||||
onDeleteBlock: (blockId: string) => Promise<void>;
|
||||
onReviewToggle: (blockId: string) => Promise<void>;
|
||||
onTriggerOcr?: (scriptType: string) => void;
|
||||
onTriggerOcr?: (scriptType: string, useExistingAnnotations: boolean) => void;
|
||||
canWrite?: boolean;
|
||||
trainingLabels?: string[];
|
||||
onToggleTrainingLabel?: (label: string, enrolled: boolean) => Promise<void>;
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -34,10 +37,14 @@ let {
|
||||
onSaveBlock,
|
||||
onDeleteBlock,
|
||||
onReviewToggle,
|
||||
onTriggerOcr
|
||||
onTriggerOcr,
|
||||
canWrite = false,
|
||||
trainingLabels = [],
|
||||
onToggleTrainingLabel
|
||||
}: Props = $props();
|
||||
|
||||
let activeBlockId: string | null = $state(null);
|
||||
let localLabels: string[] = $derived.by(() => [...trainingLabels]);
|
||||
|
||||
// Sync: when an annotation is clicked on the PDF, activate the corresponding block
|
||||
$effect(() => {
|
||||
@@ -50,6 +57,9 @@ let debounceTimers = new SvelteMap<string, ReturnType<typeof setTimeout>>();
|
||||
let pendingTexts = new SvelteMap<string, string>();
|
||||
let sortedBlocks = $derived([...blocks].sort((a, b) => a.sortOrder - b.sortOrder));
|
||||
let hasBlocks = $derived(blocks.length > 0);
|
||||
let reviewedCount = $derived(blocks.filter((b) => b.reviewed).length);
|
||||
let totalCount = $derived(blocks.length);
|
||||
let reviewProgress = $derived(totalCount > 0 ? (reviewedCount / totalCount) * 100 : 0);
|
||||
|
||||
function getSaveState(blockId: string): SaveState {
|
||||
return saveStates.get(blockId) ?? 'idle';
|
||||
@@ -185,7 +195,7 @@ let dropTargetIdx: number | null = $state(null);
|
||||
let dragOffsetY: number = $state(0);
|
||||
let dragStartY = 0;
|
||||
let capturedEl: HTMLElement | null = null;
|
||||
let listEl: HTMLElement | null = null;
|
||||
let listEl: HTMLElement | null = $state(null);
|
||||
|
||||
function handleGripDown(e: PointerEvent, blockId: string) {
|
||||
if (!(e.target as HTMLElement).closest('[data-drag-handle]')) return;
|
||||
@@ -237,6 +247,23 @@ function handlePointerUp() {
|
||||
capturedEl = null;
|
||||
}
|
||||
|
||||
async function handleLabelToggle(label: string) {
|
||||
if (!onToggleTrainingLabel) return;
|
||||
const enrolled = !localLabels.includes(label);
|
||||
// Optimistic update
|
||||
if (enrolled) {
|
||||
localLabels = [...localLabels, label];
|
||||
} else {
|
||||
localLabels = localLabels.filter((l) => l !== label);
|
||||
}
|
||||
try {
|
||||
await onToggleTrainingLabel(label, enrolled);
|
||||
} catch {
|
||||
// Revert on failure
|
||||
localLabels = [...trainingLabels];
|
||||
}
|
||||
}
|
||||
|
||||
function flushViaBeacon() {
|
||||
for (const [blockId, text] of pendingTexts) {
|
||||
clearDebounce(blockId);
|
||||
@@ -263,78 +290,91 @@ $effect(() => {
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="flex h-full flex-col overflow-y-auto bg-surface p-4">
|
||||
<div class="flex h-full flex-col overflow-y-auto bg-surface">
|
||||
{#if hasBlocks}
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<div
|
||||
class="flex flex-col gap-3"
|
||||
bind:this={listEl}
|
||||
onpointermove={handlePointerMove}
|
||||
onpointerup={handlePointerUp}
|
||||
>
|
||||
{#each sortedBlocks as block, i (block.id)}
|
||||
{#if dropTargetIdx === i}
|
||||
<div class="h-1 rounded-full bg-turquoise transition-all"></div>
|
||||
{/if}
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<!-- Sticky review progress header -->
|
||||
<div class="sticky top-0 z-10 border-b border-line bg-surface px-4 pt-3 pb-2">
|
||||
<p class="font-sans text-xs text-ink-2">
|
||||
<span class="font-semibold text-ink">{reviewedCount} / {totalCount}</span> geprüft
|
||||
</p>
|
||||
<div class="bg-brand-sand mt-1.5 h-0.5 w-full overflow-hidden rounded-full">
|
||||
<div
|
||||
data-block-wrapper
|
||||
onblur={handleBlur}
|
||||
onpointerdown={(e) => handleGripDown(e, block.id)}
|
||||
class="relative transition-all duration-150 {draggedBlockId === block.id ? 'z-10 rounded-lg shadow-xl ring-2 ring-turquoise/40' : ''}"
|
||||
style={draggedBlockId === block.id ? `transform: translateY(${dragOffsetY}px) scale(1.02); opacity: 0.9;` : ''}
|
||||
>
|
||||
<TranscriptionBlock
|
||||
blockId={block.id}
|
||||
documentId={documentId}
|
||||
blockNumber={i + 1}
|
||||
text={block.text}
|
||||
label={block.label}
|
||||
active={activeBlockId === block.id}
|
||||
reviewed={block.reviewed ?? false}
|
||||
saveState={getSaveState(block.id)}
|
||||
canComment={canComment}
|
||||
currentUserId={currentUserId}
|
||||
onTextChange={(text) => handleTextChange(block.id, text)}
|
||||
onFocus={() => handleFocus(block.id)}
|
||||
onDeleteClick={() => handleDelete(block.id)}
|
||||
onRetry={() => handleRetry(block.id)}
|
||||
onReviewToggle={() => onReviewToggle(block.id)}
|
||||
onMoveUp={() => handleMoveUp(block.id)}
|
||||
onMoveDown={() => handleMoveDown(block.id)}
|
||||
isFirst={i === 0}
|
||||
isLast={i === sortedBlocks.length - 1}
|
||||
/>
|
||||
</div>
|
||||
{/each}
|
||||
|
||||
{#if dropTargetIdx === sortedBlocks.length}
|
||||
<div class="h-1 rounded-full bg-turquoise transition-all"></div>
|
||||
{/if}
|
||||
|
||||
<!-- Next block CTA — dashed outline hint -->
|
||||
<div
|
||||
class="flex items-center justify-center rounded border border-dashed border-line px-4 py-5 text-center font-sans text-sm text-ink-3"
|
||||
>
|
||||
{m.transcription_next_block_cta({ number: sortedBlocks.length + 1 })}
|
||||
class="h-full rounded-full bg-brand-mint transition-all duration-300"
|
||||
style="width: {reviewProgress}%"
|
||||
></div>
|
||||
</div>
|
||||
|
||||
{#if canRunOcr && onTriggerOcr}
|
||||
<details class="mt-6">
|
||||
<summary
|
||||
class="cursor-pointer font-sans text-xs font-medium text-ink-3 transition-colors hover:text-brand-navy"
|
||||
</div>
|
||||
<div class="p-4">
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<div
|
||||
class="flex flex-col gap-3"
|
||||
bind:this={listEl}
|
||||
onpointermove={handlePointerMove}
|
||||
onpointerup={handlePointerUp}
|
||||
>
|
||||
{#each sortedBlocks as block, i (block.id)}
|
||||
{#if dropTargetIdx === i}
|
||||
<div class="h-1 rounded-full bg-turquoise transition-all"></div>
|
||||
{/if}
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<div
|
||||
data-block-wrapper
|
||||
onblur={handleBlur}
|
||||
onpointerdown={(e) => handleGripDown(e, block.id)}
|
||||
class="relative transition-all duration-150 {draggedBlockId === block.id ? 'z-10 rounded-lg shadow-xl ring-2 ring-turquoise/40' : ''}"
|
||||
style={draggedBlockId === block.id ? `transform: translateY(${dragOffsetY}px) scale(1.02); opacity: 0.9;` : ''}
|
||||
>
|
||||
{m.ocr_rerun_label()}
|
||||
</summary>
|
||||
<div class="mt-3 max-w-xs">
|
||||
<OcrTrigger
|
||||
existingBlockCount={blocks.length}
|
||||
storedScriptType={storedScriptType}
|
||||
onTrigger={onTriggerOcr}
|
||||
<TranscriptionBlock
|
||||
blockId={block.id}
|
||||
documentId={documentId}
|
||||
blockNumber={i + 1}
|
||||
text={block.text}
|
||||
label={block.label}
|
||||
active={activeBlockId === block.id}
|
||||
reviewed={block.reviewed ?? false}
|
||||
saveState={getSaveState(block.id)}
|
||||
canComment={canComment}
|
||||
currentUserId={currentUserId}
|
||||
onTextChange={(text) => handleTextChange(block.id, text)}
|
||||
onFocus={() => handleFocus(block.id)}
|
||||
onDeleteClick={() => handleDelete(block.id)}
|
||||
onRetry={() => handleRetry(block.id)}
|
||||
onReviewToggle={() => onReviewToggle(block.id)}
|
||||
onMoveUp={() => handleMoveUp(block.id)}
|
||||
onMoveDown={() => handleMoveDown(block.id)}
|
||||
isFirst={i === 0}
|
||||
isLast={i === sortedBlocks.length - 1}
|
||||
source={block.source}
|
||||
/>
|
||||
</div>
|
||||
</details>
|
||||
{/if}
|
||||
{/each}
|
||||
|
||||
{#if dropTargetIdx === sortedBlocks.length}
|
||||
<div class="h-1 rounded-full bg-turquoise transition-all"></div>
|
||||
{/if}
|
||||
|
||||
<!-- Next block CTA — dashed outline hint -->
|
||||
<div
|
||||
class="flex items-center justify-center rounded border border-dashed border-line px-4 py-5 text-center font-sans text-sm text-ink-3"
|
||||
>
|
||||
{m.transcription_next_block_cta({ number: sortedBlocks.length + 1 })}
|
||||
</div>
|
||||
|
||||
{#if canRunOcr && onTriggerOcr}
|
||||
<div class="mt-6">
|
||||
<p class="mb-3 font-sans text-xs font-bold tracking-widest text-ink-3 uppercase">
|
||||
{m.ocr_section_heading()}
|
||||
</p>
|
||||
<div class="max-w-xs">
|
||||
<OcrTrigger
|
||||
blockCount={blocks.length}
|
||||
storedScriptType={storedScriptType}
|
||||
onTrigger={onTriggerOcr}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex flex-1 flex-col items-center justify-center px-6 py-12 text-center">
|
||||
@@ -352,25 +392,28 @@ $effect(() => {
|
||||
/>
|
||||
</svg>
|
||||
|
||||
{#if canRunOcr && onTriggerOcr}
|
||||
<p class="mb-6 max-w-xs text-sm leading-relaxed text-ink-3">
|
||||
{m.transcription_empty_title()}
|
||||
</p>
|
||||
<div class="w-full max-w-xs">
|
||||
<OcrTrigger
|
||||
existingBlockCount={0}
|
||||
storedScriptType={storedScriptType}
|
||||
onTrigger={onTriggerOcr}
|
||||
/>
|
||||
</div>
|
||||
<p class="mt-4 text-xs text-ink-3">
|
||||
{m.transcription_empty_desc()}
|
||||
</p>
|
||||
{:else}
|
||||
<p class="max-w-xs text-sm leading-relaxed text-ink-3">
|
||||
{m.transcription_empty_cta()}
|
||||
</p>
|
||||
{/if}
|
||||
<p class="max-w-xs text-sm leading-relaxed text-ink-3">
|
||||
{m.transcription_empty_draw_hint()}
|
||||
</p>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if canWrite}
|
||||
<div class="border-t border-line px-4 py-3">
|
||||
<p class="mb-2 font-sans text-xs font-medium text-ink-2">Für Training vormerken</p>
|
||||
<div class="flex flex-wrap gap-2">
|
||||
{#each [{ label: 'KURRENT_RECOGNITION', display: m.training_chip_kurrent() }, { label: 'KURRENT_SEGMENTATION', display: m.training_chip_segmentation() }] as chip (chip.label)}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => handleLabelToggle(chip.label)}
|
||||
class="rounded-full border px-3 py-1 font-sans text-xs font-medium transition-colors {localLabels.includes(chip.label)
|
||||
? 'border-brand-mint bg-brand-mint text-brand-navy'
|
||||
: 'border-line bg-surface text-ink-3 hover:border-brand-mint hover:text-brand-navy'}"
|
||||
>
|
||||
{chip.display}
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -13,7 +13,9 @@ const block1 = {
|
||||
text: 'Block eins',
|
||||
label: null,
|
||||
sortOrder: 0,
|
||||
version: 0
|
||||
version: 0,
|
||||
source: 'MANUAL' as const,
|
||||
reviewed: false
|
||||
};
|
||||
const block2 = {
|
||||
id: 'b2',
|
||||
@@ -22,7 +24,9 @@ const block2 = {
|
||||
text: 'Block zwei',
|
||||
label: null,
|
||||
sortOrder: 1,
|
||||
version: 0
|
||||
version: 0,
|
||||
source: 'OCR' as const,
|
||||
reviewed: true
|
||||
};
|
||||
|
||||
function renderView(overrides: Record<string, unknown> = {}, service = createConfirmService()) {
|
||||
@@ -36,6 +40,7 @@ function renderView(overrides: Record<string, unknown> = {}, service = createCon
|
||||
onBlockFocus: vi.fn(),
|
||||
onSaveBlock: vi.fn(),
|
||||
onDeleteBlock: vi.fn(),
|
||||
onReviewToggle: vi.fn(),
|
||||
...overrides
|
||||
},
|
||||
context: new Map([[CONFIRM_KEY, service]])
|
||||
@@ -58,7 +63,7 @@ describe('TranscriptionEditView — rendering', () => {
|
||||
|
||||
it('shows empty state when no blocks', async () => {
|
||||
renderView({ blocks: [] });
|
||||
await expect.element(page.getByText(/Markiere einen Bereich/)).toBeInTheDocument();
|
||||
await expect.element(page.getByText(/Zeichnen Sie Bereiche/)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -232,3 +237,23 @@ describe('TranscriptionEditView — delete block', () => {
|
||||
expect(onDeleteBlock).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Review progress counter ──────────────────────────────────────────────────
|
||||
|
||||
describe('TranscriptionEditView — review progress counter', () => {
|
||||
it('shows reviewed count and total when blocks exist', async () => {
|
||||
// block1: reviewed=false, block2: reviewed=true → "1 / 2 geprüft"
|
||||
renderView();
|
||||
await expect.element(page.getByText(/1 \/ 2 geprüft/)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows 0 reviewed when no blocks are reviewed', async () => {
|
||||
renderView({ blocks: [block1] }); // block1.reviewed = false
|
||||
await expect.element(page.getByText(/0 \/ 1 geprüft/)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('does not show progress counter when there are no blocks', async () => {
|
||||
renderView({ blocks: [] });
|
||||
await expect.element(page.getByText(/geprüft/)).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -18,7 +18,6 @@ export type ErrorCode =
|
||||
| 'IMPORT_ALREADY_RUNNING'
|
||||
| 'INVALID_RESET_TOKEN'
|
||||
| 'ANNOTATION_NOT_FOUND'
|
||||
| 'ANNOTATION_OVERLAP'
|
||||
| 'TRANSCRIPTION_BLOCK_NOT_FOUND'
|
||||
| 'TRANSCRIPTION_BLOCK_CONFLICT'
|
||||
| 'COMMENT_NOT_FOUND'
|
||||
@@ -26,6 +25,7 @@ export type ErrorCode =
|
||||
| 'OCR_JOB_NOT_FOUND'
|
||||
| 'OCR_DOCUMENT_NOT_UPLOADED'
|
||||
| 'OCR_PROCESSING_FAILED'
|
||||
| 'TRAINING_ALREADY_RUNNING'
|
||||
| 'UNAUTHORIZED'
|
||||
| 'FORBIDDEN'
|
||||
| 'VALIDATION_ERROR'
|
||||
@@ -81,8 +81,6 @@ export function getErrorMessage(code: ErrorCode | string | undefined): string {
|
||||
return m.error_invalid_reset_token();
|
||||
case 'ANNOTATION_NOT_FOUND':
|
||||
return m.error_annotation_not_found();
|
||||
case 'ANNOTATION_OVERLAP':
|
||||
return m.error_annotation_overlap();
|
||||
case 'TRANSCRIPTION_BLOCK_NOT_FOUND':
|
||||
return m.error_transcription_block_not_found();
|
||||
case 'TRANSCRIPTION_BLOCK_CONFLICT':
|
||||
@@ -97,6 +95,8 @@ export function getErrorMessage(code: ErrorCode | string | undefined): string {
|
||||
return m.error_ocr_document_not_uploaded();
|
||||
case 'OCR_PROCESSING_FAILED':
|
||||
return m.error_ocr_processing_failed();
|
||||
case 'TRAINING_ALREADY_RUNNING':
|
||||
return m.error_training_already_running();
|
||||
case 'UNAUTHORIZED':
|
||||
return m.error_unauthorized();
|
||||
case 'FORBIDDEN':
|
||||
|
||||
@@ -116,6 +116,22 @@ export interface paths {
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/documents/{documentId}/transcription-blocks/{blockId}/review": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get?: never;
|
||||
put: operations["reviewBlock"];
|
||||
post?: never;
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/documents/{documentId}/transcription-blocks/reorder": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
@@ -212,6 +228,54 @@ export interface paths {
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/ocr/train": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get?: never;
|
||||
put?: never;
|
||||
post: operations["triggerTraining"];
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/ocr/segtrain": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get?: never;
|
||||
put?: never;
|
||||
post: operations["triggerSegTraining"];
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/ocr/batch": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get?: never;
|
||||
put?: never;
|
||||
post: operations["triggerBatch"];
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/notifications/read-all": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
@@ -308,6 +372,22 @@ export interface paths {
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/documents/{documentId}/ocr": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get?: never;
|
||||
put?: never;
|
||||
post: operations["triggerOcr"];
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/documents/{documentId}/comments": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
@@ -516,6 +596,22 @@ export interface paths {
|
||||
patch: operations["updateGroup"];
|
||||
trace?: never;
|
||||
};
|
||||
"/api/documents/{id}/training-labels": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get?: never;
|
||||
put?: never;
|
||||
post?: never;
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch: operations["patchTrainingLabel"];
|
||||
trace?: never;
|
||||
};
|
||||
"/api/documents/{documentId}/comments/{commentId}": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
@@ -628,6 +724,86 @@ export interface paths {
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/ocr/training-info": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get: operations["getTrainingInfo"];
|
||||
put?: never;
|
||||
post?: never;
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/ocr/training-data/export": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get: operations["exportTrainingData"];
|
||||
put?: never;
|
||||
post?: never;
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/ocr/segmentation-training-data/export": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get: operations["exportSegmentationTrainingData"];
|
||||
put?: never;
|
||||
post?: never;
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/ocr/jobs/{jobId}": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get: operations["getJobStatus"];
|
||||
put?: never;
|
||||
post?: never;
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/ocr/jobs/{jobId}/progress": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get: operations["streamProgress"];
|
||||
put?: never;
|
||||
post?: never;
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/notifications": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
@@ -740,6 +916,22 @@ export interface paths {
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/documents/{documentId}/ocr-status": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get: operations["getDocumentOcrStatus"];
|
||||
put?: never;
|
||||
post?: never;
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/documents/search": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
@@ -940,6 +1132,7 @@ export interface components {
|
||||
name: string;
|
||||
};
|
||||
PersonUpdateDTO: {
|
||||
title?: string;
|
||||
firstName?: string;
|
||||
lastName?: string;
|
||||
alias?: string;
|
||||
@@ -978,6 +1171,8 @@ export interface components {
|
||||
receiverIds?: string[];
|
||||
tags?: string;
|
||||
metadataComplete?: boolean;
|
||||
/** @enum {string} */
|
||||
scriptType?: "UNKNOWN" | "TYPEWRITER" | "HANDWRITING_LATIN" | "HANDWRITING_KURRENT";
|
||||
};
|
||||
Document: {
|
||||
/** Format: uuid */
|
||||
@@ -1002,9 +1197,12 @@ export interface components {
|
||||
/** Format: date-time */
|
||||
updatedAt: string;
|
||||
metadataComplete: boolean;
|
||||
/** @enum {string} */
|
||||
scriptType: "UNKNOWN" | "TYPEWRITER" | "HANDWRITING_LATIN" | "HANDWRITING_KURRENT";
|
||||
receivers?: components["schemas"]["Person"][];
|
||||
sender?: components["schemas"]["Person"];
|
||||
tags?: components["schemas"]["Tag"][];
|
||||
trainingLabels?: ("KURRENT_RECOGNITION" | "KURRENT_SEGMENTATION")[];
|
||||
};
|
||||
UpdateTranscriptionBlockDTO: {
|
||||
text?: string;
|
||||
@@ -1017,10 +1215,13 @@ export interface components {
|
||||
annotationId: string;
|
||||
/** Format: uuid */
|
||||
documentId: string;
|
||||
text: string;
|
||||
text?: string;
|
||||
label?: string;
|
||||
/** Format: int32 */
|
||||
sortOrder: number;
|
||||
/** @enum {string} */
|
||||
source: "MANUAL" | "OCR";
|
||||
reviewed: boolean;
|
||||
/** Format: int32 */
|
||||
version: number;
|
||||
/** Format: uuid */
|
||||
@@ -1068,6 +1269,35 @@ export interface components {
|
||||
/** Format: date-time */
|
||||
createdAt: string;
|
||||
};
|
||||
OcrTrainingRun: {
|
||||
/** Format: uuid */
|
||||
id: string;
|
||||
/** @enum {string} */
|
||||
status: "RUNNING" | "DONE" | "FAILED";
|
||||
/** Format: int32 */
|
||||
blockCount: number;
|
||||
/** Format: int32 */
|
||||
documentCount: number;
|
||||
modelName: string;
|
||||
/** Format: double */
|
||||
cer?: number;
|
||||
/** Format: double */
|
||||
loss?: number;
|
||||
/** Format: double */
|
||||
accuracy?: number;
|
||||
/** Format: int32 */
|
||||
epochs?: number;
|
||||
errorMessage?: string;
|
||||
/** Format: uuid */
|
||||
triggeredBy?: string;
|
||||
/** Format: date-time */
|
||||
createdAt: string;
|
||||
/** Format: date-time */
|
||||
completedAt?: string;
|
||||
};
|
||||
BatchOcrDTO: {
|
||||
documentIds: string[];
|
||||
};
|
||||
GroupDTO: {
|
||||
name?: string;
|
||||
permissions?: string[];
|
||||
@@ -1118,6 +1348,11 @@ export interface components {
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
};
|
||||
TriggerOcrDTO: {
|
||||
/** @enum {string} */
|
||||
scriptType?: "UNKNOWN" | "TYPEWRITER" | "HANDWRITING_LATIN" | "HANDWRITING_KURRENT";
|
||||
useExistingAnnotations?: boolean;
|
||||
};
|
||||
CreateAnnotationDTO: {
|
||||
/** Format: int32 */
|
||||
pageNumber?: number;
|
||||
@@ -1130,6 +1365,7 @@ export interface components {
|
||||
/** Format: double */
|
||||
height?: number;
|
||||
color?: string;
|
||||
polygon?: number[][];
|
||||
};
|
||||
DocumentAnnotation: {
|
||||
/** Format: uuid */
|
||||
@@ -1148,6 +1384,7 @@ export interface components {
|
||||
height: number;
|
||||
color: string;
|
||||
fileHash?: string;
|
||||
polygon?: number[][];
|
||||
/** Format: uuid */
|
||||
createdBy?: string;
|
||||
/** Format: date-time */
|
||||
@@ -1199,6 +1436,10 @@ export interface components {
|
||||
actorName?: string;
|
||||
documentTitle?: string;
|
||||
};
|
||||
TrainingLabelRequest: {
|
||||
label?: string;
|
||||
enrolled?: boolean;
|
||||
};
|
||||
StatsDTO: {
|
||||
/** Format: int64 */
|
||||
totalPersons?: number;
|
||||
@@ -1218,15 +1459,54 @@ export interface components {
|
||||
deathYear?: number;
|
||||
alias?: string;
|
||||
notes?: string;
|
||||
personType?: string;
|
||||
/** Format: int64 */
|
||||
documentCount?: number;
|
||||
personType?: string;
|
||||
};
|
||||
TrainingInfoResponse: {
|
||||
/** Format: int32 */
|
||||
availableBlocks?: number;
|
||||
/** Format: int32 */
|
||||
totalOcrBlocks?: number;
|
||||
/** Format: int32 */
|
||||
availableDocuments?: number;
|
||||
/** Format: int32 */
|
||||
availableSegBlocks?: number;
|
||||
ocrServiceAvailable?: boolean;
|
||||
lastRun?: components["schemas"]["OcrTrainingRun"];
|
||||
runs?: components["schemas"]["OcrTrainingRun"][];
|
||||
};
|
||||
StreamingResponseBody: unknown;
|
||||
OcrJob: {
|
||||
/** Format: uuid */
|
||||
id: string;
|
||||
/** @enum {string} */
|
||||
status: "PENDING" | "RUNNING" | "DONE" | "FAILED";
|
||||
/** Format: int32 */
|
||||
totalDocuments: number;
|
||||
/** Format: int32 */
|
||||
processedDocuments: number;
|
||||
/** Format: int32 */
|
||||
errorCount: number;
|
||||
/** Format: int32 */
|
||||
skippedCount: number;
|
||||
progressMessage?: string;
|
||||
/** Format: uuid */
|
||||
createdBy?: string;
|
||||
/** Format: date-time */
|
||||
createdAt: string;
|
||||
/** Format: date-time */
|
||||
updatedAt: string;
|
||||
};
|
||||
SseEmitter: {
|
||||
/** Format: int64 */
|
||||
timeout?: number;
|
||||
};
|
||||
PageNotificationDTO: {
|
||||
/** Format: int64 */
|
||||
totalElements?: number;
|
||||
/** Format: int32 */
|
||||
totalPages?: number;
|
||||
/** Format: int64 */
|
||||
totalElements?: number;
|
||||
pageable?: components["schemas"]["PageableObject"];
|
||||
/** Format: int32 */
|
||||
size?: number;
|
||||
@@ -1256,10 +1536,6 @@ export interface components {
|
||||
empty?: boolean;
|
||||
unsorted?: boolean;
|
||||
};
|
||||
SseEmitter: {
|
||||
/** Format: int64 */
|
||||
timeout?: number;
|
||||
};
|
||||
DocumentVersionSummary: {
|
||||
/** Format: uuid */
|
||||
id: string;
|
||||
@@ -1292,6 +1568,15 @@ export interface components {
|
||||
/** Format: date-time */
|
||||
changedAt: string;
|
||||
};
|
||||
OcrStatusDTO: {
|
||||
status?: string;
|
||||
/** Format: uuid */
|
||||
jobId?: string;
|
||||
/** Format: int32 */
|
||||
currentPage?: number;
|
||||
/** Format: int32 */
|
||||
totalPages?: number;
|
||||
};
|
||||
DocumentSearchResult: {
|
||||
documents?: components["schemas"]["Document"][];
|
||||
/** Format: int64 */
|
||||
@@ -1702,6 +1987,29 @@ export interface operations {
|
||||
};
|
||||
};
|
||||
};
|
||||
reviewBlock: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path: {
|
||||
documentId: string;
|
||||
blockId: string;
|
||||
};
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
responses: {
|
||||
/** @description OK */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"*/*": components["schemas"]["TranscriptionBlock"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
reorderBlocks: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
@@ -1914,6 +2222,72 @@ export interface operations {
|
||||
};
|
||||
};
|
||||
};
|
||||
triggerTraining: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
responses: {
|
||||
/** @description Created */
|
||||
201: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"*/*": components["schemas"]["OcrTrainingRun"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
triggerSegTraining: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
responses: {
|
||||
/** @description Created */
|
||||
201: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"*/*": components["schemas"]["OcrTrainingRun"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
triggerBatch: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["BatchOcrDTO"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description Accepted */
|
||||
202: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"*/*": {
|
||||
[key: string]: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
markAllRead: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
@@ -2124,6 +2498,34 @@ export interface operations {
|
||||
};
|
||||
};
|
||||
};
|
||||
triggerOcr: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path: {
|
||||
documentId: string;
|
||||
};
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["TriggerOcrDTO"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description Accepted */
|
||||
202: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"*/*": {
|
||||
[key: string]: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
getDocumentComments: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
@@ -2521,6 +2923,30 @@ export interface operations {
|
||||
};
|
||||
};
|
||||
};
|
||||
patchTrainingLabel: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path: {
|
||||
id: string;
|
||||
};
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["TrainingLabelRequest"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description OK */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content?: never;
|
||||
};
|
||||
};
|
||||
};
|
||||
deleteComment: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
@@ -2701,6 +3127,110 @@ export interface operations {
|
||||
};
|
||||
};
|
||||
};
|
||||
getTrainingInfo: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
responses: {
|
||||
/** @description OK */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"*/*": components["schemas"]["TrainingInfoResponse"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
exportTrainingData: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
responses: {
|
||||
/** @description OK */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"*/*": components["schemas"]["StreamingResponseBody"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
exportSegmentationTrainingData: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
responses: {
|
||||
/** @description OK */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"*/*": components["schemas"]["StreamingResponseBody"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
getJobStatus: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path: {
|
||||
jobId: string;
|
||||
};
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
responses: {
|
||||
/** @description OK */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"*/*": components["schemas"]["OcrJob"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
streamProgress: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path: {
|
||||
jobId: string;
|
||||
};
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
responses: {
|
||||
/** @description OK */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"text/event-stream": components["schemas"]["SseEmitter"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
getNotifications: {
|
||||
parameters: {
|
||||
query?: {
|
||||
@@ -2860,6 +3390,28 @@ export interface operations {
|
||||
};
|
||||
};
|
||||
};
|
||||
getDocumentOcrStatus: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path: {
|
||||
documentId: string;
|
||||
};
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
responses: {
|
||||
/** @description OK */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"*/*": components["schemas"]["OcrStatusDTO"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
search_1: {
|
||||
parameters: {
|
||||
query?: {
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
<script lang="ts">
|
||||
import { onDestroy } from 'svelte';
|
||||
import { m } from '$lib/paraglide/messages.js';
|
||||
import OcrTrainingCard from '$lib/components/OcrTrainingCard.svelte';
|
||||
import SegmentationTrainingCard from '$lib/components/SegmentationTrainingCard.svelte';
|
||||
import type { components } from '$lib/generated/api.js';
|
||||
|
||||
type TrainingInfo = components['schemas']['TrainingInfoResponse'];
|
||||
|
||||
let trainingInfo: TrainingInfo | null = $state(null);
|
||||
|
||||
let backfillResult: number | null = $state(null);
|
||||
let backfillLoading = $state(false);
|
||||
@@ -51,8 +58,16 @@ async function triggerImport() {
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchTrainingInfo() {
|
||||
const res = await fetch('/api/ocr/training-info');
|
||||
if (res.ok) {
|
||||
trainingInfo = await res.json();
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
fetchImportStatus();
|
||||
fetchTrainingInfo();
|
||||
});
|
||||
|
||||
onDestroy(() => stopPolling());
|
||||
@@ -88,6 +103,12 @@ async function backfillFileHashes() {
|
||||
|
||||
<div class="flex-1 overflow-y-auto p-6">
|
||||
<div class="mx-auto max-w-2xl space-y-5">
|
||||
<!-- OCR Recognition Training -->
|
||||
<OcrTrainingCard trainingInfo={trainingInfo} />
|
||||
|
||||
<!-- OCR Segmentation Training -->
|
||||
<SegmentationTrainingCard trainingInfo={trainingInfo} />
|
||||
|
||||
<!-- Backfill versions -->
|
||||
<div class="rounded-sm border border-line bg-surface p-6 shadow-sm">
|
||||
<h2 class="mb-1 font-sans text-sm font-bold text-ink">{m.admin_system_backfill_heading()}</h2>
|
||||
|
||||
@@ -78,6 +78,8 @@ describe('Admin system page — mass import card', () => {
|
||||
startedAt: null
|
||||
})
|
||||
})
|
||||
// training info fetch → empty
|
||||
.mockResolvedValueOnce({ ok: true, json: async () => ({}) })
|
||||
// trigger POST → returns RUNNING immediately
|
||||
.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
|
||||
@@ -129,20 +129,29 @@ async function reviewToggle(blockId: string) {
|
||||
transcriptionBlocks = transcriptionBlocks.map((b) => (b.id === blockId ? updated : b));
|
||||
}
|
||||
|
||||
async function toggleTrainingLabel(label: string, enrolled: boolean) {
|
||||
const res = await fetch(`/api/documents/${doc.id}/training-labels`, {
|
||||
method: 'PATCH',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ label, enrolled })
|
||||
});
|
||||
if (!res.ok) throw new Error('Failed to update training label');
|
||||
}
|
||||
|
||||
let ocrRunning = $state(false);
|
||||
let ocrProgressMessage = $state('');
|
||||
let ocrErrorMessage = $state('');
|
||||
let ocrPollTimer = $state<ReturnType<typeof setInterval> | null>(null);
|
||||
let ocrSkippedPages = $state(0);
|
||||
|
||||
async function triggerOcr(scriptType: string) {
|
||||
async function triggerOcr(scriptType: string, useExistingAnnotations: boolean = false) {
|
||||
ocrRunning = true;
|
||||
ocrErrorMessage = '';
|
||||
try {
|
||||
const res = await fetch(`/api/documents/${doc.id}/ocr`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ scriptType })
|
||||
body: JSON.stringify({ scriptType, useExistingAnnotations })
|
||||
});
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
@@ -449,11 +458,14 @@ onMount(() => {
|
||||
activeAnnotationId={activeAnnotationId}
|
||||
storedScriptType={doc.scriptType ?? ''}
|
||||
canRunOcr={canWrite && !!doc.filePath}
|
||||
canWrite={canWrite}
|
||||
trainingLabels={doc.trainingLabels ?? []}
|
||||
onBlockFocus={handleBlockFocus}
|
||||
onSaveBlock={saveBlock}
|
||||
onDeleteBlock={deleteBlock}
|
||||
onReviewToggle={reviewToggle}
|
||||
onTriggerOcr={triggerOcr}
|
||||
onToggleTrainingLabel={toggleTrainingLabel}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -23,4 +23,4 @@ COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
|
||||
@@ -47,7 +47,7 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
||||
pred_it = rpred.rpred(_model, image, baseline_seg)
|
||||
|
||||
for record in pred_it:
|
||||
polygon_pts = record.cuts if hasattr(record, "cuts") else []
|
||||
polygon_pts = record.boundary if hasattr(record, "boundary") and record.boundary else []
|
||||
|
||||
if polygon_pts:
|
||||
xs = [p[0] for p in polygon_pts]
|
||||
@@ -55,8 +55,8 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
||||
x1, y1 = min(xs), min(ys)
|
||||
x2, y2 = max(xs), max(ys)
|
||||
else:
|
||||
xs = [p[0] for p in record.line]
|
||||
ys = [p[1] for p in record.line]
|
||||
xs = [p[0] for p in record.baseline]
|
||||
ys = [p[1] for p in record.baseline]
|
||||
x1, y1 = min(xs), min(ys) - 5
|
||||
x2, y2 = max(xs), max(ys) + 5
|
||||
|
||||
@@ -79,6 +79,54 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
||||
return blocks
|
||||
|
||||
|
||||
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
||||
"""Crop image to a normalized region and run Kraken recognition on the crop.
|
||||
|
||||
Used for guided OCR — skips full-page layout detection entirely.
|
||||
A single synthetic baseline spanning the full crop width is used so that
|
||||
blla.segment() (which crashes on small crops) is never called.
|
||||
Coordinates are normalized to [0, 1].
|
||||
"""
|
||||
from kraken import rpred
|
||||
from kraken.containers import Segmentation, BaselineLine
|
||||
|
||||
if _model is None:
|
||||
raise RuntimeError("Kraken model is not loaded")
|
||||
|
||||
pw, ph = image.size
|
||||
x1 = max(0, int(x * pw))
|
||||
y1 = max(0, int(y * ph))
|
||||
x2 = min(pw, int((x + w) * pw))
|
||||
y2 = min(ph, int((y + h) * ph))
|
||||
crop = image.crop((x1, y1, x2, y2))
|
||||
|
||||
cw, ch = crop.size
|
||||
if cw == 0 or ch == 0:
|
||||
return ""
|
||||
|
||||
# Single synthetic baseline at vertical midpoint, spanning full crop width.
|
||||
# Kraken's bounds check is >= (not >), so all coordinates must be < image
|
||||
# dimension — use cw-1 / ch-1 to stay strictly inside.
|
||||
mid_y = ch // 2
|
||||
synthetic_seg = Segmentation(
|
||||
type="baselines",
|
||||
imagename="",
|
||||
text_direction="horizontal-lr",
|
||||
script_detection=False,
|
||||
lines=[
|
||||
BaselineLine(
|
||||
id="line0",
|
||||
baseline=[(0, mid_y), (cw - 1, mid_y)],
|
||||
boundary=[(0, 0), (cw - 1, 0), (cw - 1, ch - 1), (0, ch - 1)],
|
||||
)
|
||||
],
|
||||
regions={},
|
||||
line_orders=[],
|
||||
)
|
||||
pred_it = rpred.rpred(_model, crop, synthetic_seg)
|
||||
return " ".join(r.prediction for r in pred_it)
|
||||
|
||||
|
||||
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
||||
"""Run Kraken segmentation + recognition on a list of PIL images.
|
||||
|
||||
|
||||
@@ -81,6 +81,25 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
||||
return blocks
|
||||
|
||||
|
||||
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
||||
"""Crop image to a normalized region and run Surya recognition on the crop.
|
||||
|
||||
Used for guided OCR — skips full-page layout detection and only processes
|
||||
the given bounding box. Coordinates are normalized to [0, 1].
|
||||
"""
|
||||
load_models()
|
||||
|
||||
pw, ph = image.size
|
||||
x1 = max(0, int(x * pw))
|
||||
y1 = max(0, int(y * ph))
|
||||
x2 = min(pw, int((x + w) * pw))
|
||||
y2 = min(ph, int((y + h) * ph))
|
||||
crop = image.crop((x1, y1, x2, y2))
|
||||
|
||||
predictions = _recognition_predictor([crop], det_predictor=_detection_predictor)
|
||||
return " ".join(line.text for line in predictions[0].text_lines)
|
||||
|
||||
|
||||
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
||||
"""Run Surya OCR on a list of PIL images (one per page).
|
||||
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
"""OCR microservice — FastAPI app with Surya and Kraken engine support."""
|
||||
|
||||
import asyncio
|
||||
import glob
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import zipfile
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import pypdfium2 as pdfium
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, Header, HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from PIL import Image
|
||||
|
||||
@@ -19,6 +26,9 @@ from engines import kraken as kraken_engine
|
||||
from engines import surya as surya_engine
|
||||
from models import OcrBlock, OcrRequest
|
||||
|
||||
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
|
||||
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,7 +110,12 @@ async def run_ocr(request: OcrRequest):
|
||||
|
||||
@app.post("/ocr/stream")
|
||||
async def run_ocr_stream(request: OcrRequest):
|
||||
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page."""
|
||||
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page.
|
||||
|
||||
When request.regions is provided, runs in guided mode: each region is cropped
|
||||
and recognized individually, skipping full-page layout detection. The response
|
||||
blocks include the annotationId from the region.
|
||||
"""
|
||||
if not _models_ready:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||
|
||||
@@ -115,6 +130,81 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
detail="Kraken model not available — cannot process Kurrent script",
|
||||
)
|
||||
|
||||
engine = kraken_engine if use_kraken else surya_engine
|
||||
|
||||
if request.regions:
|
||||
# Guided mode: recognize only the user-drawn annotation regions
|
||||
regions_by_page: dict[int, list] = {}
|
||||
for region in request.regions:
|
||||
regions_by_page.setdefault(region.pageNumber, []).append(region)
|
||||
|
||||
async def generate_guided():
|
||||
total_pages = len(images)
|
||||
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
||||
|
||||
total_blocks = 0
|
||||
skipped_pages = 0
|
||||
|
||||
for page_idx, image in enumerate(images, start=1):
|
||||
page_regions = regions_by_page.get(page_idx, [])
|
||||
if not page_regions:
|
||||
yield json.dumps({
|
||||
"type": "page",
|
||||
"pageNumber": page_idx,
|
||||
"blocks": [],
|
||||
}) + "\n"
|
||||
del image
|
||||
continue
|
||||
|
||||
try:
|
||||
blocks = []
|
||||
for region in page_regions:
|
||||
text = await asyncio.to_thread(
|
||||
engine.extract_region_text, image,
|
||||
region.x, region.y, region.width, region.height,
|
||||
)
|
||||
blocks.append({
|
||||
"pageNumber": page_idx,
|
||||
"x": region.x,
|
||||
"y": region.y,
|
||||
"width": region.width,
|
||||
"height": region.height,
|
||||
"polygon": None,
|
||||
"text": text,
|
||||
"annotationId": region.annotationId,
|
||||
})
|
||||
|
||||
total_blocks += len(blocks)
|
||||
yield json.dumps({
|
||||
"type": "page",
|
||||
"pageNumber": page_idx,
|
||||
"blocks": blocks,
|
||||
}) + "\n"
|
||||
|
||||
except Exception:
|
||||
logger.exception("Guided OCR failed on page %d", page_idx)
|
||||
skipped_pages += 1
|
||||
yield json.dumps({
|
||||
"type": "error",
|
||||
"pageNumber": page_idx,
|
||||
"message": f"Guided OCR processing failed on page {page_idx}",
|
||||
}) + "\n"
|
||||
|
||||
finally:
|
||||
del image
|
||||
|
||||
yield json.dumps({
|
||||
"type": "done",
|
||||
"totalBlocks": total_blocks,
|
||||
"skippedPages": skipped_pages,
|
||||
}) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_guided(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"},
|
||||
)
|
||||
|
||||
async def generate():
|
||||
total_pages = len(images)
|
||||
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
||||
@@ -124,7 +214,6 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
|
||||
for page_idx, image in enumerate(images, start=1):
|
||||
try:
|
||||
engine = kraken_engine if use_kraken else surya_engine
|
||||
blocks = await asyncio.to_thread(
|
||||
engine.extract_page_blocks, image, page_idx, request.language
|
||||
)
|
||||
@@ -169,6 +258,262 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
)
|
||||
|
||||
|
||||
def _check_training_token(x_training_token: str | None) -> None:
|
||||
"""Validate training token — fails closed when TRAINING_TOKEN is not configured."""
|
||||
if not TRAINING_TOKEN:
|
||||
raise HTTPException(status_code=503, detail="Training not configured on this node")
|
||||
if x_training_token != TRAINING_TOKEN:
|
||||
raise HTTPException(status_code=403, detail="Invalid or missing X-Training-Token")
|
||||
|
||||
|
||||
def _validate_zip_entry(name: str, extract_dir: str) -> None:
|
||||
"""Reject ZIP Slip attacks: path traversal and absolute paths."""
|
||||
if os.path.isabs(name) or name.startswith(".."):
|
||||
raise HTTPException(status_code=400, detail=f"Unsafe ZIP entry: {name}")
|
||||
resolved = os.path.realpath(os.path.join(extract_dir, name))
|
||||
if not resolved.startswith(os.path.realpath(extract_dir)):
|
||||
raise HTTPException(status_code=400, detail=f"ZIP Slip detected: {name}")
|
||||
|
||||
|
||||
def _rotate_backups(model_path: str, keep: int = 3) -> None:
|
||||
"""Keep only the last `keep` timestamped backups of the model."""
|
||||
pattern = model_path + ".*.bak"
|
||||
backups = sorted(glob.glob(pattern))
|
||||
for old in backups[:-keep]:
|
||||
try:
|
||||
os.remove(old)
|
||||
except OSError:
|
||||
logger.warning("Could not remove old backup: %s", old)
|
||||
|
||||
|
||||
def _parse_best_checkpoint(checkpoint_dir: str) -> tuple[float | None, int]:
|
||||
"""Parse checkpoint filenames to find the best validation metric and epoch count.
|
||||
|
||||
Kraken saves checkpoints as e.g. ``checkpoint_03-0.9500.ckpt``.
|
||||
Returns (best_accuracy, epoch_count).
|
||||
"""
|
||||
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
|
||||
best_acc: float | None = None
|
||||
max_epoch = 0
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
m = pattern.match(fname)
|
||||
if m:
|
||||
epoch = int(m.group(1))
|
||||
acc = float(m.group(2))
|
||||
max_epoch = max(max_epoch, epoch)
|
||||
if best_acc is None or acc > best_acc:
|
||||
best_acc = acc
|
||||
return best_acc, max_epoch
|
||||
|
||||
|
||||
def _find_best_model(checkpoint_dir: str) -> str | None:
|
||||
"""Return the best final model file produced by ketos train.
|
||||
|
||||
With --weights-format coreml, ketos writes ``best_<score>.mlmodel``.
|
||||
Falls back to any .mlmodel in the directory.
|
||||
"""
|
||||
# Prefer the named best file (e.g. best_0.8256.mlmodel or best_0.8256.safetensors)
|
||||
best_pattern = re.compile(r"best_([0-9.]+)\.(mlmodel|safetensors)$")
|
||||
best_acc: float | None = None
|
||||
best_path: str | None = None
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
m = best_pattern.match(fname)
|
||||
if m:
|
||||
acc = float(m.group(1))
|
||||
if best_acc is None or acc > best_acc:
|
||||
best_acc = acc
|
||||
best_path = os.path.join(checkpoint_dir, fname)
|
||||
if best_path:
|
||||
return best_path
|
||||
# Fallback: any .mlmodel file
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
if fname.endswith(".mlmodel"):
|
||||
return os.path.join(checkpoint_dir, fname)
|
||||
return None
|
||||
|
||||
|
||||
@app.post("/train")
|
||||
async def train_model(
|
||||
file: UploadFile,
|
||||
x_training_token: str | None = Header(default=None),
|
||||
):
|
||||
"""Fine-tune the Kurrent recognition model with uploaded training data.
|
||||
|
||||
Accepts a ZIP archive containing .png/.gt.txt training pairs exported
|
||||
by the Java backend. Training mutates in-process model state — not safe
|
||||
if the service is replicated.
|
||||
"""
|
||||
_check_training_token(x_training_token)
|
||||
|
||||
if not _models_ready:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||
|
||||
zip_bytes = await file.read()
|
||||
|
||||
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
||||
log.info("Starting training run %s", training_run_id)
|
||||
|
||||
def _run_training() -> dict:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Extract ZIP with safety checks
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
for entry in zf.namelist():
|
||||
_validate_zip_entry(entry, tmp_dir)
|
||||
zf.extractall(tmp_dir)
|
||||
|
||||
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
|
||||
|
||||
ground_truth = glob.glob(os.path.join(tmp_dir, "*.xml"))
|
||||
if not ground_truth:
|
||||
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
|
||||
|
||||
log.info("Training on %d ground-truth pairs", len(ground_truth))
|
||||
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
cmd = [
|
||||
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
|
||||
"train",
|
||||
"-f", "page",
|
||||
"--weights-format", "coreml",
|
||||
"-o", checkpoint_dir,
|
||||
"-q", "fixed",
|
||||
"-N", "10",
|
||||
"-B", "1",
|
||||
]
|
||||
if os.path.exists(KRAKEN_MODEL_PATH):
|
||||
cmd += ["-i", KRAKEN_MODEL_PATH]
|
||||
cmd += ground_truth
|
||||
|
||||
log.info("Running: %s", " ".join(cmd[:6]) + " ...")
|
||||
proc = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log.info("ketos train stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
|
||||
if proc.stderr:
|
||||
log.info("ketos train stderr: %s", proc.stderr[-2000:])
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
||||
|
||||
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
||||
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
||||
log.info("Training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
|
||||
|
||||
# Find the best model file produced by training
|
||||
best_model = _find_best_model(checkpoint_dir)
|
||||
if best_model is None:
|
||||
raise RuntimeError("Training produced no model file")
|
||||
|
||||
# Backup existing model and replace
|
||||
if os.path.exists(KRAKEN_MODEL_PATH):
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
backup_path = f"{KRAKEN_MODEL_PATH}.{timestamp}.bak"
|
||||
shutil.copy2(KRAKEN_MODEL_PATH, backup_path)
|
||||
log.info("Backed up model to %s", backup_path)
|
||||
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
|
||||
|
||||
shutil.copy2(best_model, KRAKEN_MODEL_PATH)
|
||||
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
|
||||
|
||||
# Reload model in-process
|
||||
kraken_engine.load_models()
|
||||
log.info("Reloaded Kraken model in-process")
|
||||
|
||||
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_training)
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/segtrain")
|
||||
async def segtrain_model(
|
||||
file: UploadFile,
|
||||
x_training_token: str | None = Header(default=None),
|
||||
):
|
||||
"""Fine-tune the blla segmentation model with uploaded PAGE XML training data.
|
||||
|
||||
Accepts a ZIP archive containing .png/.xml (PAGE XML) training pairs exported
|
||||
by the Java backend. Training mutates in-process model state — not safe
|
||||
if the service is replicated.
|
||||
"""
|
||||
_check_training_token(x_training_token)
|
||||
|
||||
if not _models_ready:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||
|
||||
zip_bytes = await file.read()
|
||||
|
||||
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
||||
log.info("Starting segmentation training run %s", training_run_id)
|
||||
|
||||
blla_model_path = os.environ.get("BLLA_MODEL_PATH", "/app/models/blla.mlmodel")
|
||||
|
||||
def _run_segtrain() -> dict:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
for entry in zf.namelist():
|
||||
_validate_zip_entry(entry, tmp_dir)
|
||||
zf.extractall(tmp_dir)
|
||||
|
||||
log.info("Extracted %d ZIP entries for segmentation training", len(os.listdir(tmp_dir)))
|
||||
|
||||
xml_files = glob.glob(os.path.join(tmp_dir, "*.xml"))
|
||||
if not xml_files:
|
||||
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
|
||||
|
||||
log.info("Training on %d PAGE XML files", len(xml_files))
|
||||
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
cmd = [
|
||||
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
|
||||
"segtrain",
|
||||
"-o", checkpoint_dir,
|
||||
"-q", "fixed",
|
||||
"-N", "10",
|
||||
]
|
||||
if os.path.exists(blla_model_path):
|
||||
cmd += ["-i", blla_model_path, "--resize", "both"]
|
||||
else:
|
||||
# No pretrained model — train from scratch with reduced height (800px)
|
||||
# to keep peak RAM under ~200 MB on CPU (default 1800px uses ~500 MB+)
|
||||
cmd += [
|
||||
"-s",
|
||||
"[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 "
|
||||
"Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]",
|
||||
]
|
||||
cmd += xml_files
|
||||
|
||||
log.info("Running: %s", " ".join(cmd[:5]) + " ...")
|
||||
proc = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log.info("ketos segtrain stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
|
||||
if proc.stderr:
|
||||
log.info("ketos segtrain stderr: %s", proc.stderr[-2000:])
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
||||
|
||||
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
||||
log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
|
||||
|
||||
best_model = _find_best_model(checkpoint_dir)
|
||||
if best_model is None:
|
||||
raise RuntimeError("Segmentation training produced no model file")
|
||||
|
||||
if os.path.exists(blla_model_path):
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
backup_path = f"{blla_model_path}.{timestamp}.bak"
|
||||
shutil.copy2(blla_model_path, backup_path)
|
||||
_rotate_backups(blla_model_path, keep=3)
|
||||
|
||||
shutil.copy2(best_model, blla_model_path)
|
||||
log.info("Replaced blla model at %s", blla_model_path)
|
||||
|
||||
return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_segtrain)
|
||||
return result
|
||||
|
||||
|
||||
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
||||
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
||||
_validate_url(url)
|
||||
|
||||
@@ -1,12 +1,24 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class OcrRegion(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
annotationId: str
|
||||
pageNumber: int
|
||||
x: float
|
||||
y: float
|
||||
width: float
|
||||
height: float
|
||||
|
||||
|
||||
class OcrRequest(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
pdfUrl: str
|
||||
scriptType: str = "UNKNOWN"
|
||||
language: str = "de"
|
||||
regions: list[OcrRegion] | None = None
|
||||
|
||||
|
||||
class OcrBlock(BaseModel):
|
||||
@@ -19,3 +31,4 @@ class OcrBlock(BaseModel):
|
||||
height: float
|
||||
polygon: list[list[float]] | None = None
|
||||
text: str
|
||||
annotationId: str | None = None
|
||||
|
||||
69
ocr-service/test_training_auth.py
Normal file
69
ocr-service/test_training_auth.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Tests for /train and /segtrain endpoint authentication."""
|
||||
|
||||
import io
|
||||
import zipfile
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
def _minimal_zip() -> bytes:
|
||||
"""Return a ZIP with one .xml file so endpoint validation passes."""
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr("page_01.xml", "<PcGts/>")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
# ─── Missing TRAINING_TOKEN → fail closed ─────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_returns_503_when_training_token_not_configured():
|
||||
"""POST /train must return 503 when TRAINING_TOKEN env var is empty.
|
||||
|
||||
An empty token means the service was started without training configured.
|
||||
Allowing requests through would grant unauthenticated access to the
|
||||
training endpoint, contradicting the principle of failing closed.
|
||||
"""
|
||||
with patch("main.TRAINING_TOKEN", ""), \
|
||||
patch("main._models_ready", True):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/train",
|
||||
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
|
||||
)
|
||||
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_segtrain_returns_503_when_training_token_not_configured():
|
||||
"""POST /segtrain must return 503 when TRAINING_TOKEN env var is empty."""
|
||||
with patch("main.TRAINING_TOKEN", ""), \
|
||||
patch("main._models_ready", True):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/segtrain",
|
||||
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
|
||||
)
|
||||
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_returns_403_when_token_wrong():
|
||||
"""POST /train must return 403 when TRAINING_TOKEN is set but header is wrong."""
|
||||
with patch("main.TRAINING_TOKEN", "secret-token"), \
|
||||
patch("main._models_ready", True):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/train",
|
||||
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
|
||||
headers={"X-Training-Token": "wrong-token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
Reference in New Issue
Block a user