feat: OCR pipeline with NDJSON streaming and real-time progress (#226, #227, #231) #229

Merged
marcel merged 74 commits from feat/issue-226-227-ocr-pipeline-polygon into main 2026-04-13 12:39:04 +02:00
89 changed files with 5026 additions and 99 deletions

View File

@@ -16,10 +16,10 @@ public class AsyncConfig {
@Bean
public Executor taskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(1);
executor.setMaxPoolSize(1);
executor.setQueueCapacity(1);
executor.setThreadNamePrefix("Import-");
executor.setCorePoolSize(2);
executor.setMaxPoolSize(2);
executor.setQueueCapacity(10);
executor.setThreadNamePrefix("Async-");
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy());
return executor;
}

View File

@@ -5,6 +5,7 @@ import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3Configuration;
import software.amazon.awssdk.services.s3.presigner.S3Presigner;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.CommandLineRunner;
import org.springframework.context.annotation.Bean;
@@ -44,6 +45,19 @@ public class MinioConfig {
.build();
}
@Bean
public S3Presigner s3Presigner() {
return S3Presigner.builder()
.endpointOverride(URI.create(endpoint))
.serviceConfiguration(S3Configuration.builder()
.pathStyleAccessEnabled(true)
.build())
.region(Region.of(region))
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create(accessKey, secretKey)))
.build();
}
@Bean
public CommandLineRunner testS3Connection(S3Client s3Client) {
return args -> {

View File

@@ -0,0 +1,88 @@
package org.raddatz.familienarchiv.controller;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.dto.BatchOcrDTO;
import org.raddatz.familienarchiv.dto.OcrStatusDTO;
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
import org.raddatz.familienarchiv.model.AppUser;
import org.raddatz.familienarchiv.model.OcrJob;
import org.raddatz.familienarchiv.security.Permission;
import org.raddatz.familienarchiv.security.RequirePermission;
import org.raddatz.familienarchiv.service.OcrBatchService;
import org.raddatz.familienarchiv.service.OcrProgressService;
import org.raddatz.familienarchiv.service.OcrService;
import org.raddatz.familienarchiv.service.UserService;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.security.core.Authentication;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import jakarta.validation.Valid;
import java.util.Map;
import java.util.UUID;
@RestController
@RequiredArgsConstructor
@Slf4j
public class OcrController {
private final OcrService ocrService;
private final OcrBatchService ocrBatchService;
private final OcrProgressService ocrProgressService;
private final UserService userService;
@PostMapping("/api/documents/{documentId}/ocr")
@ResponseStatus(HttpStatus.ACCEPTED)
@RequirePermission(Permission.WRITE_ALL)
public Map<String, UUID> triggerOcr(
@PathVariable UUID documentId,
@RequestBody TriggerOcrDTO dto,
Authentication authentication) {
UUID userId = resolveUserId(authentication);
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId);
return Map.of("jobId", jobId);
}
@PostMapping("/api/ocr/batch")
@ResponseStatus(HttpStatus.ACCEPTED)
@RequirePermission(Permission.ADMIN)
public Map<String, UUID> triggerBatch(
@RequestBody @Valid BatchOcrDTO dto,
Authentication authentication) {
UUID userId = resolveUserId(authentication);
UUID jobId = ocrBatchService.startBatch(dto.getDocumentIds(), userId);
return Map.of("jobId", jobId);
}
@GetMapping("/api/ocr/jobs/{jobId}")
@RequirePermission(Permission.READ_ALL)
public OcrJob getJobStatus(@PathVariable UUID jobId) {
return ocrService.getJob(jobId);
}
@GetMapping(value = "/api/ocr/jobs/{jobId}/progress", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@RequirePermission(Permission.READ_ALL)
public SseEmitter streamProgress(@PathVariable UUID jobId) {
ocrService.getJob(jobId);
return ocrProgressService.register(jobId);
}
@GetMapping("/api/documents/{documentId}/ocr-status")
@RequirePermission(Permission.READ_ALL)
public OcrStatusDTO getDocumentOcrStatus(@PathVariable UUID documentId) {
return ocrService.getDocumentOcrStatus(documentId);
}
private UUID resolveUserId(Authentication authentication) {
if (authentication == null || !authentication.isAuthenticated()) return null;
try {
AppUser user = userService.findByUsername(authentication.getName());
return user != null ? user.getId() : null;
} catch (Exception e) {
log.warn("Failed to resolve user ID for authentication: {}", authentication.getName(), e);
return null;
}
}
}

View File

@@ -81,6 +81,14 @@ public class TranscriptionBlockController {
return transcriptionService.listBlocks(documentId);
}
@PutMapping("/{blockId}/review")
@RequirePermission(Permission.WRITE_ALL)
public TranscriptionBlock reviewBlock(
@PathVariable UUID documentId,
@PathVariable UUID blockId) {
return transcriptionService.reviewBlock(documentId, blockId);
}
@GetMapping("/{blockId}/history")
@RequirePermission(Permission.READ_ALL)
public List<TranscriptionBlockVersion> getBlockHistory(

View File

@@ -0,0 +1,19 @@
package org.raddatz.familienarchiv.dto;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.Size;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
import java.util.UUID;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class BatchOcrDTO {
@NotEmpty
@Size(max = 500, message = "batch size must not exceed 500 documents")
private List<UUID> documentIds;
}

View File

@@ -1,9 +1,15 @@
package org.raddatz.familienarchiv.dto;
import jakarta.validation.Valid;
import jakarta.validation.constraints.DecimalMax;
import jakarta.validation.constraints.DecimalMin;
import jakarta.validation.constraints.Size;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
@@ -14,4 +20,19 @@ public class CreateAnnotationDTO {
private double width;
private double height;
private String color;
@Size(min = 4, max = 4, message = "polygon must have exactly 4 points")
@UniquePoints
@Valid
private List<@Size(min = 2, max = 2, message = "each point must have exactly 2 coordinates")
List<@DecimalMin("0.0") @DecimalMax("1.0") Double>> polygon;
public CreateAnnotationDTO(int pageNumber, double x, double y, double width, double height, String color) {
this.pageNumber = pageNumber;
this.x = x;
this.y = y;
this.width = width;
this.height = height;
this.color = color;
}
}

View File

@@ -5,6 +5,7 @@ import java.util.List;
import java.util.UUID;
import lombok.Data;
import org.raddatz.familienarchiv.model.ScriptType;
@Data
public class DocumentUpdateDTO {
@@ -18,4 +19,5 @@ public class DocumentUpdateDTO {
private List<UUID> receiverIds;
private String tags;
private Boolean metadataComplete;
private ScriptType scriptType;
}

View File

@@ -0,0 +1,19 @@
package org.raddatz.familienarchiv.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.UUID;
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class OcrStatusDTO {
private String status;
private UUID jobId;
private int currentPage;
private int totalPages;
}

View File

@@ -0,0 +1,13 @@
package org.raddatz.familienarchiv.dto;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.raddatz.familienarchiv.model.ScriptType;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class TriggerOcrDTO {
private ScriptType scriptType;
}

View File

@@ -0,0 +1,16 @@
package org.raddatz.familienarchiv.dto;
import jakarta.validation.Constraint;
import jakarta.validation.Payload;
import java.lang.annotation.*;
@Documented
@Constraint(validatedBy = UniquePointsValidator.class)
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
public @interface UniquePoints {
String message() default "polygon must contain 4 unique points";
Class<?>[] groups() default {};
Class<? extends Payload>[] payload() default {};
}

View File

@@ -0,0 +1,16 @@
package org.raddatz.familienarchiv.dto;
import jakarta.validation.ConstraintValidator;
import jakarta.validation.ConstraintValidatorContext;
import java.util.HashSet;
import java.util.List;
public class UniquePointsValidator implements ConstraintValidator<UniquePoints, List<List<Double>>> {
@Override
public boolean isValid(List<List<Double>> polygon, ConstraintValidatorContext context) {
if (polygon == null) return true;
return new HashSet<>(polygon).size() == polygon.size();
}
}

View File

@@ -66,6 +66,16 @@ public enum ErrorCode {
/** The notification with the given ID does not exist. 404 */
NOTIFICATION_NOT_FOUND,
// --- OCR ---
/** The OCR service is not available or not healthy. 503 */
OCR_SERVICE_UNAVAILABLE,
/** The OCR job with the given ID does not exist. 404 */
OCR_JOB_NOT_FOUND,
/** The document is not in UPLOADED status and cannot be OCR'd. 400 */
OCR_DOCUMENT_NOT_UPLOADED,
/** OCR processing failed for the document. 500 */
OCR_PROCESSING_FAILED,
// --- Generic ---
/** Request validation failed (missing or malformed fields). 400 */
VALIDATION_ERROR,

View File

@@ -0,0 +1,6 @@
package org.raddatz.familienarchiv.model;
public enum BlockSource {
MANUAL,
OCR
}

View File

@@ -91,6 +91,12 @@ public class Document {
@Builder.Default
private boolean metadataComplete = false;
@Enumerated(EnumType.STRING)
@Column(name = "script_type", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
@Builder.Default
private ScriptType scriptType = ScriptType.UNKNOWN;
@ManyToMany(fetch = FetchType.EAGER)
@JoinTable(name = "document_receivers", joinColumns = @JoinColumn(name = "document_id"), inverseJoinColumns = @JoinColumn(name = "person_id"))
@Builder.Default

View File

@@ -4,8 +4,11 @@ import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.persistence.*;
import lombok.*;
import org.hibernate.annotations.CreationTimestamp;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.type.SqlTypes;
import java.time.LocalDateTime;
import java.util.List;
import java.util.UUID;
@Entity
@@ -52,6 +55,10 @@ public class DocumentAnnotation {
@Column(name = "file_hash", length = 64)
private String fileHash;
@JdbcTypeCode(SqlTypes.JSON)
@Column(columnDefinition = "jsonb")
private List<List<Double>> polygon;
@Column(name = "created_by")
private UUID createdBy;

View File

@@ -0,0 +1,9 @@
package org.raddatz.familienarchiv.model;
public enum OcrDocumentStatus {
PENDING,
RUNNING,
DONE,
FAILED,
SKIPPED
}

View File

@@ -0,0 +1,65 @@
package org.raddatz.familienarchiv.model;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.persistence.*;
import lombok.*;
import org.hibernate.annotations.CreationTimestamp;
import org.hibernate.annotations.UpdateTimestamp;
import java.time.LocalDateTime;
import java.util.UUID;
@Entity
@Table(name = "ocr_jobs")
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class OcrJob {
@Id
@GeneratedValue(strategy = GenerationType.UUID)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID id;
@Enumerated(EnumType.STRING)
@Column(nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
@Builder.Default
private OcrJobStatus status = OcrJobStatus.PENDING;
@Column(name = "total_documents", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private int totalDocuments;
@Column(name = "processed_documents", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
@Builder.Default
private int processedDocuments = 0;
@Column(name = "error_count", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
@Builder.Default
private int errorCount = 0;
@Column(name = "skipped_count", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
@Builder.Default
private int skippedCount = 0;
@Column(name = "progress_message")
private String progressMessage;
@Column(name = "created_by")
private UUID createdBy;
@Column(name = "created_at", nullable = false, updatable = false)
@CreationTimestamp
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private LocalDateTime createdAt;
@Column(name = "updated_at", nullable = false)
@UpdateTimestamp
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private LocalDateTime updatedAt;
}

View File

@@ -0,0 +1,59 @@
package org.raddatz.familienarchiv.model;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.persistence.*;
import lombok.*;
import org.hibernate.annotations.CreationTimestamp;
import org.hibernate.annotations.UpdateTimestamp;
import java.time.LocalDateTime;
import java.util.UUID;
@Entity
@Table(name = "ocr_job_documents")
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class OcrJobDocument {
@Id
@GeneratedValue(strategy = GenerationType.UUID)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID id;
@Column(name = "job_id", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID jobId;
@Column(name = "document_id", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID documentId;
@Enumerated(EnumType.STRING)
@Column(nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
@Builder.Default
private OcrDocumentStatus status = OcrDocumentStatus.PENDING;
@Column(name = "error_message")
private String errorMessage;
@Column(name = "current_page")
@Builder.Default
private int currentPage = 0;
@Column(name = "total_pages")
@Builder.Default
private int totalPages = 0;
@Column(name = "created_at", nullable = false, updatable = false)
@CreationTimestamp
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private LocalDateTime createdAt;
@Column(name = "updated_at", nullable = false)
@UpdateTimestamp
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private LocalDateTime updatedAt;
}

View File

@@ -0,0 +1,8 @@
package org.raddatz.familienarchiv.model;
public enum OcrJobStatus {
PENDING,
RUNNING,
DONE,
FAILED
}

View File

@@ -0,0 +1,36 @@
package org.raddatz.familienarchiv.model;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.persistence.AttributeConverter;
import jakarta.persistence.Converter;
import java.util.List;
@Converter
public class PolygonConverter implements AttributeConverter<List<List<Double>>, String> {
private static final ObjectMapper MAPPER = new ObjectMapper();
private static final TypeReference<List<List<Double>>> TYPE_REF = new TypeReference<>() {};
@Override
public String convertToDatabaseColumn(List<List<Double>> polygon) {
if (polygon == null) return null;
try {
return MAPPER.writeValueAsString(polygon);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException("Failed to serialize polygon", e);
}
}
@Override
public List<List<Double>> convertToEntityAttribute(String json) {
if (json == null || json.isEmpty()) return null;
try {
return MAPPER.readValue(json, TYPE_REF);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException("Failed to deserialize polygon", e);
}
}
}

View File

@@ -0,0 +1,8 @@
package org.raddatz.familienarchiv.model;
public enum ScriptType {
UNKNOWN,
TYPEWRITER,
HANDWRITING_LATIN,
HANDWRITING_KURRENT
}

View File

@@ -41,6 +41,17 @@ public class TranscriptionBlock {
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private int sortOrder;
@Enumerated(EnumType.STRING)
@Column(nullable = false, length = 10)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
@Builder.Default
private BlockSource source = BlockSource.MANUAL;
@Column(nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
@Builder.Default
private boolean reviewed = false;
@Version
@Column(nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)

View File

@@ -0,0 +1,20 @@
package org.raddatz.familienarchiv.repository;
import org.raddatz.familienarchiv.model.OcrDocumentStatus;
import org.raddatz.familienarchiv.model.OcrJobDocument;
import org.springframework.data.jpa.repository.JpaRepository;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
public interface OcrJobDocumentRepository extends JpaRepository<OcrJobDocument, UUID> {
List<OcrJobDocument> findByJobIdOrderByCreatedAtAsc(UUID jobId);
List<OcrJobDocument> findByJobIdAndStatus(UUID jobId, OcrDocumentStatus status);
Optional<OcrJobDocument> findByJobIdAndDocumentId(UUID jobId, UUID documentId);
Optional<OcrJobDocument> findFirstByDocumentIdAndStatusIn(UUID documentId, List<OcrDocumentStatus> statuses);
}

View File

@@ -0,0 +1,9 @@
package org.raddatz.familienarchiv.repository;
import org.raddatz.familienarchiv.model.OcrJob;
import org.springframework.data.jpa.repository.JpaRepository;
import java.util.UUID;
public interface OcrJobRepository extends JpaRepository<OcrJob, UUID> {
}

View File

@@ -48,6 +48,26 @@ public class AnnotationService {
return annotationRepository.save(annotation);
}
@Transactional
public DocumentAnnotation createOcrAnnotation(UUID documentId, CreateAnnotationDTO dto,
UUID userId, String fileHash,
List<List<Double>> polygon) {
DocumentAnnotation annotation = DocumentAnnotation.builder()
.documentId(documentId)
.pageNumber(dto.getPageNumber())
.x(dto.getX())
.y(dto.getY())
.width(dto.getWidth())
.height(dto.getHeight())
.color(dto.getColor())
.fileHash(fileHash)
.createdBy(userId)
.polygon(polygon)
.build();
return annotationRepository.save(annotation);
}
@Transactional
public void deleteAnnotation(UUID documentId, UUID annotationId, UUID userId) {
DocumentAnnotation annotation = annotationRepository

View File

@@ -8,6 +8,7 @@ import org.raddatz.familienarchiv.dto.IncompleteDocumentDTO;
import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.dto.DocumentSort;
import org.raddatz.familienarchiv.model.DocumentStatus;
import org.raddatz.familienarchiv.model.ScriptType;
import org.raddatz.familienarchiv.model.Person;
import org.raddatz.familienarchiv.model.Tag;
import org.raddatz.familienarchiv.repository.DocumentRepository;
@@ -222,6 +223,10 @@ public class DocumentService {
doc.setMetadataComplete(dto.getMetadataComplete());
}
if (dto.getScriptType() != null) {
doc.setScriptType(dto.getScriptType());
}
// 4. Datei austauschen (nur wenn eine neue ausgewählt wurde)
if (newFile != null && !newFile.isEmpty()) {
FileService.UploadResult upload = fileService.uploadFile(newFile, newFile.getOriginalFilename());
@@ -373,6 +378,13 @@ public class DocumentService {
return documentRepository.findAll(conversation, Sort.by(Sort.Direction.ASC, "documentDate"));
}
@Transactional
public void updateScriptType(UUID documentId, ScriptType scriptType) {
Document doc = getDocumentById(documentId);
doc.setScriptType(scriptType);
documentRepository.save(doc);
}
public Document getDocumentById(UUID id) {
return documentRepository.findById(id)
.orElseThrow(() -> DomainException.notFound(ErrorCode.DOCUMENT_NOT_FOUND, "Document not found: " + id));

View File

@@ -4,6 +4,8 @@ import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.*;
import software.amazon.awssdk.services.s3.presigner.S3Presigner;
import software.amazon.awssdk.services.s3.presigner.model.GetObjectPresignRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -16,6 +18,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.util.UUID;
@Service
@@ -24,10 +27,13 @@ public class FileService {
private static final Logger log = LoggerFactory.getLogger(FileService.class);
private final S3Client s3Client;
private final S3Presigner s3Presigner;
private final String bucketName;
public FileService(S3Client s3Client, @Value("${app.s3.bucket}") String bucketName) {
public FileService(S3Client s3Client, S3Presigner s3Presigner,
@Value("${app.s3.bucket}") String bucketName) {
this.s3Client = s3Client;
this.s3Presigner = s3Presigner;
this.bucketName = bucketName;
}
@@ -106,6 +112,24 @@ public class FileService {
}
}
/**
* Generates a presigned URL for downloading an object from S3/MinIO.
* Valid for 15 minutes — enough for OCR processing on CPU.
*/
public String generatePresignedUrl(String s3Key) {
GetObjectRequest getObjectRequest = GetObjectRequest.builder()
.bucket(bucketName)
.key(s3Key)
.build();
GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder()
.signatureDuration(Duration.ofMinutes(15))
.getObjectRequest(getObjectRequest)
.build();
return s3Presigner.presignGetObject(presignRequest).url().toString();
}
// ─── private helpers ──────────────────────────────────────────────────────
private static String sha256Hex(byte[] bytes) {

View File

@@ -0,0 +1,217 @@
package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.dto.CreateAnnotationDTO;
import org.raddatz.familienarchiv.model.*;
import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository;
import org.raddatz.familienarchiv.repository.OcrJobRepository;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
@Component
@RequiredArgsConstructor
@Slf4j
public class OcrAsyncRunner {
private static final String OCR_ANNOTATION_COLOR = "#00C7B1";
private final OcrClient ocrClient;
private final DocumentService documentService;
private final TranscriptionService transcriptionService;
private final AnnotationService annotationService;
private final FileService fileService;
private final OcrJobRepository ocrJobRepository;
private final OcrJobDocumentRepository ocrJobDocumentRepository;
private final OcrProgressService ocrProgressService;
@Async
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
OcrJob job = ocrJobRepository.findById(jobId).orElse(null);
if (job == null) return;
job.setStatus(OcrJobStatus.RUNNING);
updateProgress(job, "PREPARING");
OcrJobDocument jobDoc = ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, documentId)
.orElse(null);
if (jobDoc != null) {
jobDoc.setStatus(OcrDocumentStatus.RUNNING);
ocrJobDocumentRepository.save(jobDoc);
}
Document doc = documentService.getDocumentById(documentId);
try {
updateProgress(job, "LOADING");
clearExistingBlocks(documentId);
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
AtomicInteger blockCounter = new AtomicInteger(0);
AtomicInteger currentPage = new AtomicInteger(0);
AtomicInteger skippedPages = new AtomicInteger(0);
AtomicInteger totalPages = new AtomicInteger(0);
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), event -> {
switch (event) {
case OcrStreamEvent.Start start -> {
totalPages.set(start.totalPages());
if (jobDoc != null) {
jobDoc.setTotalPages(start.totalPages());
ocrJobDocumentRepository.save(jobDoc);
}
}
case OcrStreamEvent.Page page -> {
for (OcrBlockResult block : page.blocks()) {
createSingleBlock(documentId, block, userId,
doc.getFileHash(), blockCounter.getAndIncrement());
}
currentPage.incrementAndGet();
if (jobDoc != null) {
jobDoc.setCurrentPage(currentPage.get());
ocrJobDocumentRepository.save(jobDoc);
}
updateProgress(job, "ANALYZING_PAGE:" + currentPage.get()
+ ":" + totalPages.get() + ":" + blockCounter.get());
}
case OcrStreamEvent.Error error -> {
log.warn("OCR page {} failed for document {}: {}",
error.pageNumber(), documentId, error.message());
skippedPages.incrementAndGet();
currentPage.incrementAndGet();
if (jobDoc != null) {
jobDoc.setCurrentPage(currentPage.get());
ocrJobDocumentRepository.save(jobDoc);
}
}
case OcrStreamEvent.Done done -> {
if (jobDoc != null) {
jobDoc.setCurrentPage(totalPages.get());
ocrJobDocumentRepository.save(jobDoc);
}
}
}
});
job.setStatus(OcrJobStatus.DONE);
job.setProcessedDocuments(1);
updateProgress(job, "DONE:" + blockCounter.get() + ":" + skippedPages.get());
if (jobDoc != null) {
jobDoc.setStatus(OcrDocumentStatus.DONE);
ocrJobDocumentRepository.save(jobDoc);
}
} catch (Exception e) {
log.error("OCR processing failed for document {}", documentId, e);
job.setStatus(OcrJobStatus.FAILED);
job.setErrorCount(1);
updateProgress(job, "ERROR");
if (jobDoc != null) {
jobDoc.setStatus(OcrDocumentStatus.FAILED);
jobDoc.setErrorMessage(e.getMessage());
ocrJobDocumentRepository.save(jobDoc);
}
}
}
private void updateProgress(OcrJob job, String message) {
job.setProgressMessage(message);
ocrJobRepository.save(job);
}
@Async
public void runBatch(UUID jobId, UUID userId) {
OcrJob job = ocrJobRepository.findById(jobId).orElse(null);
if (job == null) return;
job.setStatus(OcrJobStatus.RUNNING);
ocrJobRepository.save(job);
List<OcrJobDocument> jobDocs = ocrJobDocumentRepository.findByJobIdOrderByCreatedAtAsc(jobId);
for (OcrJobDocument jobDoc : jobDocs) {
Document doc = documentService.getDocumentById(jobDoc.getDocumentId());
if (doc.getStatus() == DocumentStatus.PLACEHOLDER) {
jobDoc.setStatus(OcrDocumentStatus.SKIPPED);
ocrJobDocumentRepository.save(jobDoc);
job.setSkippedCount(job.getSkippedCount() + 1);
ocrJobRepository.save(job);
ocrProgressService.emit(jobId, "document", Map.of(
"documentId", jobDoc.getDocumentId(),
"status", "SKIPPED",
"processed", job.getProcessedDocuments(),
"total", job.getTotalDocuments()));
continue;
}
jobDoc.setStatus(OcrDocumentStatus.RUNNING);
ocrJobDocumentRepository.save(jobDoc);
try {
processDocument(jobDoc.getDocumentId(), doc, userId);
jobDoc.setStatus(OcrDocumentStatus.DONE);
job.setProcessedDocuments(job.getProcessedDocuments() + 1);
} catch (Exception e) {
log.error("OCR batch: failed document {}", jobDoc.getDocumentId(), e);
jobDoc.setStatus(OcrDocumentStatus.FAILED);
jobDoc.setErrorMessage(e.getMessage());
job.setErrorCount(job.getErrorCount() + 1);
}
ocrJobDocumentRepository.save(jobDoc);
ocrJobRepository.save(job);
ocrProgressService.emit(jobId, "document", Map.of(
"documentId", jobDoc.getDocumentId(),
"status", jobDoc.getStatus().name(),
"processed", job.getProcessedDocuments(),
"total", job.getTotalDocuments()));
}
job.setStatus(OcrJobStatus.DONE);
ocrJobRepository.save(job);
ocrProgressService.emit(jobId, "done", Map.of(
"processed", job.getProcessedDocuments(),
"errors", job.getErrorCount(),
"skipped", job.getSkippedCount()));
ocrProgressService.complete(jobId);
}
void processDocument(UUID documentId, Document doc, UUID userId) {
clearExistingBlocks(documentId);
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
List<OcrBlockResult> blocks = ocrClient.extractBlocks(pdfUrl, doc.getScriptType());
createTranscriptionBlocks(documentId, blocks, userId, doc.getFileHash());
}
private void clearExistingBlocks(UUID documentId) {
transcriptionService.deleteAllBlocksByDocument(documentId);
}
private void createTranscriptionBlocks(UUID documentId, List<OcrBlockResult> blocks,
UUID userId, String fileHash) {
for (int i = 0; i < blocks.size(); i++) {
createSingleBlock(documentId, blocks.get(i), userId, fileHash, i);
}
}
void createSingleBlock(UUID documentId, OcrBlockResult block,
UUID userId, String fileHash, int sortOrder) {
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
block.pageNumber(), block.x(), block.y(),
block.width(), block.height(), OCR_ANNOTATION_COLOR);
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
documentId, annotationDTO, userId, fileHash, block.polygon());
transcriptionService.createOcrBlock(documentId, annotation.getId(),
block.text(), sortOrder, userId);
}
}

View File

@@ -0,0 +1,50 @@
package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.*;
import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository;
import org.raddatz.familienarchiv.repository.OcrJobRepository;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.UUID;
@Service
@RequiredArgsConstructor
@Slf4j
public class OcrBatchService {
private final OcrHealthClient ocrHealthClient;
private final OcrJobRepository ocrJobRepository;
private final OcrJobDocumentRepository ocrJobDocumentRepository;
private final OcrAsyncRunner ocrAsyncRunner;
public UUID startBatch(List<UUID> documentIds, UUID userId) {
if (!ocrHealthClient.isHealthy()) {
throw DomainException.internal(ErrorCode.OCR_SERVICE_UNAVAILABLE,
"OCR service is not available");
}
OcrJob job = OcrJob.builder()
.totalDocuments(documentIds.size())
.createdBy(userId)
.status(OcrJobStatus.PENDING)
.build();
job = ocrJobRepository.save(job);
for (UUID docId : documentIds) {
OcrJobDocument jobDoc = OcrJobDocument.builder()
.jobId(job.getId())
.documentId(docId)
.status(OcrDocumentStatus.PENDING)
.build();
ocrJobDocumentRepository.save(jobDoc);
}
ocrAsyncRunner.runBatch(job.getId(), userId);
return job.getId();
}
}

View File

@@ -0,0 +1,16 @@
package org.raddatz.familienarchiv.service;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import java.util.List;
@JsonIgnoreProperties(ignoreUnknown = true)
public record OcrBlockResult(
int pageNumber,
double x,
double y,
double width,
double height,
List<List<Double>> polygon,
String text
) {}

View File

@@ -0,0 +1,35 @@
package org.raddatz.familienarchiv.service;
import org.raddatz.familienarchiv.model.ScriptType;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.function.Consumer;
public interface OcrClient {
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
/**
* Stream OCR results page-by-page via NDJSON. Implementations should override
* this method. The default exists only for backward compatibility during migration
* — it calls extractBlocks() and synthesizes events from the collected result.
*/
default void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();
for (OcrBlockResult block : allBlocks) {
byPage.computeIfAbsent(block.pageNumber(), k -> new ArrayList<>()).add(block);
}
int totalPages = byPage.isEmpty() ? 0 : byPage.keySet().stream().mapToInt(i -> i).max().orElse(0) + 1;
handler.accept(new OcrStreamEvent.Start(totalPages));
for (var entry : byPage.entrySet()) {
handler.accept(new OcrStreamEvent.Page(entry.getKey(), entry.getValue()));
}
handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0));
}
}

View File

@@ -0,0 +1,5 @@
package org.raddatz.familienarchiv.service;
public interface OcrHealthClient {
boolean isHealthy();
}

View File

@@ -0,0 +1,69 @@
package org.raddatz.familienarchiv.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
@Service
@Slf4j
public class OcrProgressService {
private static final long SSE_TIMEOUT = 5 * 60 * 1000L;
private final ConcurrentHashMap<UUID, List<SseEmitter>> emitters = new ConcurrentHashMap<>();
public SseEmitter register(UUID jobId) {
SseEmitter emitter = new SseEmitter(SSE_TIMEOUT);
emitters.computeIfAbsent(jobId, k -> new CopyOnWriteArrayList<>()).add(emitter);
emitter.onCompletion(() -> removeEmitter(jobId, emitter));
emitter.onTimeout(() -> removeEmitter(jobId, emitter));
emitter.onError(e -> removeEmitter(jobId, emitter));
return emitter;
}
public void emit(UUID jobId, String eventType, Object data) {
List<SseEmitter> jobEmitters = emitters.get(jobId);
if (jobEmitters == null) return;
for (SseEmitter emitter : jobEmitters) {
try {
emitter.send(SseEmitter.event().name(eventType).data(data));
} catch (IOException e) {
log.debug("SSE send failed for job {} — removing emitter", jobId);
removeEmitter(jobId, emitter);
}
}
}
public void complete(UUID jobId) {
List<SseEmitter> jobEmitters = emitters.remove(jobId);
if (jobEmitters == null) return;
for (SseEmitter emitter : jobEmitters) {
try {
emitter.complete();
} catch (Exception e) {
log.debug("SSE complete failed for job {}", jobId);
}
}
}
private void removeEmitter(UUID jobId, SseEmitter emitter) {
List<SseEmitter> jobEmitters = emitters.get(jobId);
if (jobEmitters != null) {
jobEmitters.remove(emitter);
if (jobEmitters.isEmpty()) {
emitters.remove(jobId);
}
}
}
}

View File

@@ -0,0 +1,88 @@
package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.dto.OcrStatusDTO;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.*;
import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository;
import org.raddatz.familienarchiv.repository.OcrJobRepository;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
@Service
@RequiredArgsConstructor
@Slf4j
public class OcrService {
private final OcrHealthClient ocrHealthClient;
private final DocumentService documentService;
private final OcrJobRepository ocrJobRepository;
private final OcrJobDocumentRepository ocrJobDocumentRepository;
private final OcrAsyncRunner ocrAsyncRunner;
public OcrJob getJob(UUID jobId) {
return ocrJobRepository.findById(jobId)
.orElseThrow(() -> DomainException.notFound(
ErrorCode.OCR_JOB_NOT_FOUND, "OCR job not found: " + jobId));
}
public OcrStatusDTO getDocumentOcrStatus(UUID documentId) {
List<OcrDocumentStatus> activeStatuses = List.of(
OcrDocumentStatus.PENDING, OcrDocumentStatus.RUNNING);
Optional<OcrJobDocument> activeJobDoc = ocrJobDocumentRepository
.findFirstByDocumentIdAndStatusIn(documentId, activeStatuses);
if (activeJobDoc.isEmpty()) {
return OcrStatusDTO.builder().status("NONE").build();
}
OcrJobDocument jobDoc = activeJobDoc.get();
return OcrStatusDTO.builder()
.status(jobDoc.getStatus().name())
.jobId(jobDoc.getJobId())
.currentPage(jobDoc.getCurrentPage())
.totalPages(jobDoc.getTotalPages())
.build();
}
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId) {
Document doc = documentService.getDocumentById(documentId);
if (doc.getStatus() == DocumentStatus.PLACEHOLDER) {
throw DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED,
"Document has no file attached: " + documentId);
}
if (!ocrHealthClient.isHealthy()) {
throw DomainException.internal(ErrorCode.OCR_SERVICE_UNAVAILABLE,
"OCR service is not available");
}
if (scriptTypeOverride != null) {
documentService.updateScriptType(documentId, scriptTypeOverride);
}
OcrJob job = OcrJob.builder()
.totalDocuments(1)
.createdBy(userId)
.status(OcrJobStatus.PENDING)
.build();
job = ocrJobRepository.save(job);
OcrJobDocument jobDoc = OcrJobDocument.builder()
.jobId(job.getId())
.documentId(documentId)
.status(OcrDocumentStatus.PENDING)
.build();
ocrJobDocumentRepository.save(jobDoc);
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId);
return job.getId();
}
}

View File

@@ -0,0 +1,14 @@
package org.raddatz.familienarchiv.service;
import java.util.List;
public sealed interface OcrStreamEvent {
record Start(int totalPages) implements OcrStreamEvent {}
record Page(int pageNumber, List<OcrBlockResult> blocks) implements OcrStreamEvent {}
record Error(int pageNumber, String message) implements OcrStreamEvent {}
record Done(int totalBlocks, int skippedPages) implements OcrStreamEvent {}
}

View File

@@ -0,0 +1,187 @@
package org.raddatz.familienarchiv.service;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.model.ScriptType;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.MediaType;
import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestClient;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
@Component
@Slf4j
public class RestClientOcrClient implements OcrClient, OcrHealthClient {
private static final ObjectMapper NDJSON_MAPPER = new ObjectMapper()
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true);
private final RestClient restClient;
private final HttpClient streamingHttpClient;
private final String baseUrl;
public RestClientOcrClient(@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl) {
this.baseUrl = baseUrl;
HttpClient httpClient = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
.connectTimeout(Duration.ofSeconds(10))
.build();
JdkClientHttpRequestFactory requestFactory = new JdkClientHttpRequestFactory(httpClient);
requestFactory.setReadTimeout(Duration.ofMinutes(10));
this.restClient = RestClient.builder()
.baseUrl(baseUrl)
.requestFactory(requestFactory)
.build();
this.streamingHttpClient = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
.connectTimeout(Duration.ofSeconds(10))
.build();
}
@Override
public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
Map<String, String> body = Map.of(
"pdfUrl", pdfUrl,
"scriptType", scriptType.name(),
"language", "de");
List<OcrBlockJson> response = restClient.post()
.uri("/ocr")
.contentType(MediaType.APPLICATION_JSON)
.body(body)
.retrieve()
.body(new ParameterizedTypeReference<>() {});
if (response == null) return List.of();
return response.stream()
.map(OcrBlockJson::toResult)
.toList();
}
@Override
public boolean isHealthy() {
try {
restClient.get()
.uri("/health")
.retrieve()
.toBodilessEntity();
return true;
} catch (Exception e) {
log.warn("OCR service health check failed: {}", e.getMessage());
return false;
}
}
@Override
public void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
String body;
try {
body = NDJSON_MAPPER.writeValueAsString(Map.of(
"pdfUrl", pdfUrl,
"scriptType", scriptType.name(),
"language", "de"));
} catch (IOException e) {
throw new RuntimeException("Failed to serialize OCR request", e);
}
HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create(baseUrl + "/ocr/stream"))
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(body))
.timeout(Duration.ofMinutes(5))
.build();
try {
HttpResponse<InputStream> response = streamingHttpClient.send(
request, HttpResponse.BodyHandlers.ofInputStream());
if (response.statusCode() == 404) {
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
OcrClient.super.streamBlocks(pdfUrl, scriptType, handler);
return;
}
try (InputStream inputStream = response.body()) {
parseNdjsonStream(inputStream, handler);
}
} catch (IOException | InterruptedException e) {
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
throw new RuntimeException("NDJSON stream failed: " + e.getMessage(), e);
}
}
static void parseNdjsonStream(InputStream inputStream, Consumer<OcrStreamEvent> handler) {
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
if (line.isBlank()) continue;
JsonNode node = NDJSON_MAPPER.readTree(line);
String type = node.path("type").asText();
switch (type) {
case "start" -> handler.accept(
new OcrStreamEvent.Start(node.path("totalPages").asInt()));
case "page" -> {
int pageNumber = node.path("pageNumber").asInt();
List<OcrBlockResult> blocks = NDJSON_MAPPER.convertValue(
node.path("blocks"),
new TypeReference<>() {});
handler.accept(new OcrStreamEvent.Page(pageNumber, blocks));
}
case "error" -> handler.accept(
new OcrStreamEvent.Error(
node.path("pageNumber").asInt(),
node.path("message").asText()));
case "done" -> handler.accept(
new OcrStreamEvent.Done(
node.path("totalBlocks").asInt(),
node.path("skippedPages").asInt()));
default -> log.debug("Ignoring unknown NDJSON event type: {}", type);
}
}
} catch (IOException e) {
throw new RuntimeException("Failed to parse NDJSON stream: " + e.getMessage(), e);
}
}
record OcrBlockJson(
@JsonProperty("pageNumber") int pageNumber,
double x,
double y,
double width,
double height,
List<List<Double>> polygon,
String text
) {
OcrBlockResult toResult() {
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text);
}
}
}

View File

@@ -8,6 +8,7 @@ import org.raddatz.familienarchiv.dto.ReorderTranscriptionBlocksDTO;
import org.raddatz.familienarchiv.dto.UpdateTranscriptionBlockDTO;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.BlockSource;
import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.model.DocumentAnnotation;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
@@ -75,6 +76,24 @@ public class TranscriptionService {
return saved;
}
@Transactional
public TranscriptionBlock createOcrBlock(UUID documentId, UUID annotationId,
String text, int sortOrder, UUID userId) {
String sanitized = sanitizeText(text);
TranscriptionBlock block = TranscriptionBlock.builder()
.annotationId(annotationId)
.documentId(documentId)
.text(sanitized)
.sortOrder(sortOrder)
.source(BlockSource.OCR)
.createdBy(userId)
.updatedBy(userId)
.build();
TranscriptionBlock saved = blockRepository.save(block);
saveVersion(saved, userId);
return saved;
}
@Transactional
public TranscriptionBlock updateBlock(UUID documentId, UUID blockId,
UpdateTranscriptionBlockDTO dto, UUID userId) {
@@ -106,6 +125,21 @@ public class TranscriptionService {
blockId, annotationId, documentId);
}
@Transactional
public void deleteAllBlocksByDocument(UUID documentId) {
List<TranscriptionBlock> blocks = blockRepository.findByDocumentIdOrderBySortOrderAsc(documentId);
if (blocks.isEmpty()) return;
List<UUID> annotationIds = blocks.stream()
.map(TranscriptionBlock::getAnnotationId)
.toList();
blockRepository.deleteAll(blocks);
blockRepository.flush();
annotationRepository.deleteAllById(annotationIds);
log.info("Bulk-deleted {} transcription blocks for document {}", blocks.size(), documentId);
}
@Transactional
public void reorderBlocks(UUID documentId, ReorderTranscriptionBlocksDTO dto) {
List<UUID> blockIds = dto.getBlockIds();
@@ -116,6 +150,13 @@ public class TranscriptionService {
}
}
@Transactional
public TranscriptionBlock reviewBlock(UUID documentId, UUID blockId) {
TranscriptionBlock block = getBlock(documentId, blockId);
block.setReviewed(!block.isReviewed());
return blockRepository.save(block);
}
public List<TranscriptionBlockVersion> getBlockHistory(UUID documentId, UUID blockId) {
getBlock(documentId, blockId);
return versionRepository.findByBlockIdOrderByChangedAtDesc(blockId);

View File

@@ -0,0 +1,8 @@
-- Add optional polygon field for quadrilateral annotation shapes (Kraken OCR output).
-- See ADR-002 for the design decision.
ALTER TABLE document_annotations ADD COLUMN polygon JSONB;
ALTER TABLE document_annotations
ADD CONSTRAINT chk_annotation_polygon_quad
CHECK (polygon IS NULL OR jsonb_array_length(polygon) = 4);

View File

@@ -0,0 +1 @@
ALTER TABLE documents ADD COLUMN script_type VARCHAR(30) NOT NULL DEFAULT 'UNKNOWN';

View File

@@ -0,0 +1,26 @@
CREATE TABLE ocr_jobs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
status VARCHAR(20) NOT NULL DEFAULT 'PENDING',
total_documents INT NOT NULL,
processed_documents INT NOT NULL DEFAULT 0,
error_count INT NOT NULL DEFAULT 0,
skipped_count INT NOT NULL DEFAULT 0,
created_by UUID,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE TABLE ocr_job_documents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
job_id UUID NOT NULL REFERENCES ocr_jobs(id) ON DELETE CASCADE,
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
status VARCHAR(20) NOT NULL DEFAULT 'PENDING',
error_message TEXT,
current_page INT DEFAULT 0,
total_pages INT DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX idx_ocr_job_documents_job_id ON ocr_job_documents(job_id);
CREATE INDEX idx_ocr_job_documents_document_id ON ocr_job_documents(document_id);

View File

@@ -0,0 +1,2 @@
ALTER TABLE transcription_blocks ADD COLUMN source VARCHAR(10) NOT NULL DEFAULT 'MANUAL';
ALTER TABLE transcription_blocks ADD COLUMN reviewed BOOLEAN NOT NULL DEFAULT FALSE;

View File

@@ -0,0 +1 @@
ALTER TABLE ocr_jobs ADD COLUMN progress_message TEXT;

View File

@@ -0,0 +1,2 @@
ALTER TABLE ocr_job_documents
ADD CONSTRAINT uq_ocr_job_document UNIQUE (job_id, document_id);

View File

@@ -0,0 +1,135 @@
package org.raddatz.familienarchiv.controller;
import tools.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.config.SecurityConfig;
import org.raddatz.familienarchiv.dto.BatchOcrDTO;
import org.raddatz.familienarchiv.dto.OcrStatusDTO;
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.*;
import org.raddatz.familienarchiv.security.PermissionAspect;
import org.raddatz.familienarchiv.service.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.aop.AopAutoConfiguration;
import org.springframework.boot.webmvc.test.autoconfigure.WebMvcTest;
import org.springframework.context.annotation.Import;
import org.springframework.http.MediaType;
import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.test.context.bean.override.mockito.MockitoBean;
import org.springframework.test.web.servlet.MockMvc;
import java.util.List;
import java.util.UUID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
@WebMvcTest(OcrController.class)
@Import({SecurityConfig.class, PermissionAspect.class, AopAutoConfiguration.class})
class OcrControllerTest {
@Autowired MockMvc mockMvc;
private final ObjectMapper objectMapper = new ObjectMapper();
@MockitoBean OcrService ocrService;
@MockitoBean OcrBatchService ocrBatchService;
@MockitoBean OcrProgressService ocrProgressService;
@MockitoBean UserService userService;
@MockitoBean CustomUserDetailsService customUserDetailsService;
@Test
@WithMockUser(authorities = "WRITE_ALL")
void triggerOcr_returns202_withJobId() throws Exception {
UUID docId = UUID.randomUUID();
UUID jobId = UUID.randomUUID();
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER);
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any())).thenReturn(jobId);
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
.contentType(MediaType.APPLICATION_JSON)
.content(objectMapper.writeValueAsString(dto)))
.andExpect(status().isAccepted())
.andExpect(jsonPath("$.jobId").value(jobId.toString()));
}
@Test
@WithMockUser(authorities = "WRITE_ALL")
void triggerOcr_returns400_whenDocumentNotUploaded() throws Exception {
UUID docId = UUID.randomUUID();
when(ocrService.startOcr(eq(docId), any(), any()))
.thenThrow(DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED, "Not uploaded"));
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
.contentType(MediaType.APPLICATION_JSON)
.content("{}"))
.andExpect(status().isBadRequest());
}
@Test
@WithMockUser(authorities = "READ_ALL")
void getJobStatus_returns404_whenJobNotFound() throws Exception {
UUID jobId = UUID.randomUUID();
when(ocrService.getJob(jobId))
.thenThrow(DomainException.notFound(ErrorCode.OCR_JOB_NOT_FOUND, "OCR job not found"));
mockMvc.perform(get("/api/ocr/jobs/{jobId}", jobId))
.andExpect(status().isNotFound());
}
@Test
@WithMockUser(authorities = "READ_ALL")
void getJobStatus_returnsJobInfo_whenFound() throws Exception {
UUID jobId = UUID.randomUUID();
OcrJob job = OcrJob.builder()
.id(jobId)
.status(OcrJobStatus.RUNNING)
.totalDocuments(10)
.processedDocuments(3)
.errorCount(1)
.skippedCount(0)
.build();
when(ocrService.getJob(jobId)).thenReturn(job);
mockMvc.perform(get("/api/ocr/jobs/{jobId}", jobId))
.andExpect(status().isOk())
.andExpect(jsonPath("$.status").value("RUNNING"))
.andExpect(jsonPath("$.totalDocuments").value(10))
.andExpect(jsonPath("$.processedDocuments").value(3));
}
@Test
@WithMockUser(authorities = "ADMIN")
void triggerBatch_returns202_withJobId() throws Exception {
UUID jobId = UUID.randomUUID();
List<UUID> docIds = List.of(UUID.randomUUID(), UUID.randomUUID());
BatchOcrDTO dto = new BatchOcrDTO(docIds);
when(ocrBatchService.startBatch(eq(docIds), any())).thenReturn(jobId);
mockMvc.perform(post("/api/ocr/batch")
.contentType(MediaType.APPLICATION_JSON)
.content(objectMapper.writeValueAsString(dto)))
.andExpect(status().isAccepted())
.andExpect(jsonPath("$.jobId").value(jobId.toString()));
}
@Test
@WithMockUser(authorities = "READ_ALL")
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {
UUID docId = UUID.randomUUID();
when(ocrService.getDocumentOcrStatus(docId))
.thenReturn(OcrStatusDTO.builder().status("NONE").build());
mockMvc.perform(get("/api/documents/{id}/ocr-status", docId))
.andExpect(status().isOk())
.andExpect(jsonPath("$.status").value("NONE"));
}
}

View File

@@ -356,4 +356,20 @@ class TranscriptionBlockControllerTest {
.andExpect(status().isOk())
.andExpect(jsonPath("$").isEmpty());
}
// ─── PUT .../review ──────────────────────────────────────────────────────
@Test
@WithMockUser(authorities = "WRITE_ALL")
void reviewBlock_returns200_withToggledBlock() throws Exception {
TranscriptionBlock reviewed = TranscriptionBlock.builder()
.id(BLOCK_ID).documentId(DOC_ID).annotationId(UUID.randomUUID())
.text("text").sortOrder(0).reviewed(true).build();
when(transcriptionService.reviewBlock(DOC_ID, BLOCK_ID)).thenReturn(reviewed);
mockMvc.perform(put("/api/documents/{documentId}/transcription-blocks/{blockId}/review",
DOC_ID, BLOCK_ID))
.andExpect(status().isOk())
.andExpect(jsonPath("$.reviewed").value(true));
}
}

View File

@@ -0,0 +1,124 @@
package org.raddatz.familienarchiv.dto;
import jakarta.validation.ConstraintViolation;
import jakarta.validation.Validation;
import jakarta.validation.Validator;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat;
class UniquePointsValidatorTest {
private final Validator validator = Validation.buildDefaultValidatorFactory().getValidator();
@Test
void shouldAcceptNull() {
var dto = new CreateAnnotationDTO();
dto.setPolygon(null);
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
assertThat(violations).noneMatch(v -> v.getPropertyPath().toString().equals("polygon"));
}
@Test
void shouldAcceptFourUniquePoints() {
var dto = new CreateAnnotationDTO();
dto.setPolygon(List.of(
List.of(0.1, 0.1),
List.of(0.9, 0.1),
List.of(0.9, 0.9),
List.of(0.1, 0.9)));
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
assertThat(violations).noneMatch(v -> v.getPropertyPath().toString().equals("polygon"));
}
@Test
void shouldRejectDuplicatePoints() {
var dto = new CreateAnnotationDTO();
dto.setPolygon(List.of(
List.of(0.1, 0.1),
List.of(0.1, 0.1),
List.of(0.9, 0.9),
List.of(0.1, 0.9)));
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
assertThat(violations).anyMatch(v -> v.getPropertyPath().toString().equals("polygon"));
}
@Test
void shouldRejectPolygonWithThreePoints() {
var dto = new CreateAnnotationDTO();
dto.setPolygon(List.of(
List.of(0.1, 0.1),
List.of(0.9, 0.1),
List.of(0.9, 0.9)));
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
assertThat(violations).anyMatch(v -> v.getPropertyPath().toString().equals("polygon"));
}
@Test
void shouldRejectPolygonWithFivePoints() {
var dto = new CreateAnnotationDTO();
dto.setPolygon(List.of(
List.of(0.1, 0.1),
List.of(0.5, 0.1),
List.of(0.9, 0.1),
List.of(0.9, 0.9),
List.of(0.1, 0.9)));
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
assertThat(violations).anyMatch(v -> v.getPropertyPath().toString().equals("polygon"));
}
@Test
void shouldRejectCoordinateOutOfRange() {
var dto = new CreateAnnotationDTO();
dto.setPolygon(List.of(
List.of(1.5, 0.1),
List.of(0.9, 0.1),
List.of(0.9, 0.9),
List.of(0.1, 0.9)));
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
assertThat(violations).anyMatch(v -> v.getPropertyPath().toString().contains("polygon"));
}
@Test
void shouldRejectNegativeCoordinate() {
var dto = new CreateAnnotationDTO();
dto.setPolygon(List.of(
List.of(-0.1, 0.1),
List.of(0.9, 0.1),
List.of(0.9, 0.9),
List.of(0.1, 0.9)));
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
assertThat(violations).anyMatch(v -> v.getPropertyPath().toString().contains("polygon"));
}
@Test
void shouldRejectPointWithOneCoordinate() {
var dto = new CreateAnnotationDTO();
dto.setPolygon(List.of(
List.of(0.1),
List.of(0.9, 0.1),
List.of(0.9, 0.9),
List.of(0.1, 0.9)));
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
assertThat(violations).anyMatch(v -> v.getPropertyPath().toString().contains("polygon"));
}
}

View File

@@ -0,0 +1,65 @@
package org.raddatz.familienarchiv.model;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
class PolygonConverterTest {
private final PolygonConverter converter = new PolygonConverter();
@Test
void convertToDatabaseColumn_returnsNull_whenPolygonIsNull() {
assertThat(converter.convertToDatabaseColumn(null)).isNull();
}
@Test
void convertToDatabaseColumn_returnsJsonArray_whenPolygonIsValid() {
List<List<Double>> polygon = List.of(
List.of(0.1, 0.2),
List.of(0.9, 0.2),
List.of(0.9, 0.8),
List.of(0.1, 0.8));
String json = converter.convertToDatabaseColumn(polygon);
assertThat(json).isEqualTo("[[0.1,0.2],[0.9,0.2],[0.9,0.8],[0.1,0.8]]");
}
@Test
void convertToEntityAttribute_returnsNull_whenJsonIsNull() {
assertThat(converter.convertToEntityAttribute(null)).isNull();
}
@Test
void convertToEntityAttribute_returnsNull_whenJsonIsEmpty() {
assertThat(converter.convertToEntityAttribute("")).isNull();
}
@Test
void convertToEntityAttribute_returnsPolygon_whenJsonIsValid() {
String json = "[[0.1,0.2],[0.9,0.2],[0.9,0.8],[0.1,0.8]]";
List<List<Double>> polygon = converter.convertToEntityAttribute(json);
assertThat(polygon).hasSize(4);
assertThat(polygon.get(0)).containsExactly(0.1, 0.2);
assertThat(polygon.get(3)).containsExactly(0.1, 0.8);
}
@Test
void roundTrip_preservesValues() {
List<List<Double>> original = List.of(
List.of(0.12, 0.08),
List.of(0.88, 0.09),
List.of(0.87, 0.14),
List.of(0.11, 0.13));
String json = converter.convertToDatabaseColumn(original);
List<List<Double>> restored = converter.convertToEntityAttribute(json);
assertThat(restored).isEqualTo(original);
}
}

View File

@@ -260,6 +260,55 @@ class AnnotationServiceTest {
verify(annotationRepository).save(any());
}
// ─── createOcrAnnotation ──────────────────────────────────────────────────
@Test
void createOcrAnnotation_skipsOverlapCheck_andSavesWithPolygon() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.8, 0.04, "#00C7B1");
List<List<Double>> polygon = List.of(
List.of(0.1, 0.1), List.of(0.9, 0.11),
List.of(0.89, 0.14), List.of(0.11, 0.13));
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
DocumentAnnotation result = annotationService.createOcrAnnotation(
docId, dto, userId, "filehash", polygon);
assertThat(result.getPolygon()).isEqualTo(polygon);
assertThat(result.getDocumentId()).isEqualTo(docId);
verify(annotationRepository).save(any());
verify(annotationRepository, never()).findByDocumentIdAndPageNumber(any(), any(int.class));
}
@Test
void createOcrAnnotation_savesWithNullPolygon_whenPolygonNotProvided() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.8, 0.04, "#00C7B1");
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
DocumentAnnotation result = annotationService.createOcrAnnotation(
docId, dto, userId, "filehash", null);
assertThat(result.getPolygon()).isNull();
verify(annotationRepository).save(any());
}
@Test
void createOcrAnnotation_doesNotCheckOverlap_evenWhenOverlappingAnnotationExists() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.3, 0.3, "#00C7B1");
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
annotationService.createOcrAnnotation(docId, dto, userId, "hash", null);
verify(annotationRepository, never()).findByDocumentIdAndPageNumber(any(), any(int.class));
}
// ─── overlaps — partial overlap cases ────────────────────────────────────
@Test
void createAnnotation_noConflict_whenAnnotationIsAbove() {
// x ranges overlap, y ranges don't — existing is ABOVE the new annotation

View File

@@ -32,7 +32,7 @@ class FileServiceTest {
@BeforeEach
void setUp() {
s3Client = mock(S3Client.class);
fileService = new FileService(s3Client, "test-bucket");
fileService = new FileService(s3Client, null, "test-bucket");
}
@Test

View File

@@ -0,0 +1,269 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.raddatz.familienarchiv.dto.CreateAnnotationDTO;
import org.raddatz.familienarchiv.model.*;
import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository;
import org.raddatz.familienarchiv.repository.OcrJobRepository;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Consumer;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class OcrAsyncRunnerTest {
@Mock OcrClient ocrClient;
@Mock DocumentService documentService;
@Mock TranscriptionService transcriptionService;
@Mock AnnotationService annotationService;
@Mock FileService fileService;
@Mock OcrJobRepository ocrJobRepository;
@Mock OcrJobDocumentRepository ocrJobDocumentRepository;
@Mock OcrProgressService ocrProgressService;
@InjectMocks OcrAsyncRunner ocrAsyncRunner;
@Test
void processDocument_clearsExistingBlocks() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
Document doc = Document.builder().id(docId).filePath("test.pdf")
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of());
ocrAsyncRunner.processDocument(docId, doc, userId);
verify(transcriptionService).deleteAllBlocksByDocument(docId);
}
@Test
void processDocument_createsAnnotationAndBlock_forEachResult() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
UUID annId = UUID.randomUUID();
Document doc = Document.builder().id(docId).filePath("test.pdf")
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2")));
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
ocrAsyncRunner.processDocument(docId, doc, userId);
verify(annotationService, times(2)).createOcrAnnotation(
eq(docId), any(CreateAnnotationDTO.class), eq(userId), eq("hash"), any());
verify(transcriptionService, times(2)).createOcrBlock(
eq(docId), eq(annId), any(), anyInt(), eq(userId));
}
@Test
void processDocument_delegatesBlockCreationToTranscriptionService() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
UUID annId = UUID.randomUUID();
Document doc = Document.builder().id(docId).filePath("test.pdf")
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test")));
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
ocrAsyncRunner.processDocument(docId, doc, userId);
verify(transcriptionService).createOcrBlock(docId, annId, "Test", 0, userId);
}
@Test
void runSingleDocument_setsJobDone_onSuccess() {
UUID jobId = UUID.randomUUID();
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
Document doc = Document.builder().id(docId).filePath("test.pdf")
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
.thenReturn(Optional.of(jobDoc));
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
handler.accept(new OcrStreamEvent.Start(1));
handler.accept(new OcrStreamEvent.Page(0, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
assertThat(job.getStatus()).isEqualTo(OcrJobStatus.DONE);
}
@Test
void runSingleDocument_setsJobFailed_onError() {
UUID jobId = UUID.randomUUID();
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
Document doc = Document.builder().id(docId).filePath("test.pdf")
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
.thenReturn(Optional.of(jobDoc));
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
assertThat(job.getStatus()).isEqualTo(OcrJobStatus.FAILED);
assertThat(job.getErrorCount()).isEqualTo(1);
}
@Test
void runSingleDocument_updatesProgressPerPage() {
UUID jobId = UUID.randomUUID();
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
Document doc = Document.builder().id(docId).filePath("test.pdf")
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
DocumentAnnotation ann = DocumentAnnotation.builder().id(UUID.randomUUID()).build();
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
.thenReturn(Optional.of(jobDoc));
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
List<String> progressMessages = new ArrayList<>();
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
handler.accept(new OcrStreamEvent.Start(3));
handler.accept(new OcrStreamEvent.Page(0, List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2"))));
progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Page(1, List.of(
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3"))));
progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Done(3, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
assertThat(progressMessages.get(0)).isEqualTo("ANALYZING_PAGE:1:3:2");
assertThat(progressMessages.get(1)).isEqualTo("ANALYZING_PAGE:2:3:3");
assertThat(job.getProgressMessage()).isEqualTo("DONE:3:0");
}
@Test
void runSingleDocument_includesSkippedPagesInDoneMessage() {
UUID jobId = UUID.randomUUID();
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
Document doc = Document.builder().id(docId).filePath("test.pdf")
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
.thenReturn(Optional.of(jobDoc));
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
handler.accept(new OcrStreamEvent.Start(3));
handler.accept(new OcrStreamEvent.Page(0, List.of()));
handler.accept(new OcrStreamEvent.Error(1, "failed"));
handler.accept(new OcrStreamEvent.Page(2, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 1));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
assertThat(job.getStatus()).isEqualTo(OcrJobStatus.DONE);
assertThat(job.getProgressMessage()).isEqualTo("DONE:0:1");
}
@Test
void runSingleDocument_logsStreamErrorAtWarnWithoutSettingJobFailed() {
UUID jobId = UUID.randomUUID();
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
Document doc = Document.builder().id(docId).filePath("test.pdf")
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
.thenReturn(Optional.of(jobDoc));
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
handler.accept(new OcrStreamEvent.Start(2));
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
handler.accept(new OcrStreamEvent.Page(1, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 1));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
// Job should still be DONE, not FAILED (per-page errors don't fail the whole job)
assertThat(job.getStatus()).isEqualTo(OcrJobStatus.DONE);
// Raw error message should not leak to progress
assertThat(job.getProgressMessage()).doesNotContain("python traceback");
}
}

View File

@@ -0,0 +1,80 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.*;
import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository;
import org.raddatz.familienarchiv.repository.OcrJobRepository;
import java.util.List;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class OcrBatchServiceTest {
@Mock OcrHealthClient ocrHealthClient;
@Mock OcrJobRepository ocrJobRepository;
@Mock OcrJobDocumentRepository ocrJobDocumentRepository;
@Mock OcrAsyncRunner ocrAsyncRunner;
@InjectMocks OcrBatchService ocrBatchService;
@Test
void startBatch_throwsServiceUnavailable_whenOcrServiceIsDown() {
when(ocrHealthClient.isHealthy()).thenReturn(false);
assertThatThrownBy(() -> ocrBatchService.startBatch(List.of(UUID.randomUUID()), UUID.randomUUID()))
.isInstanceOf(DomainException.class)
.satisfies(e -> assertThat(((DomainException) e).getCode())
.isEqualTo(ErrorCode.OCR_SERVICE_UNAVAILABLE));
}
@Test
void startBatch_createsJobAndDispatchesAsync() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
UUID jobId = UUID.randomUUID();
when(ocrHealthClient.isHealthy()).thenReturn(true);
when(ocrJobRepository.save(any())).thenAnswer(inv -> {
OcrJob job = inv.getArgument(0);
job.setId(jobId);
return job;
});
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
UUID result = ocrBatchService.startBatch(List.of(docId), userId);
assertThat(result).isEqualTo(jobId);
verify(ocrAsyncRunner).runBatch(jobId, userId);
}
@Test
void startBatch_createsJobDocumentForEachId() {
UUID doc1 = UUID.randomUUID();
UUID doc2 = UUID.randomUUID();
UUID userId = UUID.randomUUID();
when(ocrHealthClient.isHealthy()).thenReturn(true);
when(ocrJobRepository.save(any())).thenAnswer(inv -> {
OcrJob job = inv.getArgument(0);
job.setId(UUID.randomUUID());
return job;
});
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
ocrBatchService.startBatch(List.of(doc1, doc2), userId);
verify(ocrJobDocumentRepository, times(2)).save(any());
}
}

View File

@@ -0,0 +1,55 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.model.ScriptType;
import java.util.ArrayList;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
class OcrClientDefaultStreamTest {
@Test
void defaultStreamBlocksSynthesizesEventsFromExtractBlocks() {
OcrClient client = (pdfUrl, scriptType) -> List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
List<OcrStreamEvent> events = new ArrayList<>();
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
assertThat(events).hasSize(4);
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
assertThat(((OcrStreamEvent.Start) events.get(0)).totalPages()).isEqualTo(2);
assertThat(events.get(1)).isInstanceOf(OcrStreamEvent.Page.class);
var page0 = (OcrStreamEvent.Page) events.get(1);
assertThat(page0.pageNumber()).isEqualTo(0);
assertThat(page0.blocks()).hasSize(2);
assertThat(events.get(2)).isInstanceOf(OcrStreamEvent.Page.class);
var page1 = (OcrStreamEvent.Page) events.get(2);
assertThat(page1.pageNumber()).isEqualTo(1);
assertThat(page1.blocks()).hasSize(1);
assertThat(events.get(3)).isInstanceOf(OcrStreamEvent.Done.class);
var done = (OcrStreamEvent.Done) events.get(3);
assertThat(done.totalBlocks()).isEqualTo(3);
assertThat(done.skippedPages()).isEqualTo(0);
}
@Test
void defaultStreamBlocksHandlesEmptyResults() {
OcrClient client = (pdfUrl, scriptType) -> List.of();
List<OcrStreamEvent> events = new ArrayList<>();
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
assertThat(events).hasSize(2);
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
assertThat(((OcrStreamEvent.Start) events.get(0)).totalPages()).isEqualTo(0);
assertThat(events.get(1)).isInstanceOf(OcrStreamEvent.Done.class);
}
}

View File

@@ -0,0 +1,33 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.Test;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
class OcrProgressServiceTest {
private final OcrProgressService progressService = new OcrProgressService();
@Test
void register_returnsNonNullEmitter() {
UUID jobId = UUID.randomUUID();
SseEmitter emitter = progressService.register(jobId);
assertThat(emitter).isNotNull();
}
@Test
void emit_doesNotThrow_whenNoEmittersRegistered() {
assertThatCode(() -> progressService.emit(UUID.randomUUID(), "test", "data"))
.doesNotThrowAnyException();
}
@Test
void complete_doesNotThrow_whenNoEmittersRegistered() {
assertThatCode(() -> progressService.complete(UUID.randomUUID()))
.doesNotThrowAnyException();
}
}

View File

@@ -0,0 +1,165 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.raddatz.familienarchiv.dto.OcrStatusDTO;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.*;
import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository;
import org.raddatz.familienarchiv.repository.OcrJobRepository;
import java.util.Optional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
import static org.springframework.http.HttpStatus.*;
@ExtendWith(MockitoExtension.class)
class OcrServiceTest {
@Mock OcrHealthClient ocrHealthClient;
@Mock DocumentService documentService;
@Mock OcrJobRepository ocrJobRepository;
@Mock OcrJobDocumentRepository ocrJobDocumentRepository;
@Mock OcrAsyncRunner ocrAsyncRunner;
@InjectMocks OcrService ocrService;
// ─── getJob ──────────────────────────────────────────────────────────────────
@Test
void getJob_returnsJob_whenFound() {
UUID jobId = UUID.randomUUID();
OcrJob job = OcrJob.builder().id(jobId).status(OcrJobStatus.RUNNING).build();
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
OcrJob result = ocrService.getJob(jobId);
assertThat(result).isEqualTo(job);
}
@Test
void getJob_throwsNotFound_whenJobDoesNotExist() {
UUID jobId = UUID.randomUUID();
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.empty());
assertThatThrownBy(() -> ocrService.getJob(jobId))
.isInstanceOf(DomainException.class)
.satisfies(e -> {
DomainException de = (DomainException) e;
assertThat(de.getStatus()).isEqualTo(NOT_FOUND);
assertThat(de.getCode()).isEqualTo(ErrorCode.OCR_JOB_NOT_FOUND);
});
}
// ─── getDocumentOcrStatus ───────────────────────────────────────────────────
@Test
void getDocumentOcrStatus_returnsNone_whenNoActiveJob() {
UUID docId = UUID.randomUUID();
when(ocrJobDocumentRepository.findFirstByDocumentIdAndStatusIn(any(), any()))
.thenReturn(Optional.empty());
OcrStatusDTO result = ocrService.getDocumentOcrStatus(docId);
assertThat(result.getStatus()).isEqualTo("NONE");
assertThat(result.getJobId()).isNull();
}
@Test
void getDocumentOcrStatus_returnsActiveStatus_whenJobExists() {
UUID docId = UUID.randomUUID();
UUID jobId = UUID.randomUUID();
OcrJobDocument jobDoc = OcrJobDocument.builder()
.jobId(jobId).documentId(docId)
.status(OcrDocumentStatus.RUNNING)
.currentPage(2).totalPages(5)
.build();
when(ocrJobDocumentRepository.findFirstByDocumentIdAndStatusIn(any(), any()))
.thenReturn(Optional.of(jobDoc));
OcrStatusDTO result = ocrService.getDocumentOcrStatus(docId);
assertThat(result.getStatus()).isEqualTo("RUNNING");
assertThat(result.getJobId()).isEqualTo(jobId);
assertThat(result.getCurrentPage()).isEqualTo(2);
assertThat(result.getTotalPages()).isEqualTo(5);
}
// ─── startOcr ───────────────────────────────────────────────────────────────
@Test
void startOcr_throwsBadRequest_whenDocumentIsPlaceholder() {
UUID docId = UUID.randomUUID();
Document doc = Document.builder().id(docId).status(DocumentStatus.PLACEHOLDER).build();
when(documentService.getDocumentById(docId)).thenReturn(doc);
assertThatThrownBy(() -> ocrService.startOcr(docId, null, UUID.randomUUID()))
.isInstanceOf(DomainException.class)
.satisfies(e -> {
DomainException de = (DomainException) e;
assertThat(de.getStatus()).isEqualTo(BAD_REQUEST);
assertThat(de.getCode()).isEqualTo(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED);
});
}
@Test
void startOcr_throwsServiceUnavailable_whenOcrServiceIsDown() {
UUID docId = UUID.randomUUID();
Document doc = Document.builder().id(docId).status(DocumentStatus.UPLOADED)
.filePath("test.pdf").build();
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(ocrHealthClient.isHealthy()).thenReturn(false);
assertThatThrownBy(() -> ocrService.startOcr(docId, null, UUID.randomUUID()))
.isInstanceOf(DomainException.class)
.satisfies(e -> assertThat(((DomainException) e).getCode())
.isEqualTo(ErrorCode.OCR_SERVICE_UNAVAILABLE));
}
@Test
void startOcr_createsJobAndDispatchesAsync() {
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
UUID jobId = UUID.randomUUID();
Document doc = Document.builder().id(docId).status(DocumentStatus.UPLOADED)
.filePath("test.pdf").scriptType(ScriptType.TYPEWRITER).build();
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(ocrHealthClient.isHealthy()).thenReturn(true);
when(ocrJobRepository.save(any())).thenAnswer(inv -> {
OcrJob job = inv.getArgument(0);
job.setId(jobId);
return job;
});
UUID result = ocrService.startOcr(docId, null, userId);
assertThat(result).isEqualTo(jobId);
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId);
}
@Test
void startOcr_updatesScriptType_whenProvided() {
UUID docId = UUID.randomUUID();
Document doc = Document.builder().id(docId).status(DocumentStatus.UPLOADED)
.filePath("test.pdf").scriptType(ScriptType.UNKNOWN).build();
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(ocrHealthClient.isHealthy()).thenReturn(true);
when(ocrJobRepository.save(any())).thenAnswer(inv -> {
OcrJob job = inv.getArgument(0);
job.setId(UUID.randomUUID());
return job;
});
ocrService.startOcr(docId, ScriptType.HANDWRITING_LATIN, UUID.randomUUID());
verify(documentService).updateScriptType(docId, ScriptType.HANDWRITING_LATIN);
}
}

View File

@@ -0,0 +1,51 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
class OcrStreamEventTest {
@Test
void startRecordHoldsTotalPages() {
var start = new OcrStreamEvent.Start(5);
assertThat(start.totalPages()).isEqualTo(5);
assertThat(start).isInstanceOf(OcrStreamEvent.class);
}
@Test
void pageRecordHoldsBlocksAndPageNumber() {
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test");
var page = new OcrStreamEvent.Page(0, List.of(block));
assertThat(page.pageNumber()).isEqualTo(0);
assertThat(page.blocks()).hasSize(1);
}
@Test
void errorRecordHoldsPageAndMessage() {
var error = new OcrStreamEvent.Error(2, "failed");
assertThat(error.pageNumber()).isEqualTo(2);
assertThat(error.message()).isEqualTo("failed");
}
@Test
void doneRecordHoldsTotalBlocksAndSkippedPages() {
var done = new OcrStreamEvent.Done(12, 2);
assertThat(done.totalBlocks()).isEqualTo(12);
assertThat(done.skippedPages()).isEqualTo(2);
}
@Test
void patternMatchingWorksOnSealedInterface() {
OcrStreamEvent event = new OcrStreamEvent.Start(3);
String result = switch (event) {
case OcrStreamEvent.Start s -> "start:" + s.totalPages();
case OcrStreamEvent.Page p -> "page:" + p.pageNumber();
case OcrStreamEvent.Error e -> "error:" + e.pageNumber();
case OcrStreamEvent.Done d -> "done:" + d.totalBlocks();
};
assertThat(result).isEqualTo("start:3");
}
}

View File

@@ -0,0 +1,134 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.Test;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
class RestClientOcrClientStreamTest {
@Test
void parseNdjsonStream_dispatchesStartPageDoneInOrder() {
String ndjson = """
{"type":"start","totalPages":2}
{"type":"page","pageNumber":0,"blocks":[{"pageNumber":0,"x":0.1,"y":0.2,"width":0.8,"height":0.1,"polygon":null,"text":"Line 1"}]}
{"type":"page","pageNumber":1,"blocks":[{"pageNumber":1,"x":0.1,"y":0.3,"width":0.8,"height":0.1,"polygon":null,"text":"Line 2"}]}
{"type":"done","totalBlocks":2,"skippedPages":0}
""";
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
List<OcrStreamEvent> events = new ArrayList<>();
RestClientOcrClient.parseNdjsonStream(stream, events::add);
assertThat(events).hasSize(4);
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
assertThat(((OcrStreamEvent.Start) events.get(0)).totalPages()).isEqualTo(2);
assertThat(events.get(1)).isInstanceOf(OcrStreamEvent.Page.class);
var page0 = (OcrStreamEvent.Page) events.get(1);
assertThat(page0.pageNumber()).isEqualTo(0);
assertThat(page0.blocks()).hasSize(1);
assertThat(page0.blocks().get(0).text()).isEqualTo("Line 1");
assertThat(events.get(2)).isInstanceOf(OcrStreamEvent.Page.class);
var page1 = (OcrStreamEvent.Page) events.get(2);
assertThat(page1.pageNumber()).isEqualTo(1);
assertThat(events.get(3)).isInstanceOf(OcrStreamEvent.Done.class);
var done = (OcrStreamEvent.Done) events.get(3);
assertThat(done.totalBlocks()).isEqualTo(2);
assertThat(done.skippedPages()).isEqualTo(0);
}
@Test
void parseNdjsonStream_parsesErrorEvents() {
String ndjson = """
{"type":"start","totalPages":3}
{"type":"page","pageNumber":0,"blocks":[]}
{"type":"error","pageNumber":1,"message":"OCR processing failed on page 1"}
{"type":"page","pageNumber":2,"blocks":[]}
{"type":"done","totalBlocks":0,"skippedPages":1}
""";
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
List<OcrStreamEvent> events = new ArrayList<>();
RestClientOcrClient.parseNdjsonStream(stream, events::add);
assertThat(events).hasSize(5);
assertThat(events.get(2)).isInstanceOf(OcrStreamEvent.Error.class);
var error = (OcrStreamEvent.Error) events.get(2);
assertThat(error.pageNumber()).isEqualTo(1);
assertThat(error.message()).contains("OCR processing failed");
}
@Test
void parseNdjsonStream_skipsBlankLines() {
String ndjson = """
{"type":"start","totalPages":1}
{"type":"page","pageNumber":0,"blocks":[]}
{"type":"done","totalBlocks":0,"skippedPages":0}
""";
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
List<OcrStreamEvent> events = new ArrayList<>();
RestClientOcrClient.parseNdjsonStream(stream, events::add);
assertThat(events).hasSize(3);
}
@Test
void parseNdjsonStream_ignoresUnknownEventTypes() {
String ndjson = """
{"type":"start","totalPages":1}
{"type":"unknown","foo":"bar"}
{"type":"done","totalBlocks":0,"skippedPages":0}
""";
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
List<OcrStreamEvent> events = new ArrayList<>();
RestClientOcrClient.parseNdjsonStream(stream, events::add);
assertThat(events).hasSize(2);
}
@Test
void parseNdjsonStream_handlesUnknownFieldsInBlocks() {
String ndjson = """
{"type":"start","totalPages":1}
{"type":"page","pageNumber":0,"blocks":[{"pageNumber":0,"x":0.1,"y":0.2,"width":0.8,"height":0.1,"polygon":null,"text":"Line 1","confidence":0.95,"newFutureField":"ignored"}]}
{"type":"done","totalBlocks":1,"skippedPages":0}
""";
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
List<OcrStreamEvent> events = new ArrayList<>();
RestClientOcrClient.parseNdjsonStream(stream, events::add);
assertThat(events).hasSize(3);
var page = (OcrStreamEvent.Page) events.get(1);
assertThat(page.blocks().get(0).text()).isEqualTo("Line 1");
}
@Test
void parseNdjsonStream_parsesPageWithPolygon() {
String ndjson = """
{"type":"start","totalPages":1}
{"type":"page","pageNumber":0,"blocks":[{"pageNumber":0,"x":0.1,"y":0.2,"width":0.8,"height":0.1,"polygon":[[0.1,0.2],[0.9,0.2],[0.9,0.3],[0.1,0.3]],"text":"With polygon"}]}
{"type":"done","totalBlocks":1,"skippedPages":0}
""";
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
List<OcrStreamEvent> events = new ArrayList<>();
RestClientOcrClient.parseNdjsonStream(stream, events::add);
var page = (OcrStreamEvent.Page) events.get(1);
assertThat(page.blocks().get(0).polygon()).hasSize(4);
assertThat(page.blocks().get(0).text()).isEqualTo("With polygon");
}
}

View File

@@ -10,6 +10,7 @@ import org.raddatz.familienarchiv.dto.CreateTranscriptionBlockDTO;
import org.raddatz.familienarchiv.dto.ReorderTranscriptionBlocksDTO;
import org.raddatz.familienarchiv.dto.UpdateTranscriptionBlockDTO;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.BlockSource;
import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.model.DocumentAnnotation;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
@@ -26,8 +27,8 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;
import static org.springframework.http.HttpStatus.NOT_FOUND;
@ExtendWith(MockitoExtension.class)
@@ -99,6 +100,50 @@ class TranscriptionServiceTest {
verify(versionRepository).save(any(TranscriptionBlockVersion.class));
}
// ─── createOcrBlock ──────────────────────────────────────────────────────────
@Test
void createOcrBlock_createsBlockWithOcrSourceAndSavesVersion() {
UUID docId = UUID.randomUUID();
UUID annotId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
when(blockRepository.save(any())).thenAnswer(inv -> {
TranscriptionBlock b = inv.getArgument(0);
b.setId(UUID.randomUUID());
return b;
});
TranscriptionBlock result = transcriptionService.createOcrBlock(
docId, annotId, "OCR text", 3, userId);
assertThat(result.getAnnotationId()).isEqualTo(annotId);
assertThat(result.getDocumentId()).isEqualTo(docId);
assertThat(result.getText()).isEqualTo("OCR text");
assertThat(result.getSortOrder()).isEqualTo(3);
assertThat(result.getSource()).isEqualTo(BlockSource.OCR);
assertThat(result.getCreatedBy()).isEqualTo(userId);
verify(versionRepository).save(any(TranscriptionBlockVersion.class));
}
@Test
void createOcrBlock_sanitizesNullText() {
UUID docId = UUID.randomUUID();
UUID annotId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
when(blockRepository.save(any())).thenAnswer(inv -> {
TranscriptionBlock b = inv.getArgument(0);
b.setId(UUID.randomUUID());
return b;
});
TranscriptionBlock result = transcriptionService.createOcrBlock(
docId, annotId, null, 0, userId);
assertThat(result.getText()).isEmpty();
}
// ─── updateBlock ─────────────────────────────────────────────────────────────
@Test
@@ -168,6 +213,39 @@ class TranscriptionServiceTest {
.satisfies(e -> assertThat(((DomainException) e).getStatus()).isEqualTo(NOT_FOUND));
}
// ─── deleteAllBlocksByDocument ─────────────────────────────────────────────
@Test
void deleteAllBlocksByDocument_deletesAllBlocksAndAnnotations() {
UUID docId = UUID.randomUUID();
UUID annId1 = UUID.randomUUID();
UUID annId2 = UUID.randomUUID();
TranscriptionBlock block1 = TranscriptionBlock.builder()
.id(UUID.randomUUID()).documentId(docId).annotationId(annId1).sortOrder(0).build();
TranscriptionBlock block2 = TranscriptionBlock.builder()
.id(UUID.randomUUID()).documentId(docId).annotationId(annId2).sortOrder(1).build();
when(blockRepository.findByDocumentIdOrderBySortOrderAsc(docId))
.thenReturn(List.of(block1, block2));
transcriptionService.deleteAllBlocksByDocument(docId);
verify(blockRepository).deleteAll(List.of(block1, block2));
verify(blockRepository).flush();
verify(annotationRepository).deleteAllById(List.of(annId1, annId2));
}
@Test
void deleteAllBlocksByDocument_doesNothing_whenNoBlocksExist() {
UUID docId = UUID.randomUUID();
when(blockRepository.findByDocumentIdOrderBySortOrderAsc(docId)).thenReturn(List.of());
transcriptionService.deleteAllBlocksByDocument(docId);
verify(blockRepository, never()).deleteAll(any());
}
// ─── reorderBlocks ───────────────────────────────────────────────────────────
@Test
@@ -243,4 +321,47 @@ class TranscriptionServiceTest {
assertThat(transcriptionService.listBlocks(docId)).containsExactly(b);
}
// ─── reviewBlock ─────────────────────────────────────────────────────────
@Test
void reviewBlock_setsReviewedTrue() {
UUID docId = UUID.randomUUID();
UUID blockId = UUID.randomUUID();
TranscriptionBlock block = TranscriptionBlock.builder()
.id(blockId).documentId(docId).annotationId(UUID.randomUUID())
.text("corrected text").sortOrder(0).reviewed(false).build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
TranscriptionBlock result = transcriptionService.reviewBlock(docId, blockId);
assertThat(result.isReviewed()).isTrue();
verify(blockRepository).save(block);
}
@Test
void reviewBlock_togglesReviewedFalse_whenAlreadyReviewed() {
UUID docId = UUID.randomUUID();
UUID blockId = UUID.randomUUID();
TranscriptionBlock block = TranscriptionBlock.builder()
.id(blockId).documentId(docId).annotationId(UUID.randomUUID())
.text("corrected text").sortOrder(0).reviewed(true).build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
TranscriptionBlock result = transcriptionService.reviewBlock(docId, blockId);
assertThat(result.isReviewed()).isFalse();
}
@Test
void reviewBlock_throwsNotFound_whenBlockMissing() {
UUID docId = UUID.randomUUID();
UUID blockId = UUID.randomUUID();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.empty());
assertThatThrownBy(() -> transcriptionService.reviewBlock(docId, blockId))
.isInstanceOf(DomainException.class);
}
}

View File

@@ -71,6 +71,35 @@ services:
networks:
- archive-net
# --- OCR: Python microservice (Surya + Kraken) ---
ocr-service:
build:
context: ./ocr-service
dockerfile: Dockerfile
container_name: archive-ocr
restart: unless-stopped
expose:
- "8000"
mem_limit: 8g
memswap_limit: 8g
volumes:
- ocr_models:/app/models
- ocr_cache:/root/.cache
environment:
KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel
OCR_CONFIDENCE_THRESHOLD: "0.3"
OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5"
RECOGNITION_BATCH_SIZE: "16"
DETECTOR_BATCH_SIZE: "8"
networks:
- archive-net
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 10s
timeout: 5s
retries: 12
start_period: 60s
# --- Backend: Spring Boot ---
backend:
build:
@@ -89,6 +118,8 @@ services:
condition: service_healthy
mailpit:
condition: service_started
ocr-service:
condition: service_started
environment:
SPRING_DATASOURCE_URL: jdbc:postgresql://db:5432/${POSTGRES_DB}
SPRING_DATASOURCE_USERNAME: ${POSTGRES_USER}
@@ -109,6 +140,7 @@ services:
# Mailpit needs no auth or STARTTLS; production SMTP overrides these via .env
SPRING_MAIL_PROPERTIES_MAIL_SMTP_AUTH: ${MAIL_SMTP_AUTH:-false}
SPRING_MAIL_PROPERTIES_MAIL_SMTP_STARTTLS_ENABLE: ${MAIL_STARTTLS_ENABLE:-false}
APP_OCR_BASE_URL: http://ocr-service:8000
ports:
- "${PORT_BACKEND}:8080"
networks:
@@ -155,3 +187,5 @@ networks:
volumes:
frontend_node_modules:
maven_cache:
ocr_models:
ocr_cache:

View File

@@ -0,0 +1,84 @@
# ADR-001: OCR Python Microservice
## Status
Accepted
## Context
The Familienarchiv needs OCR capability to pre-populate transcription blocks from scanned documents. Two OCR engines are required:
- **Surya** — transformer-based, handles typewritten and modern Latin handwriting
- **Kraken** — historical HTR model support, required for pre-1941 German Kurrent/Suetterlin scripts
Both engines exist exclusively in the Python ecosystem. There are no production-quality Java bindings for either engine. Tess4J (Tesseract for Java) was considered but rejected: Tesseract has poor accuracy on degraded historical handwriting and no HTR-United model support.
The server has no GPU. CPU-only inference is the target (16-32 GB system RAM).
## Decision
Introduce a separate Python container (`ocr-service`) that exposes a simple HTTP API. Spring Boot calls this service via `RestClient`. The Python service is stateless — all job tracking and business logic remain in Spring Boot.
**Interface contract:**
Request:
```json
{
"pdfUrl": "http://minio:9000/archive-documents/abc.pdf?presigned...",
"scriptType": "HANDWRITING_KURRENT",
"language": "de"
}
```
Response:
```json
[
{
"pageNumber": 0,
"x": 0.12, "y": 0.08, "width": 0.76, "height": 0.04,
"polygon": [[0.12,0.08],[0.88,0.09],[0.87,0.12],[0.13,0.11]],
"text": "Sehr geehrter Herr ..."
}
]
```
Coordinates are normalized (0-1) relative to page dimensions.
**Java-side integration:**
- `OcrClient` interface with `extractBlocks()` method — mockable for unit tests
- `OcrHealthClient` interface with `isHealthy()` — separate concern from block extraction
- `RestClientOcrClient` implements both interfaces
- `OcrService` orchestrates: presigned URL generation, OCR call, block mapping, TranscriptionService delegation
**Docker networking:**
- `ocr-service` is on the internal Docker network only — no host port mapping
- Spring Boot reaches it via `http://ocr-service:8000`
- Health check with `start_period: 60s` to account for model loading (~30-60s on CPU)
## Alternatives Considered
| Alternative | Why rejected |
|---|---|
| Tess4J (Tesseract in Java) | No HTR-United model support; poor Kurrent accuracy |
| Calling Python via ProcessBuilder | Fragile, no health checks, model reloading on every call |
| Embedding Python via GraalVM | Experimental, complex dependency management for ML libraries |
| External SaaS OCR (Google Vision, AWS Textract) | Data sovereignty concern for private family documents; no Kurrent support |
## Consequences
**Easier:**
- Each engine is used via its native Python API — no bridging complexity
- OCR service can be updated independently of the main application
- Models can be swapped via volume mount without code changes
**Harder:**
- One additional container to operate (memory, health checks, restarts)
- Integration tests require WireMock stub — real OCR service is too slow for CI
- Presigned URL TTL must be managed (15-30 min recommended)
## Future Direction
- LISTEN/NOTIFY from PostgreSQL to push progress events when scaling to multiple instances
- GPU acceleration if the server is upgraded — only the Docker image needs to change

View File

@@ -0,0 +1,52 @@
# ADR-002: Polygon JSONB Storage for Annotations
## Status
Accepted
## Context
Document annotations currently store axis-aligned bounding boxes (`x, y, width, height`). Kraken OCR outputs polygon boundaries for text lines — historical handwriting (Kurrent, Suetterlin) produces rotated and curved text that axis-aligned rectangles approximate poorly.
We need to store an optional quadrilateral (4 corner points) per annotation to represent the precise text region. The polygon is display-only — overlap detection and all server-side geometry logic continues to use the AABB fields.
## Decision
Add a `polygon JSONB` column to `document_annotations`:
```sql
ALTER TABLE document_annotations ADD COLUMN polygon JSONB;
ALTER TABLE document_annotations
ADD CONSTRAINT chk_annotation_polygon_quad
CHECK (polygon IS NULL OR jsonb_array_length(polygon) = 4);
```
- `null` means rectangle — render using existing `x, y, width, height` fields (fully backward compatible)
- Non-null value is a normalized 4-point quadrilateral: `[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]` with coordinates in the 0-1 range relative to page dimensions
The existing AABB fields are always populated (even when a polygon is present) and remain the authoritative geometry for overlap detection.
**Java entity:** `List<List<Double>> polygon` backed by a custom `AttributeConverter<List<List<Double>>, String>`. No new dependency (Hypersistence Utils is not in the project and won't be added for a single column).
**Semantic invariant:** `polygon`, if present, is a 4-point quadrilateral with coordinates normalized to [0, 1] relative to page dimensions. It may originate from OCR engine output (Kraken) or from a future manual drawing tool. The AABB fields remain the geometry source of truth for server-side logic.
## Alternatives Considered
| Alternative | Why rejected |
|---|---|
| 8 `NUMERIC(8,6)` columns (x1,y1,...,x4,y4) | Verbose, no structural enforcement, awkward to query or extend |
| Separate `annotation_polygons` join table | Unnecessary complexity for a 1:1 optional relationship |
| PostGIS geometry column | Adds a heavyweight extension for a display-only field with no spatial queries |
| `String polygon` on the entity | Requires manual parsing at every callsite; error-prone |
## Consequences
**Easier:**
- Backward compatible — all existing annotations continue to work unchanged
- Frontend renders `<polygon>` or `<rect>` based on a simple null check
- Schema can accommodate N-point polygons in the future (JSONB is flexible), though the CHECK constraint currently enforces exactly 4
**Harder:**
- Cannot express range checks (`0 <= x <= 1`) as database constraints without a PL/pgSQL function — validated at the DTO layer instead
- No server-side geometry queries on polygon coordinates (acceptable — polygon is display-only)
- AttributeConverter adds a small amount of serialization code to maintain

View File

@@ -500,5 +500,37 @@
"person_alias_delete_title": "Alias entfernen?",
"person_alias_delete_body": "Dieser Name wird aus der Suche entfernt.",
"person_alias_btn_delete": "Entfernen",
"error_alias_not_found": "Der Namensalias wurde nicht gefunden."
"error_alias_not_found": "Der Namensalias wurde nicht gefunden.",
"error_ocr_service_unavailable": "Der OCR-Dienst ist nicht verfügbar.",
"error_ocr_job_not_found": "Der OCR-Auftrag wurde nicht gefunden.",
"error_ocr_document_not_uploaded": "Das Dokument hat keine Datei — OCR ist nicht möglich.",
"error_ocr_processing_failed": "Die OCR-Verarbeitung ist fehlgeschlagen.",
"ocr_script_type_typewriter": "Schreibmaschine",
"ocr_script_type_handwriting_latin": "Handschrift (lateinisch)",
"ocr_script_type_handwriting_kurrent": "Handschrift (Kurrent/Sütterlin)",
"ocr_trigger_label": "Schrifttyp",
"ocr_trigger_select_placeholder": "Schrifttyp wählen…",
"ocr_trigger_btn": "OCR starten",
"ocr_trigger_btn_disabled": "Bitte wählen Sie einen Schrifttyp",
"ocr_confirm_title": "Vorhandene Transkription ersetzen?",
"ocr_confirm_body": "Alle {count} vorhandenen Blöcke werden gelöscht und durch die OCR-Ergebnisse ersetzt. Diese Aktion kann nicht rückgängig gemacht werden.",
"ocr_confirm_btn": "Ersetzen",
"ocr_rerun_label": "OCR erneut ausführen…",
"ocr_progress_heading": "OCR läuft",
"ocr_progress_page": "Seite {current} von {total}",
"ocr_error_heading": "OCR fehlgeschlagen",
"ocr_error_retry": "Erneut versuchen",
"ocr_batch_running": "OCR läuft · {processed} von {total} Dokumente abgeschlossen",
"ocr_batch_done": "OCR abgeschlossen · {processed} erfolgreich · {errors} fehlgeschlagen",
"ocr_status_preparing": "Dokument wird vorbereitet…",
"ocr_status_loading": "Lade Modell und Dokument…",
"ocr_status_analyzing": "OCR-Analyse läuft — dies kann einige Minuten dauern…",
"ocr_status_creating_blocks": "{count} Textblöcke erkannt — erstelle Transkription…",
"ocr_status_done_blocks": "{count} Blöcke erstellt",
"ocr_status_analyzing_page": "Seite {current} von {total} wird analysiert…",
"ocr_status_done_skipped": "{count} Blöcke erstellt, {skipped} Seite(n) übersprungen",
"ocr_status_error": "OCR fehlgeschlagen",
"transcription_block_review": "Als geprüft markieren",
"transcription_block_unreview": "Markierung aufheben",
"transcription_reviewed_count": "{reviewed} von {total} geprüft"
}

View File

@@ -500,5 +500,37 @@
"person_alias_delete_title": "Remove alias?",
"person_alias_delete_body": "This name will be removed from search results.",
"person_alias_btn_delete": "Remove",
"error_alias_not_found": "The name alias was not found."
"error_alias_not_found": "The name alias was not found.",
"error_ocr_service_unavailable": "The OCR service is not available.",
"error_ocr_job_not_found": "The OCR job was not found.",
"error_ocr_document_not_uploaded": "The document has no file — OCR is not possible.",
"error_ocr_processing_failed": "OCR processing failed.",
"ocr_script_type_typewriter": "Typewriter",
"ocr_script_type_handwriting_latin": "Handwriting (Latin)",
"ocr_script_type_handwriting_kurrent": "Handwriting (Kurrent/Sütterlin)",
"ocr_trigger_label": "Script type",
"ocr_trigger_select_placeholder": "Select script type…",
"ocr_trigger_btn": "Start OCR",
"ocr_trigger_btn_disabled": "Please select a script type",
"ocr_confirm_title": "Replace existing transcription?",
"ocr_confirm_body": "All {count} existing blocks will be deleted and replaced with OCR results. This action cannot be undone.",
"ocr_confirm_btn": "Replace",
"ocr_rerun_label": "Re-run OCR…",
"ocr_progress_heading": "OCR running",
"ocr_progress_page": "Page {current} of {total}",
"ocr_error_heading": "OCR failed",
"ocr_error_retry": "Try again",
"ocr_batch_running": "OCR running · {processed} of {total} documents complete",
"ocr_batch_done": "OCR complete · {processed} successful · {errors} failed",
"ocr_status_preparing": "Preparing document…",
"ocr_status_loading": "Loading model and document…",
"ocr_status_analyzing": "OCR analysis running — this may take a few minutes…",
"ocr_status_creating_blocks": "{count} text blocks detected — creating transcription…",
"ocr_status_done_blocks": "{count} blocks created",
"ocr_status_analyzing_page": "Analyzing page {current} of {total}…",
"ocr_status_done_skipped": "{count} blocks created, {skipped} page(s) skipped",
"ocr_status_error": "OCR failed",
"transcription_block_review": "Mark as reviewed",
"transcription_block_unreview": "Unmark as reviewed",
"transcription_reviewed_count": "{reviewed} of {total} reviewed"
}

View File

@@ -500,5 +500,37 @@
"person_alias_delete_title": "Eliminar alias?",
"person_alias_delete_body": "Este nombre se eliminara de los resultados de busqueda.",
"person_alias_btn_delete": "Eliminar",
"error_alias_not_found": "No se encontro el alias de nombre."
"error_alias_not_found": "No se encontro el alias de nombre.",
"error_ocr_service_unavailable": "El servicio OCR no está disponible.",
"error_ocr_job_not_found": "No se encontró el trabajo OCR.",
"error_ocr_document_not_uploaded": "El documento no tiene archivo — OCR no es posible.",
"error_ocr_processing_failed": "El procesamiento OCR ha fallado.",
"ocr_script_type_typewriter": "Máquina de escribir",
"ocr_script_type_handwriting_latin": "Escritura manuscrita (latina)",
"ocr_script_type_handwriting_kurrent": "Escritura manuscrita (Kurrent/Sütterlin)",
"ocr_trigger_label": "Tipo de escritura",
"ocr_trigger_select_placeholder": "Seleccionar tipo de escritura…",
"ocr_trigger_btn": "Iniciar OCR",
"ocr_trigger_btn_disabled": "Por favor seleccione un tipo de escritura",
"ocr_confirm_title": "¿Reemplazar transcripción existente?",
"ocr_confirm_body": "Los {count} bloques existentes serán eliminados y reemplazados con los resultados del OCR. Esta acción no se puede deshacer.",
"ocr_confirm_btn": "Reemplazar",
"ocr_rerun_label": "Ejecutar OCR de nuevo…",
"ocr_progress_heading": "OCR en curso",
"ocr_progress_page": "Página {current} de {total}",
"ocr_error_heading": "OCR fallido",
"ocr_error_retry": "Intentar de nuevo",
"ocr_batch_running": "OCR en curso · {processed} de {total} documentos completados",
"ocr_batch_done": "OCR completado · {processed} exitosos · {errors} fallidos",
"ocr_status_preparing": "Preparando documento…",
"ocr_status_loading": "Cargando modelo y documento…",
"ocr_status_analyzing": "Análisis OCR en curso — esto puede tardar unos minutos…",
"ocr_status_creating_blocks": "{count} bloques de texto detectados — creando transcripción…",
"ocr_status_done_blocks": "{count} bloques creados",
"ocr_status_analyzing_page": "Analizando página {current} de {total}…",
"ocr_status_done_skipped": "{count} bloques creados, {skipped} página(s) omitida(s)",
"ocr_status_error": "OCR fallido",
"transcription_block_review": "Marcar como revisado",
"transcription_block_unreview": "Desmarcar como revisado",
"transcription_reviewed_count": "{reviewed} de {total} revisados"
}

View File

@@ -1,5 +1,6 @@
<script lang="ts">
import type { Annotation } from '$lib/types';
import AnnotationShape from './AnnotationShape.svelte';
type DrawRect = {
x: number;
@@ -33,13 +34,6 @@ let {
let drawStart = $state<{ x: number; y: number } | null>(null);
let drawRect = $state<DrawRect | null>(null);
function hexToRgba(hex: string, alpha: number): string {
const r = parseInt(hex.slice(1, 3), 16);
const g = parseInt(hex.slice(3, 5), 16);
const b = parseInt(hex.slice(5, 7), 16);
return `rgba(${r}, ${g}, ${b}, ${alpha})`;
}
function getNormalizedCoords(event: PointerEvent, element: HTMLElement): { x: number; y: number } {
const rect = element.getBoundingClientRect();
return {
@@ -109,58 +103,18 @@ const containerStyle = $derived(
onpointerup={handlePointerUp}
>
{#each annotations as annotation (annotation.id)}
<div
data-testid="annotation-{annotation.id}"
data-annotation
class:annotation-flash={flashAnnotationId === annotation.id}
role="button"
tabindex="0"
aria-label="Block anzeigen"
<AnnotationShape
annotation={annotation}
isHovered={hoveredId === annotation.id}
isActive={annotation.id === activeAnnotationId}
faded={!dimmed && !!activeAnnotationId && annotation.id !== activeAnnotationId}
dimmed={dimmed}
blockNumber={blockNumbers[annotation.id]}
isFlashing={flashAnnotationId === annotation.id}
onclick={() => onAnnotationClick?.(annotation.id)}
onkeydown={(e) => {
if (e.key === 'Enter' || e.key === ' ') onAnnotationClick?.(annotation.id);
}}
onpointerenter={() => (hoveredId = annotation.id)}
onpointerleave={() => (hoveredId = null)}
style="
position: absolute;
left: {annotation.x * 100}%;
top: {annotation.y * 100}%;
width: {annotation.width * 100}%;
height: {annotation.height * 100}%;
background-color: {hexToRgba(annotation.color, dimmed ? 0.3 : (hoveredId === annotation.id || annotation.id === activeAnnotationId ? 0.5 : 0.3))};
box-shadow: {dimmed ? 'none' : (annotation.id === activeAnnotationId ? `inset 0 0 0 2px ${hexToRgba(annotation.color, 0.8)}` : hoveredId === annotation.id ? `inset 0 0 0 2px ${hexToRgba(annotation.color, 0.8)}` : 'none')};
opacity: {dimmed ? 1 : (activeAnnotationId && annotation.id !== activeAnnotationId ? 0.3 : 1)};
pointer-events: auto;
cursor: pointer;
transition: background-color 0.15s ease, box-shadow 0.15s ease, opacity 0.3s ease;
"
>
{#if !dimmed && blockNumbers[annotation.id]}
<div
style="
position: absolute;
top: -8px;
left: -8px;
width: 20px;
height: 20px;
border-radius: 50%;
background-color: {annotation.color};
color: white;
font-size: 11px;
font-family: sans-serif;
font-weight: 700;
display: flex;
align-items: center;
justify-content: center;
pointer-events: none;
box-shadow: 0 1px 3px rgba(0,0,0,0.3);
"
>
{blockNumbers[annotation.id]}
</div>
{/if}
</div>
/>
{/each}
{#if drawRect && drawRect.width > 0}
@@ -178,27 +132,3 @@ const containerStyle = $derived(
></div>
{/if}
</div>
<style>
@keyframes annotation-flash-anim {
0% {
outline: 3px solid color-mix(in srgb, var(--color-turquoise) 80%, transparent);
outline-offset: 0px;
}
100% {
outline: 3px solid color-mix(in srgb, var(--color-turquoise) 0%, transparent);
outline-offset: 2px;
}
}
.annotation-flash {
animation: annotation-flash-anim 1.5s ease-out;
}
@media (prefers-reduced-motion: reduce) {
.annotation-flash {
animation: none;
outline: 3px solid color-mix(in srgb, var(--color-turquoise) 80%, transparent);
}
}
</style>

View File

@@ -0,0 +1,136 @@
<script lang="ts">
import type { Annotation } from '$lib/types';
let {
annotation,
isHovered,
isActive,
faded = false,
dimmed = false,
blockNumber = undefined,
isFlashing = false,
onclick,
onpointerenter,
onpointerleave
}: {
annotation: Annotation;
isHovered: boolean;
isActive: boolean;
faded?: boolean;
dimmed?: boolean;
blockNumber?: number | undefined;
isFlashing?: boolean;
onclick: () => void;
onpointerenter: () => void;
onpointerleave: () => void;
} = $props();
function hexToRgba(hex: string, alpha: number): string {
const r = parseInt(hex.slice(1, 3), 16);
const g = parseInt(hex.slice(3, 5), 16);
const b = parseInt(hex.slice(5, 7), 16);
return `rgba(${r}, ${g}, ${b}, ${alpha})`;
}
let clipPath = $derived.by(() => {
if (!annotation.polygon || annotation.polygon.length !== 4) return 'none';
const points = annotation.polygon
.map(([px, py]) => {
const cx = ((px - annotation.x) / annotation.width) * 100;
const cy = ((py - annotation.y) / annotation.height) * 100;
return `${cx}% ${cy}%`;
})
.join(', ');
return `polygon(${points})`;
});
let bgAlpha = $derived(dimmed ? 0.3 : isHovered || isActive ? 0.5 : 0.3);
let boxShadow = $derived.by(() => {
if (dimmed) return 'none';
if (isActive || isHovered) return `inset 0 0 0 2px ${hexToRgba(annotation.color, 0.8)}`;
return 'none';
});
let opacity = $derived(dimmed ? 1 : faded ? 0.3 : 1);
let shapeStyle = $derived(
`position: absolute;` +
` left: ${annotation.x * 100}%;` +
` top: ${annotation.y * 100}%;` +
` width: ${annotation.width * 100}%;` +
` height: ${annotation.height * 100}%;` +
` background-color: ${hexToRgba(annotation.color, bgAlpha)};` +
` box-shadow: ${boxShadow};` +
` opacity: ${opacity};` +
` pointer-events: auto;` +
` cursor: pointer;` +
` transition: background-color 0.15s ease, box-shadow 0.15s ease, opacity 0.3s ease;` +
(clipPath !== 'none' ? ` clip-path: ${clipPath};` : '')
);
</script>
<div
data-testid="annotation-{annotation.id}"
data-annotation
class:annotation-flash={isFlashing}
role="button"
tabindex="0"
aria-label="Block anzeigen"
onclick={onclick}
onkeydown={(e) => {
if (e.key === 'Enter' || e.key === ' ') onclick();
}}
onpointerenter={onpointerenter}
onpointerleave={onpointerleave}
style={shapeStyle}
>
{#if !dimmed && blockNumber}
<div
style="
position: absolute;
top: -8px;
left: -8px;
width: 20px;
height: 20px;
border-radius: 50%;
background-color: {annotation.color};
color: white;
font-size: 12px;
font-family: sans-serif;
font-weight: 700;
display: flex;
align-items: center;
justify-content: center;
pointer-events: none;
box-shadow: 0 1px 3px rgba(0,0,0,0.3);
"
>
{blockNumber}
</div>
{/if}
</div>
<style>
@keyframes annotation-flash-anim {
0% {
outline: 3px solid color-mix(in srgb, var(--color-turquoise) 80%, transparent);
outline-offset: 0px;
}
100% {
outline: 3px solid color-mix(in srgb, var(--color-turquoise) 0%, transparent);
outline-offset: 2px;
}
}
.annotation-flash {
animation: annotation-flash-anim 1.5s ease-out;
}
@media (prefers-reduced-motion: reduce) {
.annotation-flash {
animation: none;
outline: 3px solid color-mix(in srgb, var(--color-turquoise) 80%, transparent);
}
}
</style>

View File

@@ -0,0 +1,91 @@
<script lang="ts">
import { m } from '$lib/paraglide/messages.js';
interface Props {
jobId: string;
onDone: () => void;
}
let { jobId, onDone }: Props = $props();
let status: 'running' | 'done' | 'error' = $state('running');
let processed: number = $state(0);
let total: number = $state(0);
let currentPage: number = $state(0);
let totalPages: number = $state(0);
let retryCount: number = $state(0);
let progressPercent = $derived(total > 0 ? Math.round((processed / total) * 100) : 0);
$effect(() => {
void retryCount; // track dependency to re-create EventSource on retry
const source = new EventSource(`/api/ocr/jobs/${jobId}/progress`);
source.addEventListener('document', (e) => {
const data = JSON.parse(e.data);
processed = data.processed;
total = data.total;
});
source.addEventListener('page', (e) => {
const data = JSON.parse(e.data);
currentPage = data.page;
totalPages = data.totalPages;
});
source.addEventListener('done', () => {
status = 'done';
source.close();
onDone();
});
source.addEventListener('error', () => {
status = 'error';
source.close();
});
source.onerror = () => {
status = 'error';
source.close();
};
return () => {
source.close();
};
});
</script>
{#if status === 'running'}
<div class="border-brand-sand rounded-sm border bg-white p-4">
<h3 class="mb-3 text-xs font-bold tracking-widest text-gray-400 uppercase">
{m.ocr_progress_heading()}
</h3>
<div class="bg-brand-sand h-2 w-full overflow-hidden rounded-full">
<div
class="h-full bg-brand-mint transition-all duration-300"
style="width: {progressPercent}%"
role="progressbar"
aria-label={m.ocr_progress_heading()}
aria-valuenow={progressPercent}
aria-valuemin={0}
aria-valuemax={100}
></div>
</div>
<p class="mt-2 text-right text-sm text-gray-500">
{m.ocr_progress_page({ current: String(currentPage), total: String(totalPages) })}
</p>
</div>
{:else if status === 'error'}
<div class="border-brand-sand rounded-sm border border-l-4 border-l-red-500 bg-white p-4">
<h3 class="mb-2 text-sm font-semibold text-red-700">
{m.ocr_error_heading()}
</h3>
<button
type="button"
onclick={() => { retryCount++; status = 'running'; }}
class="text-sm font-medium text-brand-navy transition-colors hover:text-brand-navy/80"
>
{m.ocr_error_retry()}
</button>
</div>
{/if}

View File

@@ -0,0 +1,49 @@
<script lang="ts">
import { untrack } from 'svelte';
import { m } from '$lib/paraglide/messages.js';
import { getConfirmService } from '$lib/services/confirm.svelte';
import ScriptTypeSelect from './ScriptTypeSelect.svelte';
interface Props {
existingBlockCount: number;
storedScriptType: string;
onTrigger: (scriptType: string) => void;
}
let { existingBlockCount, storedScriptType, onTrigger }: Props = $props();
const { confirm } = getConfirmService();
let selectedScriptType: string = $state(
untrack(() => (storedScriptType && storedScriptType !== 'UNKNOWN' ? storedScriptType : ''))
);
async function handleClick() {
if (!selectedScriptType) return;
if (existingBlockCount > 0) {
const confirmed = await confirm({
title: m.ocr_confirm_title(),
body: m.ocr_confirm_body({ count: String(existingBlockCount) }),
confirmLabel: m.ocr_confirm_btn(),
destructive: true
});
if (!confirmed) return;
}
onTrigger(selectedScriptType);
}
</script>
<div class="flex flex-col gap-3">
<ScriptTypeSelect bind:value={selectedScriptType} />
<button
type="button"
disabled={!selectedScriptType}
title={!selectedScriptType ? m.ocr_trigger_btn_disabled() : undefined}
onclick={handleClick}
class="min-h-[44px] w-full rounded-sm bg-brand-navy font-sans text-sm font-medium text-white transition-colors hover:bg-brand-navy/90 disabled:cursor-not-allowed disabled:opacity-50"
>
{m.ocr_trigger_btn()}
</button>
</div>

View File

@@ -145,6 +145,7 @@ async function renderPage(doc: PDFDocumentProxy, pageNum: number) {
// Text layer
const textDiv = textLayerEl;
if (!textDiv) return;
textDiv.innerHTML = '';
textDiv.style.width = `${viewport.width / dpr}px`;
textDiv.style.height = `${viewport.height / dpr}px`;

View File

@@ -0,0 +1,27 @@
<script lang="ts">
import { m } from '$lib/paraglide/messages.js';
interface Props {
value: string;
disabled?: boolean;
}
let { value = $bindable(), disabled = false }: Props = $props();
</script>
<div>
<label for="script-type-select" class="text-sm font-bold tracking-widest text-gray-400 uppercase">
{m.ocr_trigger_label()}
</label>
<select
id="script-type-select"
bind:value={value}
disabled={disabled}
class="border-brand-sand min-h-[44px] w-full rounded-sm border bg-white px-3 py-2 font-serif text-sm text-brand-navy focus:ring-2 focus:ring-brand-mint focus:outline-none"
>
<option value="" disabled>{m.ocr_trigger_select_placeholder()}</option>
<option value="TYPEWRITER">{m.ocr_script_type_typewriter()}</option>
<option value="HANDWRITING_LATIN">{m.ocr_script_type_handwriting_latin()}</option>
<option value="HANDWRITING_KURRENT">{m.ocr_script_type_handwriting_kurrent()}</option>
</select>
</div>

View File

@@ -14,6 +14,7 @@ type Props = {
text: string;
label: string | null;
active: boolean;
reviewed: boolean;
saveState: SaveState;
canComment: boolean;
currentUserId: string | null;
@@ -21,6 +22,7 @@ type Props = {
onFocus: () => void;
onDeleteClick: () => void;
onRetry: () => void;
onReviewToggle: () => void;
onMoveUp?: () => void;
onMoveDown?: () => void;
isFirst?: boolean;
@@ -34,6 +36,7 @@ let {
text,
label = null,
active,
reviewed,
saveState,
canComment,
currentUserId,
@@ -41,6 +44,7 @@ let {
onFocus,
onDeleteClick,
onRetry,
onReviewToggle,
onMoveUp,
onMoveDown,
isFirst = false,
@@ -239,6 +243,29 @@ function handleTextareaMouseUp() {
</span>
{/if}
<!-- Review toggle -->
<button
type="button"
class="cursor-pointer transition-colors {reviewed ? 'text-turquoise hover:text-turquoise/70' : 'text-ink-3 hover:text-turquoise'}"
aria-label={reviewed ? m.transcription_block_unreview() : m.transcription_block_review()}
title={reviewed ? m.transcription_block_unreview() : m.transcription_block_review()}
onclick={onReviewToggle}
>
<svg
class="h-4 w-4"
fill={reviewed ? 'currentColor' : 'none'}
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="1.5"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
</button>
<!-- Delete button -->
<button
type="button"

View File

@@ -2,6 +2,7 @@
import { m } from '$lib/paraglide/messages.js';
import { SvelteMap } from 'svelte/reactivity';
import TranscriptionBlock from './TranscriptionBlock.svelte';
import OcrTrigger from './OcrTrigger.svelte';
import type { TranscriptionBlockData } from '$lib/types';
type SaveState = 'idle' | 'saving' | 'saved' | 'fading' | 'error';
@@ -12,9 +13,13 @@ type Props = {
canComment: boolean;
currentUserId: string | null;
activeAnnotationId?: string | null;
storedScriptType?: string;
canRunOcr?: boolean;
onBlockFocus: (blockId: string) => void;
onSaveBlock: (blockId: string, text: string) => Promise<void>;
onDeleteBlock: (blockId: string) => Promise<void>;
onReviewToggle: (blockId: string) => Promise<void>;
onTriggerOcr?: (scriptType: string) => void;
};
let {
@@ -23,9 +28,13 @@ let {
canComment,
currentUserId,
activeAnnotationId = null,
storedScriptType = '',
canRunOcr = false,
onBlockFocus,
onSaveBlock,
onDeleteBlock
onDeleteBlock,
onReviewToggle,
onTriggerOcr
}: Props = $props();
let activeBlockId: string | null = $state(null);
@@ -282,6 +291,7 @@ $effect(() => {
text={block.text}
label={block.label}
active={activeBlockId === block.id}
reviewed={block.reviewed ?? false}
saveState={getSaveState(block.id)}
canComment={canComment}
currentUserId={currentUserId}
@@ -289,6 +299,7 @@ $effect(() => {
onFocus={() => handleFocus(block.id)}
onDeleteClick={() => handleDelete(block.id)}
onRetry={() => handleRetry(block.id)}
onReviewToggle={() => onReviewToggle(block.id)}
onMoveUp={() => handleMoveUp(block.id)}
onMoveDown={() => handleMoveDown(block.id)}
isFirst={i === 0}
@@ -307,6 +318,23 @@ $effect(() => {
>
{m.transcription_next_block_cta({ number: sortedBlocks.length + 1 })}
</div>
{#if canRunOcr && onTriggerOcr}
<details class="mt-6">
<summary
class="cursor-pointer font-sans text-xs font-medium text-ink-3 transition-colors hover:text-brand-navy"
>
{m.ocr_rerun_label()}
</summary>
<div class="mt-3 max-w-xs">
<OcrTrigger
existingBlockCount={blocks.length}
storedScriptType={storedScriptType}
onTrigger={onTriggerOcr}
/>
</div>
</details>
{/if}
</div>
{:else}
<div class="flex flex-1 flex-col items-center justify-center px-6 py-12 text-center">
@@ -323,9 +351,26 @@ $effect(() => {
d="M19.5 14.25v-2.625a3.375 3.375 0 00-3.375-3.375h-1.5A1.125 1.125 0 0113.5 7.125v-1.5a3.375 3.375 0 00-3.375-3.375H8.25m0 12.75h7.5m-7.5 3H12M10.5 2.25H5.625c-.621 0-1.125.504-1.125 1.125v17.25c0 .621.504 1.125 1.125 1.125h12.75c.621 0 1.125-.504 1.125-1.125V11.25a9 9 0 00-9-9z"
/>
</svg>
<p class="max-w-xs text-sm leading-relaxed text-ink-3">
{m.transcription_empty_cta()}
</p>
{#if canRunOcr && onTriggerOcr}
<p class="mb-6 max-w-xs text-sm leading-relaxed text-ink-3">
{m.transcription_empty_title()}
</p>
<div class="w-full max-w-xs">
<OcrTrigger
existingBlockCount={0}
storedScriptType={storedScriptType}
onTrigger={onTriggerOcr}
/>
</div>
<p class="mt-4 text-xs text-ink-3">
{m.transcription_empty_desc()}
</p>
{:else}
<p class="max-w-xs text-sm leading-relaxed text-ink-3">
{m.transcription_empty_cta()}
</p>
{/if}
</div>
{/if}
</div>

View File

@@ -22,6 +22,10 @@ export type ErrorCode =
| 'TRANSCRIPTION_BLOCK_NOT_FOUND'
| 'TRANSCRIPTION_BLOCK_CONFLICT'
| 'COMMENT_NOT_FOUND'
| 'OCR_SERVICE_UNAVAILABLE'
| 'OCR_JOB_NOT_FOUND'
| 'OCR_DOCUMENT_NOT_UPLOADED'
| 'OCR_PROCESSING_FAILED'
| 'UNAUTHORIZED'
| 'FORBIDDEN'
| 'VALIDATION_ERROR'
@@ -85,6 +89,14 @@ export function getErrorMessage(code: ErrorCode | string | undefined): string {
return m.error_transcription_block_conflict();
case 'COMMENT_NOT_FOUND':
return m.error_comment_not_found();
case 'OCR_SERVICE_UNAVAILABLE':
return m.error_ocr_service_unavailable();
case 'OCR_JOB_NOT_FOUND':
return m.error_ocr_job_not_found();
case 'OCR_DOCUMENT_NOT_UPLOADED':
return m.error_ocr_document_not_uploaded();
case 'OCR_PROCESSING_FAILED':
return m.error_ocr_processing_failed();
case 'UNAUTHORIZED':
return m.error_unauthorized();
case 'FORBIDDEN':

View File

@@ -0,0 +1,78 @@
import { describe, it, expect, vi } from 'vitest';
vi.mock('$lib/paraglide/messages.js', () => ({
m: {
ocr_progress_heading: () => 'OCR-Analyse',
ocr_status_preparing: () => 'Dokument wird vorbereitet…',
ocr_status_loading: () => 'Lade Modell und Dokument…',
ocr_status_analyzing: () => 'OCR-Analyse läuft…',
ocr_status_creating_blocks: ({ count }: { count: string }) => `${count} Textblöcke erkannt`,
ocr_status_done_blocks: ({ count }: { count: string }) => `${count} Blöcke erstellt`,
ocr_status_done_skipped: ({ count, skipped }: { count: string; skipped: string }) =>
`${count} Blöcke erstellt, ${skipped} Seite(n) übersprungen`,
ocr_status_analyzing_page: ({ current, total }: { current: string; total: string }) =>
`Seite ${current} von ${total} wird analysiert…`,
ocr_status_error: () => 'OCR fehlgeschlagen'
}
}));
import { translateOcrProgress } from './translateOcrProgress';
describe('translateOcrProgress', () => {
it('returns heading for empty code', () => {
const result = translateOcrProgress('');
expect(result.message).toBe('OCR-Analyse');
});
it('translates PREPARING', () => {
const result = translateOcrProgress('PREPARING');
expect(result.message).toBe('Dokument wird vorbereitet…');
expect(result.currentPage).toBeUndefined();
expect(result.totalPages).toBeUndefined();
});
it('translates LOADING', () => {
expect(translateOcrProgress('LOADING').message).toBe('Lade Modell und Dokument…');
});
it('translates ANALYZING', () => {
expect(translateOcrProgress('ANALYZING').message).toBe('OCR-Analyse läuft…');
});
it('translates CREATING_BLOCKS with count', () => {
expect(translateOcrProgress('CREATING_BLOCKS:42').message).toBe('42 Textblöcke erkannt');
});
it('translates DONE without skipped pages', () => {
const result = translateOcrProgress('DONE:15');
expect(result.message).toBe('15 Blöcke erstellt');
expect(result.skippedPages).toBeUndefined();
});
it('translates DONE with zero skipped', () => {
const result = translateOcrProgress('DONE:15:0');
expect(result.message).toBe('15 Blöcke erstellt');
expect(result.skippedPages).toBeUndefined();
});
it('translates DONE with skipped pages', () => {
const result = translateOcrProgress('DONE:12:2');
expect(result.message).toBe('12 Blöcke erstellt, 2 Seite(n) übersprungen');
expect(result.skippedPages).toBe(2);
});
it('translates ANALYZING_PAGE with current, total, and blocks', () => {
const result = translateOcrProgress('ANALYZING_PAGE:2:5:10');
expect(result.message).toBe('Seite 2 von 5 wird analysiert…');
expect(result.currentPage).toBe(2);
expect(result.totalPages).toBe(5);
});
it('translates ERROR', () => {
expect(translateOcrProgress('ERROR').message).toBe('OCR fehlgeschlagen');
});
it('returns raw code for unknown codes', () => {
expect(translateOcrProgress('UNKNOWN_CODE').message).toBe('UNKNOWN_CODE');
});
});

View File

@@ -0,0 +1,56 @@
import { m } from '$lib/paraglide/messages.js';
export interface OcrProgressResult {
message: string;
currentPage?: number;
totalPages?: number;
skippedPages?: number;
}
export function translateOcrProgress(code: string): OcrProgressResult {
if (!code) return { message: m.ocr_progress_heading() };
const parts = code.split(':');
const key = parts[0];
switch (key) {
case 'PREPARING':
return { message: m.ocr_status_preparing() };
case 'LOADING':
return { message: m.ocr_status_loading() };
case 'ANALYZING':
return { message: m.ocr_status_analyzing() };
case 'CREATING_BLOCKS':
return { message: m.ocr_status_creating_blocks({ count: parts[1] ?? '0' }) };
case 'DONE': {
const count = parts[1] ?? '0';
const skipped = parts[2] ? parseInt(parts[2], 10) : 0;
if (skipped > 0) {
return {
message: m.ocr_status_done_skipped({
count,
skipped: String(skipped)
}),
skippedPages: skipped
};
}
return { message: m.ocr_status_done_blocks({ count }) };
}
case 'ANALYZING_PAGE': {
const current = parseInt(parts[1] ?? '0', 10);
const total = parseInt(parts[2] ?? '0', 10);
return {
message: m.ocr_status_analyzing_page({
current: String(current),
total: String(total)
}),
currentPage: current,
totalPages: total
};
}
case 'ERROR':
return { message: m.ocr_status_error() };
default:
return { message: code };
}
}

View File

@@ -35,6 +35,8 @@ export type TranscriptionBlockData = {
label: string | null;
sortOrder: number;
version: number;
source: 'MANUAL' | 'OCR';
reviewed: boolean;
updatedAt?: string | null;
};
@@ -49,4 +51,5 @@ export type Annotation = {
color: string;
createdAt: string;
fileHash?: string | null;
polygon?: [number, number][] | null;
};

View File

@@ -7,6 +7,8 @@ import TranscriptionEditView from '$lib/components/TranscriptionEditView.svelte'
import TranscriptionReadView from '$lib/components/TranscriptionReadView.svelte';
import TranscriptionPanelHeader from '$lib/components/TranscriptionPanelHeader.svelte';
import type { TranscriptionBlockData } from '$lib/types';
import { getErrorMessage } from '$lib/errors';
import { translateOcrProgress } from '$lib/ocr/translateOcrProgress';
let { data } = $props();
@@ -118,6 +120,78 @@ async function deleteBlock(blockId: string) {
annotationReloadKey++;
}
async function reviewToggle(blockId: string) {
const res = await fetch(`/api/documents/${doc.id}/transcription-blocks/${blockId}/review`, {
method: 'PUT'
});
if (!res.ok) return;
const updated = await res.json();
transcriptionBlocks = transcriptionBlocks.map((b) => (b.id === blockId ? updated : b));
}
let ocrRunning = $state(false);
let ocrProgressMessage = $state('');
let ocrErrorMessage = $state('');
let ocrPollTimer = $state<ReturnType<typeof setInterval> | null>(null);
let ocrSkippedPages = $state(0);
async function triggerOcr(scriptType: string) {
ocrRunning = true;
ocrErrorMessage = '';
try {
const res = await fetch(`/api/documents/${doc.id}/ocr`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ scriptType })
});
if (res.ok) {
const data = await res.json();
pollOcrJob(data.jobId);
} else {
ocrRunning = false;
const body = await res.json().catch(() => null);
const code = (body as { code?: string } | null)?.code;
ocrErrorMessage = code ? getErrorMessage(code) : m.ocr_status_error();
}
} catch (e) {
console.error('Failed to trigger OCR:', e);
ocrRunning = false;
ocrErrorMessage = m.ocr_status_error();
}
}
function pollOcrJob(jobId: string) {
if (ocrPollTimer) clearInterval(ocrPollTimer);
ocrPollTimer = setInterval(async () => {
try {
const res = await fetch(`/api/ocr/jobs/${jobId}`);
if (!res.ok) return;
const job = await res.json();
const rawCode = job.progressMessage ?? '';
const progress = translateOcrProgress(rawCode);
ocrProgressMessage = progress.message;
if (progress.skippedPages !== undefined) ocrSkippedPages = progress.skippedPages;
if (job.status === 'DONE' || job.status === 'FAILED') {
if (ocrPollTimer) clearInterval(ocrPollTimer);
ocrPollTimer = null;
setTimeout(() => {
ocrRunning = false;
ocrProgressMessage = '';
ocrSkippedPages = 0;
}, 1000);
if (job.status === 'FAILED') {
ocrErrorMessage = m.ocr_status_error();
}
await loadTranscriptionBlocks();
annotationReloadKey++;
panelMode = transcriptionBlocks.length > 0 ? 'read' : 'edit';
}
} catch {
// polling is best-effort
}
}, 2000);
}
async function createBlockFromDraw(rect: {
x: number;
y: number;
@@ -198,12 +272,28 @@ function handleParagraphClick(annotationId: string) {
);
}
// Load blocks when transcribe mode is entered and set default panel mode
async function checkOcrStatus() {
if (!doc?.id) return;
try {
const res = await fetch(`/api/documents/${doc.id}/ocr-status`);
if (!res.ok) return;
const status = await res.json();
if ((status.status === 'PENDING' || status.status === 'RUNNING') && status.jobId) {
ocrRunning = true;
pollOcrJob(status.jobId);
}
} catch {
// best-effort
}
}
// Load blocks and check OCR status when transcribe mode is entered
$effect(() => {
if (transcribeMode) {
loadTranscriptionBlocks().then(() => {
panelMode = transcriptionBlocks.length > 0 ? 'read' : 'edit';
});
checkOcrStatus();
}
});
@@ -227,7 +317,10 @@ onMount(() => {
}
}
document.addEventListener('keydown', onKeyDown);
return () => document.removeEventListener('keydown', onKeyDown);
return () => {
document.removeEventListener('keydown', onKeyDown);
if (ocrPollTimer) clearInterval(ocrPollTimer);
};
});
</script>
@@ -258,7 +351,7 @@ onMount(() => {
fileUrl={fileUrl}
isLoading={isLoading}
error={fileError}
transcribeMode={transcribeMode}
transcribeMode={transcribeMode && !ocrRunning}
blockNumbers={blockNumbers}
annotationReloadKey={annotationReloadKey}
annotationsDimmed={transcribeMode && panelMode === 'read'}
@@ -303,7 +396,45 @@ onMount(() => {
onClose={() => (transcribeMode = false)}
/>
<div class="flex-1 overflow-y-auto">
{#if panelMode === 'read'}
{#if ocrErrorMessage}
<div class="mx-4 mt-4 rounded-sm border border-red-200 bg-red-50 px-4 py-3">
<p class="text-sm text-red-700">{ocrErrorMessage}</p>
</div>
{/if}
{#if ocrRunning}
<div class="flex flex-1 flex-col items-center justify-center px-6 py-12 text-center">
<svg
class="mb-4 h-8 w-8 animate-spin text-brand-mint"
viewBox="0 0 24 24"
fill="none"
>
<circle
class="opacity-25"
cx="12"
cy="12"
r="10"
stroke="currentColor"
stroke-width="4"
></circle>
<path
class="opacity-75"
fill="currentColor"
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z"
></path>
</svg>
<p class="text-xs font-bold tracking-widest text-gray-400 uppercase">
{m.ocr_progress_heading()}
</p>
<p class="mt-2 text-sm text-ink-2">
{ocrProgressMessage}
</p>
{#if ocrSkippedPages > 0}
<p class="mt-1 text-xs text-amber-600">
{ocrSkippedPages} Seiten übersprungen
</p>
{/if}
</div>
{:else if panelMode === 'read'}
<TranscriptionReadView
blocks={transcriptionBlocks}
highlightBlockId={highlightBlockId}
@@ -316,9 +447,13 @@ onMount(() => {
canComment={canWrite}
currentUserId={currentUserId}
activeAnnotationId={activeAnnotationId}
storedScriptType={doc.scriptType ?? ''}
canRunOcr={canWrite && !!doc.filePath}
onBlockFocus={handleBlockFocus}
onSaveBlock={saveBlock}
onDeleteBlock={deleteBlock}
onReviewToggle={reviewToggle}
onTriggerOcr={triggerOcr}
/>
{/if}
</div>

View File

@@ -0,0 +1,6 @@
__pycache__
*.pyc
.venv
models/
.git
.pytest_cache

26
ocr-service/Dockerfile Normal file
View File

@@ -0,0 +1,26 @@
FROM python:3.11-slim
WORKDIR /app
# curl for healthcheck; libgomp1 for PyTorch CPU threading; libvips for kraken PDF support
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
libgomp1 \
libvips42 \
&& rm -rf /var/lib/apt/lists/*
# PyTorch CPU-only — separate layer; the whl/cpu index strips all CUDA variants (~2 GB saved)
# torchvision must also come from the CPU index to match torch's operator registrations
RUN pip install --no-cache-dir \
torch==2.7.1 \
torchvision==0.22.1 \
--index-url https://download.pytorch.org/whl/cpu
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

90
ocr-service/confidence.py Normal file
View File

@@ -0,0 +1,90 @@
"""Confidence-based [unleserlich] marker insertion for OCR output."""
import os
THRESHOLD_DEFAULT = float(os.environ.get("OCR_CONFIDENCE_THRESHOLD", "0.3"))
THRESHOLD_KURRENT = float(os.environ.get("OCR_CONFIDENCE_THRESHOLD_KURRENT", "0.5"))
ILLEGIBLE_MARKER = "[unleserlich]"
def get_threshold(script_type: str) -> float:
if script_type and script_type.upper() == "HANDWRITING_KURRENT":
return THRESHOLD_KURRENT
return THRESHOLD_DEFAULT
def apply_confidence_markers(words: list[dict], threshold: float | None = None) -> str:
"""Replace low-confidence words with [unleserlich], collapsing adjacent markers.
Args:
words: list of {"text": str, "confidence": float} dicts
threshold: confidence threshold (uses THRESHOLD_DEFAULT if None)
Returns:
Reconstructed text string with [unleserlich] substitutions.
"""
if not words:
return ""
if threshold is None:
threshold = THRESHOLD_DEFAULT
result: list[str] = []
prev_was_marker = False
for word in words:
if word["confidence"] < threshold:
if not prev_was_marker:
result.append(ILLEGIBLE_MARKER)
prev_was_marker = True
else:
result.append(word["text"])
prev_was_marker = False
return " ".join(result)
def words_from_characters(prediction: str, confidences: list[float]) -> list[dict]:
"""Reconstruct word-level confidence from character-level data.
Splits prediction on whitespace, maps characters to their confidences,
computes mean confidence per word.
Args:
prediction: full line text from Kraken
confidences: per-character confidence list (same length as prediction)
Returns:
list of {"text": str, "confidence": float} dicts
"""
if not prediction or not prediction.strip():
return []
if len(confidences) != len(prediction):
return [{"text": prediction, "confidence": 1.0}]
result: list[dict] = []
current_word: list[str] = []
current_confs: list[float] = []
for char, conf in zip(prediction, confidences):
if char == " ":
if current_word:
result.append({
"text": "".join(current_word),
"confidence": sum(current_confs) / len(current_confs),
})
current_word = []
current_confs = []
else:
current_word.append(char)
current_confs.append(conf)
if current_word:
result.append({
"text": "".join(current_word),
"confidence": sum(current_confs) / len(current_confs),
})
return result

View File

View File

@@ -0,0 +1,200 @@
"""Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin."""
import logging
import os
logger = logging.getLogger(__name__)
_model = None
_model_path = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
def load_models():
"""Load the Kraken model at startup. Skips if model file is not present."""
global _model
if not os.path.exists(_model_path):
logger.warning("Kraken model not found at %s — Kurrent OCR will not be available", _model_path)
return
logger.info("Loading Kraken model from %s...", _model_path)
from kraken.lib import models as kraken_models
_model = kraken_models.load_any(_model_path)
logger.info("Kraken model loaded successfully")
def is_available() -> bool:
return _model is not None
def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict]:
"""Run Kraken segmentation + recognition on a single PIL image.
Returns block dicts for that page. Coordinates are normalized to [0, 1].
"""
from kraken import blla, rpred
from confidence import words_from_characters
if _model is None:
raise RuntimeError("Kraken model is not loaded")
page_w, page_h = image.size
blocks = []
baseline_seg = blla.segment(image)
pred_it = rpred.rpred(_model, image, baseline_seg)
for record in pred_it:
polygon_pts = record.cuts if hasattr(record, "cuts") else []
if polygon_pts:
xs = [p[0] for p in polygon_pts]
ys = [p[1] for p in polygon_pts]
x1, y1 = min(xs), min(ys)
x2, y2 = max(xs), max(ys)
else:
xs = [p[0] for p in record.line]
ys = [p[1] for p in record.line]
x1, y1 = min(xs), min(ys) - 5
x2, y2 = max(xs), max(ys) + 5
quad = _approximate_to_quad(polygon_pts, page_w, page_h) if polygon_pts else None
char_confidences = getattr(record, "confidences", [])
words = words_from_characters(record.prediction, char_confidences)
blocks.append({
"pageNumber": page_idx,
"x": x1 / page_w,
"y": y1 / page_h,
"width": (x2 - x1) / page_w,
"height": (y2 - y1) / page_h,
"polygon": quad,
"text": record.prediction,
"words": words,
})
return blocks
def extract_blocks(images: list, language: str = "de") -> list[dict]:
"""Run Kraken segmentation + recognition on a list of PIL images.
Returns block dicts with pageNumber, x, y, width, height, polygon, text.
Polygon is a 4-point quadrilateral approximation of the baseline polygon.
Coordinates are normalized to [0, 1].
"""
all_blocks = []
for page_idx, image in enumerate(images, start=1):
all_blocks.extend(extract_page_blocks(image, page_idx, language))
return all_blocks
def _approximate_to_quad(points: list[tuple], page_w: float, page_h: float) -> list[list[float]] | None:
"""Approximate a polygon to a 4-point quadrilateral using the minimum bounding rectangle.
Uses gift-wrapping (Jarvis march) for convex hull, then rotating calipers
for the minimum area bounding rectangle. Pure Python, no scipy/numpy.
"""
if len(points) < 3:
return None
try:
hull = _convex_hull(points)
if len(hull) < 3:
return None
rect = _min_bounding_rect(hull)
# Normalize to [0, 1]
return [[p[0] / page_w, p[1] / page_h] for p in rect]
except Exception:
logger.debug("Failed to approximate polygon to quad, returning None")
return None
def _convex_hull(points: list[tuple]) -> list[tuple]:
"""Jarvis march (gift wrapping) algorithm for 2D convex hull."""
pts = list(set(points))
if len(pts) < 3:
return pts
# Start from leftmost point
start = min(pts, key=lambda p: (p[0], p[1]))
hull = []
current = start
while True:
hull.append(current)
candidate = pts[0]
for p in pts[1:]:
if candidate == current:
candidate = p
continue
cross = _cross(current, candidate, p)
if cross < 0:
candidate = p
elif cross == 0:
# Collinear — pick the farther point
if _dist_sq(current, p) > _dist_sq(current, candidate):
candidate = p
current = candidate
if current == start:
break
return hull
def _min_bounding_rect(hull: list[tuple]) -> list[tuple]:
"""Find the minimum area bounding rectangle of a convex hull using rotating calipers."""
n = len(hull)
if n < 2:
return hull
min_area = float("inf")
best_rect = None
for i in range(n):
# Edge vector
edge_x = hull[(i + 1) % n][0] - hull[i][0]
edge_y = hull[(i + 1) % n][1] - hull[i][1]
edge_len = (edge_x ** 2 + edge_y ** 2) ** 0.5
if edge_len == 0:
continue
# Unit vectors along and perpendicular to the edge
ux, uy = edge_x / edge_len, edge_y / edge_len
vx, vy = -uy, ux
# Project all hull points onto the edge coordinate system
projs_u = [p[0] * ux + p[1] * uy for p in hull]
projs_v = [p[0] * vx + p[1] * vy for p in hull]
min_u, max_u = min(projs_u), max(projs_u)
min_v, max_v = min(projs_v), max(projs_v)
area = (max_u - min_u) * (max_v - min_v)
if area < min_area:
min_area = area
# Reconstruct 4 corners in original coordinates
best_rect = [
(min_u * ux + min_v * vx, min_u * uy + min_v * vy),
(max_u * ux + min_v * vx, max_u * uy + min_v * vy),
(max_u * ux + max_v * vx, max_u * uy + max_v * vy),
(min_u * ux + max_v * vx, min_u * uy + max_v * vy),
]
return best_rect if best_rect else hull[:4]
def _cross(o: tuple, a: tuple, b: tuple) -> float:
return (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0])
def _dist_sq(a: tuple, b: tuple) -> float:
return (a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2

View File

@@ -0,0 +1,97 @@
"""Surya OCR engine wrapper — transformer-based, handles typewritten and modern Latin handwriting."""
import logging
logger = logging.getLogger(__name__)
_recognition_predictor = None
_detection_predictor = None
_loaded = False
def load_models():
"""Lazy-load Surya models on first use to save RAM at idle.
Called automatically by extract_blocks(). Can also be called explicitly
to pre-warm if desired.
"""
global _recognition_predictor, _detection_predictor, _loaded
if _loaded:
return
logger.info("Loading Surya models (lazy, first OCR request)...")
from surya.foundation import FoundationPredictor
from surya.recognition import RecognitionPredictor
from surya.detection import DetectionPredictor
foundation_predictor = FoundationPredictor()
_recognition_predictor = RecognitionPredictor(foundation_predictor)
_detection_predictor = DetectionPredictor()
_loaded = True
logger.info("Surya models loaded successfully")
def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict]:
"""Run Surya OCR on a single PIL image and return block dicts for that page.
Coordinates are normalized to [0, 1].
"""
load_models()
page_w, page_h = image.size
blocks = []
predictions = _recognition_predictor([image], det_predictor=_detection_predictor)
page_pred = predictions[0]
for line in page_pred.text_lines:
bbox = line.bbox
x1, y1, x2, y2 = bbox
polygon = None
if hasattr(line, "polygon") and line.polygon and len(line.polygon) == 4:
polygon = [
[p[0] / page_w, p[1] / page_h]
for p in line.polygon
]
words = []
if hasattr(line, "words") and line.words:
for word in line.words:
words.append({
"text": word.text,
"confidence": word.confidence,
})
else:
words = [{"text": line.text, "confidence": getattr(line, "confidence", 1.0)}]
blocks.append({
"pageNumber": page_idx,
"x": x1 / page_w,
"y": y1 / page_h,
"width": (x2 - x1) / page_w,
"height": (y2 - y1) / page_h,
"polygon": polygon,
"text": line.text,
"words": words,
})
return blocks
def extract_blocks(images: list, language: str = "de") -> list[dict]:
"""Run Surya OCR on a list of PIL images (one per page).
Processes one page at a time to limit peak memory usage.
Returns a flat list of block dicts with pageNumber, x, y, width, height,
polygon, text, words. Coordinates are normalized to [0, 1].
"""
all_blocks = []
for page_idx, image in enumerate(images, start=1):
all_blocks.extend(extract_page_blocks(image, page_idx, language))
del image
return all_blocks

192
ocr-service/main.py Normal file
View File

@@ -0,0 +1,192 @@
"""OCR microservice — FastAPI app with Surya and Kraken engine support."""
import asyncio
import io
import json
import logging
import os
from contextlib import asynccontextmanager
from urllib.parse import urlparse
import httpx
import pypdfium2 as pdfium
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from PIL import Image
from confidence import apply_confidence_markers, get_threshold
from engines import kraken as kraken_engine
from engines import surya as surya_engine
from models import OcrBlock, OcrRequest
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
_models_ready = False
ALLOWED_PDF_HOSTS = set(
h.strip() for h in os.getenv("ALLOWED_PDF_HOSTS", "minio,localhost,127.0.0.1").split(",")
)
def _validate_url(url: str) -> None:
"""Validate that the PDF URL points to an allowed host (SSRF protection)."""
parsed = urlparse(url)
hostname = parsed.hostname or ""
if hostname not in ALLOWED_PDF_HOSTS:
raise HTTPException(status_code=400, detail=f"PDF host not allowed: {hostname}")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load lightweight models at startup. Surya loads lazily on first request."""
global _models_ready
logger.info("Loading Kraken model at startup (Surya loads lazily on first OCR request)...")
kraken_engine.load_models()
_models_ready = True
logger.info("Startup complete — ready to accept requests")
yield
logger.info("Shutting down OCR service")
app = FastAPI(title="Familienarchiv OCR Service", lifespan=lifespan)
@app.get("/health")
def health():
"""Health endpoint — returns 200 only after models are loaded."""
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
return {"status": "ok", "surya": True, "kraken": kraken_engine.is_available()}
@app.post("/ocr", response_model=list[OcrBlock])
async def run_ocr(request: OcrRequest):
"""Run OCR on a PDF document.
Downloads the PDF from the provided URL, converts pages to images,
and runs the appropriate OCR engine based on scriptType.
OCR engines run in a thread pool so the event loop stays free for /health.
"""
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
images = await _download_and_convert_pdf(request.pdfUrl)
script_type = request.scriptType.upper()
if script_type == "HANDWRITING_KURRENT":
if not kraken_engine.is_available():
raise HTTPException(
status_code=400,
detail="Kraken model not available — cannot process Kurrent script",
)
blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language)
else:
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
threshold = get_threshold(script_type)
for block in blocks:
if block.get("words"):
block["text"] = apply_confidence_markers(block["words"], threshold)
block.pop("words", None)
return [OcrBlock(**b) for b in blocks]
@app.post("/ocr/stream")
async def run_ocr_stream(request: OcrRequest):
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page."""
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
images = await _download_and_convert_pdf(request.pdfUrl)
script_type = request.scriptType.upper()
threshold = get_threshold(script_type)
use_kraken = script_type == "HANDWRITING_KURRENT"
if use_kraken and not kraken_engine.is_available():
raise HTTPException(
status_code=400,
detail="Kraken model not available — cannot process Kurrent script",
)
async def generate():
total_pages = len(images)
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
total_blocks = 0
skipped_pages = 0
for page_idx, image in enumerate(images, start=1):
try:
engine = kraken_engine if use_kraken else surya_engine
blocks = await asyncio.to_thread(
engine.extract_page_blocks, image, page_idx, request.language
)
for block in blocks:
if block.get("words"):
block["text"] = apply_confidence_markers(block["words"], threshold)
block.pop("words", None)
total_blocks += len(blocks)
yield json.dumps({
"type": "page",
"pageNumber": page_idx,
"blocks": blocks,
}) + "\n"
except Exception:
logger.exception("OCR failed on page %d", page_idx)
skipped_pages += 1
yield json.dumps({
"type": "error",
"pageNumber": page_idx,
"message": f"OCR processing failed on page {page_idx}",
}) + "\n"
finally:
del image
yield json.dumps({
"type": "done",
"totalBlocks": total_blocks,
"skippedPages": skipped_pages,
}) + "\n"
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
},
)
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
_validate_url(url)
async with httpx.AsyncClient(
timeout=httpx.Timeout(300.0), follow_redirects=False
) as client:
response = await client.get(url)
response.raise_for_status()
pdf = pdfium.PdfDocument(io.BytesIO(response.content))
images = []
for page_idx in range(len(pdf)):
page = pdf[page_idx]
# Render at 200 DPI — balances OCR quality vs memory usage
# (Surya 0.17 models use ~5GB idle; 300 DPI causes OOM on multi-page docs)
bitmap = page.render(scale=200 / 72)
pil_image = bitmap.to_pil()
images.append(pil_image)
return images

21
ocr-service/models.py Normal file
View File

@@ -0,0 +1,21 @@
from pydantic import BaseModel, ConfigDict
class OcrRequest(BaseModel):
model_config = ConfigDict(populate_by_name=True)
pdfUrl: str
scriptType: str = "UNKNOWN"
language: str = "de"
class OcrBlock(BaseModel):
model_config = ConfigDict(populate_by_name=True)
pageNumber: int
x: float
y: float
width: float
height: float
polygon: list[list[float]] | None = None
text: str

View File

@@ -0,0 +1,10 @@
fastapi[standard]==0.115.6
surya-ocr==0.17.1
kraken==7.0
torch==2.7.1
torchvision==0.22.1
transformers>=4.56.1,<5.0.0
pillow>=10.2.0,<11.0.0
pypdfium2==4.30.0
pyvips>=2.2.0
httpx==0.28.1

View File

@@ -0,0 +1,160 @@
"""Tests for confidence-based [unleserlich] marker insertion."""
import os
import pytest
from confidence import apply_confidence_markers, words_from_characters, get_threshold
# ─── apply_confidence_markers ─────────────────────────────────────────────────
def test_all_words_above_threshold_passes_through():
words = [
{"text": "Lieber", "confidence": 0.95},
{"text": "Freund", "confidence": 0.88},
]
assert apply_confidence_markers(words) == "Lieber Freund"
def test_single_low_confidence_word_replaced():
words = [
{"text": "Lieber", "confidence": 0.95},
{"text": "xkqz", "confidence": 0.1},
{"text": "Freund", "confidence": 0.88},
]
assert apply_confidence_markers(words) == "Lieber [unleserlich] Freund"
def test_adjacent_low_confidence_words_collapsed():
words = [
{"text": "Lieber", "confidence": 0.95},
{"text": "xkqz", "confidence": 0.1},
{"text": "abc", "confidence": 0.05},
{"text": "yyy", "confidence": 0.2},
{"text": "Freund", "confidence": 0.88},
]
assert apply_confidence_markers(words) == "Lieber [unleserlich] Freund"
def test_mixed_high_low_each_group_gets_marker():
words = [
{"text": "Lieber", "confidence": 0.95},
{"text": "xkqz", "confidence": 0.1},
{"text": "wie", "confidence": 0.9},
{"text": "abc", "confidence": 0.05},
{"text": "dir", "confidence": 0.88},
]
assert apply_confidence_markers(words) == "Lieber [unleserlich] wie [unleserlich] dir"
def test_all_below_threshold_returns_single_marker():
words = [
{"text": "xkq", "confidence": 0.1},
{"text": "zzz", "confidence": 0.05},
]
assert apply_confidence_markers(words) == "[unleserlich]"
def test_empty_list_returns_empty_string():
assert apply_confidence_markers([]) == ""
def test_single_word_above_threshold():
words = [{"text": "Hallo", "confidence": 0.9}]
assert apply_confidence_markers(words) == "Hallo"
def test_exact_threshold_passes_through():
"""Confidence exactly at threshold should NOT be replaced (strict <)."""
words = [{"text": "Wort", "confidence": 0.3}]
assert apply_confidence_markers(words) == "Wort"
def test_just_below_threshold_replaced():
words = [{"text": "Wort", "confidence": 0.29}]
assert apply_confidence_markers(words) == "[unleserlich]"
def test_custom_threshold_via_parameter():
words = [
{"text": "Lieber", "confidence": 0.95},
{"text": "Freund", "confidence": 0.5},
]
assert apply_confidence_markers(words, threshold=0.8) == "Lieber [unleserlich]"
assert apply_confidence_markers(words, threshold=0.3) == "Lieber Freund"
def test_kurrent_threshold_is_higher_than_default():
default = get_threshold("TYPEWRITER")
kurrent = get_threshold("HANDWRITING_KURRENT")
assert kurrent > default
def test_get_threshold_kurrent():
assert get_threshold("HANDWRITING_KURRENT") == 0.5
def test_get_threshold_default():
assert get_threshold("TYPEWRITER") == 0.3
assert get_threshold("HANDWRITING_LATIN") == 0.3
assert get_threshold("UNKNOWN") == 0.3
def test_low_confidence_at_start():
words = [
{"text": "xkq", "confidence": 0.1},
{"text": "Freund", "confidence": 0.88},
]
assert apply_confidence_markers(words) == "[unleserlich] Freund"
def test_low_confidence_at_end():
words = [
{"text": "Lieber", "confidence": 0.95},
{"text": "xkq", "confidence": 0.1},
]
assert apply_confidence_markers(words) == "Lieber [unleserlich]"
# ─── words_from_characters ────────────────────────────────────────────────────
def test_single_word_matching_confidences():
words = words_from_characters("Hallo", [0.9, 0.8, 0.85, 0.7, 0.95])
assert len(words) == 1
assert words[0]["text"] == "Hallo"
assert abs(words[0]["confidence"] - 0.84) < 0.01
def test_multi_word_with_spaces():
prediction = "Sehr geehrter"
confidences = [0.9, 0.8, 0.7, 0.6, 0.5, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2]
words = words_from_characters(prediction, confidences)
assert len(words) == 2
assert words[0]["text"] == "Sehr"
assert words[1]["text"] == "geehrter"
def test_length_mismatch_falls_back_safely():
words = words_from_characters("Hallo Welt", [0.9, 0.8])
assert len(words) == 1
assert words[0]["text"] == "Hallo Welt"
assert words[0]["confidence"] == 1.0
def test_empty_prediction_returns_empty():
assert words_from_characters("", []) == []
def test_single_character_word():
words = words_from_characters("A B", [0.9, 0.5, 0.3])
assert len(words) == 2
assert words[0]["text"] == "A"
assert words[0]["confidence"] == 0.9
assert words[1]["text"] == "B"
assert words[1]["confidence"] == 0.3
def test_whitespace_only_prediction():
words = words_from_characters(" ", [0.5, 0.5, 0.5])
assert words == []

178
ocr-service/test_engines.py Normal file
View File

@@ -0,0 +1,178 @@
"""Tests for per-page block extraction in OCR engines."""
from unittest.mock import MagicMock, patch
from PIL import Image
from engines import surya, kraken
# ─── Surya extract_page_blocks ───────────────────────────────────────────────
def _make_surya_line(text, bbox, polygon=None, words=None):
"""Create a mock Surya text line with the expected attributes."""
line = MagicMock()
line.text = text
line.bbox = bbox
line.polygon = polygon
line.words = words or []
if not words:
line.confidence = 0.95
# hasattr check needs words to be falsy
del line.words
return line
def test_surya_extract_page_blocks_returns_blocks_for_single_image():
image = Image.new("RGB", (100, 200))
mock_line = _make_surya_line("Hallo Welt", [10, 20, 90, 40])
mock_pred = MagicMock()
mock_pred.text_lines = [mock_line]
with patch.object(surya, "_recognition_predictor") as mock_rec, \
patch.object(surya, "_loaded", True):
mock_rec.return_value = [mock_pred]
blocks = surya.extract_page_blocks(image, page_idx=2, language="de")
assert len(blocks) == 1
assert blocks[0]["pageNumber"] == 2
assert blocks[0]["text"] == "Hallo Welt"
assert blocks[0]["x"] == 10 / 100
assert blocks[0]["y"] == 20 / 200
assert blocks[0]["width"] == 80 / 100
assert blocks[0]["height"] == 20 / 200
def test_surya_extract_page_blocks_extracts_polygon_when_present():
image = Image.new("RGB", (100, 200))
mock_line = _make_surya_line("Text", [10, 20, 90, 40])
mock_line.polygon = [(10, 20), (90, 20), (90, 40), (10, 40)]
mock_pred = MagicMock()
mock_pred.text_lines = [mock_line]
with patch.object(surya, "_recognition_predictor") as mock_rec, \
patch.object(surya, "_loaded", True):
mock_rec.return_value = [mock_pred]
blocks = surya.extract_page_blocks(image, page_idx=1, language="de")
assert blocks[0]["polygon"] is not None
assert len(blocks[0]["polygon"]) == 4
assert blocks[0]["polygon"][0] == [10 / 100, 20 / 200]
def test_surya_extract_page_blocks_extracts_word_level_confidence():
image = Image.new("RGB", (100, 200))
word1 = MagicMock()
word1.text = "Hallo"
word1.confidence = 0.95
word2 = MagicMock()
word2.text = "Welt"
word2.confidence = 0.3
mock_line = _make_surya_line("Hallo Welt", [10, 20, 90, 40], words=[word1, word2])
mock_pred = MagicMock()
mock_pred.text_lines = [mock_line]
with patch.object(surya, "_recognition_predictor") as mock_rec, \
patch.object(surya, "_loaded", True):
mock_rec.return_value = [mock_pred]
blocks = surya.extract_page_blocks(image, page_idx=1, language="de")
assert len(blocks[0]["words"]) == 2
assert blocks[0]["words"][0]["text"] == "Hallo"
assert blocks[0]["words"][0]["confidence"] == 0.95
def test_surya_extract_blocks_delegates_to_extract_page_blocks():
"""After refactoring, extract_blocks should produce the same output."""
image1 = Image.new("RGB", (100, 200))
image2 = Image.new("RGB", (100, 200))
mock_line = _make_surya_line("Test", [10, 20, 90, 40])
mock_pred = MagicMock()
mock_pred.text_lines = [mock_line]
with patch.object(surya, "_recognition_predictor") as mock_rec, \
patch.object(surya, "_loaded", True):
mock_rec.return_value = [mock_pred]
blocks = surya.extract_blocks([image1, image2])
assert len(blocks) == 2
assert blocks[0]["pageNumber"] == 1
assert blocks[1]["pageNumber"] == 2
# ─── Kraken extract_page_blocks ──────────────────────────────────────────────
def _make_kraken_record(text, cuts, confidences=None):
record = MagicMock()
record.prediction = text
record.cuts = cuts
record.line = cuts
record.confidences = confidences or [0.9] * len(text)
return record
def _run_kraken_with_mocks(fn, *args):
"""Run a kraken function with blla/rpred mocked via sys.modules."""
import sys
mock_blla = MagicMock()
mock_blla.segment.return_value = MagicMock()
mock_rpred = MagicMock()
mock_record = _make_kraken_record("Kurrent", [(10, 20), (90, 20), (90, 40), (10, 40)])
mock_rpred.rpred.return_value = [mock_record]
saved_kraken = sys.modules.get("kraken")
saved_blla = sys.modules.get("kraken.blla")
saved_rpred = sys.modules.get("kraken.rpred")
sys.modules["kraken"] = MagicMock(blla=mock_blla, rpred=mock_rpred)
sys.modules["kraken.blla"] = mock_blla
sys.modules["kraken.rpred"] = mock_rpred
try:
with patch.object(kraken, "_model", MagicMock()):
return fn(*args)
finally:
if saved_kraken is not None:
sys.modules["kraken"] = saved_kraken
else:
sys.modules.pop("kraken", None)
if saved_blla is not None:
sys.modules["kraken.blla"] = saved_blla
else:
sys.modules.pop("kraken.blla", None)
if saved_rpred is not None:
sys.modules["kraken.rpred"] = saved_rpred
else:
sys.modules.pop("kraken.rpred", None)
def test_kraken_extract_page_blocks_returns_blocks_for_single_image():
image = Image.new("RGB", (100, 200))
blocks = _run_kraken_with_mocks(kraken.extract_page_blocks, image, 3, "de")
assert len(blocks) == 1
assert blocks[0]["pageNumber"] == 3
assert blocks[0]["text"] == "Kurrent"
def test_kraken_extract_blocks_delegates_to_extract_page_blocks():
image1 = Image.new("RGB", (100, 200))
image2 = Image.new("RGB", (100, 200))
blocks = _run_kraken_with_mocks(kraken.extract_blocks, [image1, image2])
assert len(blocks) == 2
assert blocks[0]["pageNumber"] == 1
assert blocks[1]["pageNumber"] == 2

303
ocr-service/test_stream.py Normal file
View File

@@ -0,0 +1,303 @@
"""Tests for the NDJSON streaming OCR endpoint POST /ocr/stream."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from main import app
@pytest.fixture
def mock_images():
"""Return 3 fake PIL images."""
from PIL import Image
return [Image.new("RGB", (100, 200)) for _ in range(3)]
def _make_block(page_idx, text="Test"):
return {
"pageNumber": page_idx,
"x": 0.1, "y": 0.2, "width": 0.8, "height": 0.1,
"polygon": None, "text": text,
"words": [{"text": text, "confidence": 0.95}],
}
# ─── P3: start event with total pages ────────────────────────────────────────
@pytest.mark.asyncio
async def test_ocr_stream_emits_start_event_with_total_pages(mock_images):
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya:
mock_surya.extract_page_blocks.return_value = [_make_block(0)]
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
"language": "de",
})
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
assert lines[0] == {"type": "start", "totalPages": 3}
# ─── P4: page events per completed page ──────────────────────────────────────
@pytest.mark.asyncio
async def test_ocr_stream_emits_page_event_per_page_with_blocks(mock_images):
def page_blocks(image, page_idx, language="de"):
return [_make_block(page_idx, f"Page {page_idx}")]
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya:
mock_surya.extract_page_blocks.side_effect = page_blocks
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
page_events = [l for l in lines if l["type"] == "page"]
assert len(page_events) == 3
assert page_events[0]["pageNumber"] == 1
assert page_events[1]["pageNumber"] == 2
assert page_events[2]["pageNumber"] == 3
assert len(page_events[0]["blocks"]) == 1
# ─── P5: done event with total blocks and skipped ────────────────────────────
@pytest.mark.asyncio
async def test_ocr_stream_emits_done_with_total_blocks(mock_images):
def page_blocks(image, page_idx, language="de"):
return [_make_block(page_idx)] * 2 # 2 blocks per page
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya:
mock_surya.extract_page_blocks.side_effect = page_blocks
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
done = [l for l in lines if l["type"] == "done"][0]
assert done["totalBlocks"] == 6
assert done["skippedPages"] == 0
# ─── P6: error event on page failure, continues ──────────────────────────────
@pytest.mark.asyncio
async def test_ocr_stream_emits_error_event_on_page_failure_and_continues(mock_images):
call_count = 0
def page_blocks(image, page_idx, language="de"):
nonlocal call_count
call_count += 1
if page_idx == 2:
raise RuntimeError("Engine crashed on page 2")
return [_make_block(page_idx)]
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya:
mock_surya.extract_page_blocks.side_effect = page_blocks
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
types = [l["type"] for l in lines]
assert "error" in types
error_event = [l for l in lines if l["type"] == "error"][0]
assert error_event["pageNumber"] == 2
# Error message must be generic, not the raw traceback
assert "Engine crashed" not in error_event["message"]
# Processing continued: pages 0 and 2 have page events
page_events = [l for l in lines if l["type"] == "page"]
assert len(page_events) == 2
done = [l for l in lines if l["type"] == "done"][0]
assert done["skippedPages"] == 1
# ─── P7: old /ocr endpoint still works ───────────────────────────────────────
@pytest.mark.asyncio
async def test_old_ocr_endpoint_still_returns_flat_list(mock_images):
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya:
mock_surya.extract_blocks.return_value = [_make_block(1), _make_block(2)]
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 2
assert data[0]["pageNumber"] == 1
# ─── Health and error handling ────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_health_returns_ok_when_models_ready():
with patch("main._models_ready", True), \
patch("main.kraken_engine") as mock_kraken:
mock_kraken.is_available.return_value = True
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["surya"] is True
assert data["kraken"] is True
@pytest.mark.asyncio
async def test_health_returns_503_when_models_not_ready():
with patch("main._models_ready", False):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.get("/health")
assert response.status_code == 503
@pytest.mark.asyncio
async def test_ocr_stream_returns_503_when_models_not_ready():
with patch("main._models_ready", False):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
assert response.status_code == 503
@pytest.mark.asyncio
async def test_ocr_stream_uses_kraken_for_kurrent_script(mock_images):
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.kraken_engine") as mock_kraken:
mock_kraken.is_available.return_value = True
mock_kraken.extract_page_blocks.return_value = [_make_block(1, "Kurrent text")]
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "HANDWRITING_KURRENT",
})
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
page_events = [l for l in lines if l["type"] == "page"]
assert len(page_events) == 3
mock_kraken.extract_page_blocks.assert_called()
@pytest.mark.asyncio
async def test_ocr_stream_returns_400_when_kraken_unavailable_for_kurrent(mock_images):
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.kraken_engine") as mock_kraken:
mock_kraken.is_available.return_value = False
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "HANDWRITING_KURRENT",
})
assert response.status_code == 400
# ─── SSRF protection ─────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_ocr_stream_rejects_disallowed_host():
with patch("main._models_ready", True):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://evil.example.com/malicious.pdf",
"scriptType": "TYPEWRITER",
})
assert response.status_code == 400
assert "not allowed" in response.json()["detail"]
@pytest.mark.asyncio
async def test_ocr_stream_allows_minio_host(mock_images):
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya:
mock_surya.extract_page_blocks.return_value = [_make_block(0)]
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
assert response.status_code == 200
@pytest.mark.asyncio
async def test_ocr_stream_applies_confidence_markers(mock_images):
"""Low-confidence words should be replaced with [unleserlich] in the stream output."""
def page_blocks(image, page_idx, language="de"):
return [{
"pageNumber": page_idx,
"x": 0.1, "y": 0.2, "width": 0.8, "height": 0.1,
"polygon": None,
"text": "original text",
"words": [
{"text": "Lieber", "confidence": 0.95},
{"text": "xkqz", "confidence": 0.1},
],
}]
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images[:1]), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya:
mock_surya.extract_page_blocks.side_effect = page_blocks
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
page = [l for l in lines if l["type"] == "page"][0]
assert page["blocks"][0]["text"] == "Lieber [unleserlich]"
assert "words" not in page["blocks"][0]

203
scripts/download-kraken-models.sh Executable file
View File

@@ -0,0 +1,203 @@
#!/bin/bash
set -euo pipefail
# Downloads Kraken HTR models, runs OCR evaluation on sample PDFs, and activates the winner.
#
# Usage:
# ./scripts/download-kraken-models.sh # download models + run evaluation
# ./scripts/download-kraken-models.sh --activate 1 # activate model 1
# ./scripts/download-kraken-models.sh --activate 2 # activate model 2
# ./scripts/download-kraken-models.sh --eval-only # re-run evaluation (models already downloaded)
COMPOSE_SERVICE="ocr-service"
MODEL_DIR="/app/models"
ACTIVE_MODEL="$MODEL_DIR/german_kurrent.mlmodel"
# Kraken 7 uses DOIs to identify models from HTR-United / Zenodo
MODEL_1_DOI="10.5281/zenodo.7933463"
MODEL_1_NAME="german_handwriting"
MODEL_1_DESC="HTR model for German manuscripts (handwritten text recognition)"
MODEL_1_PATH="$MODEL_DIR/$MODEL_1_NAME.mlmodel"
MODEL_2_DOI="10.5281/zenodo.13788177"
MODEL_2_NAME="mccatmus"
MODEL_2_DESC="McCATMuS — generic model for handwritten, printed & typewritten (16th c. onward)"
MODEL_2_PATH="$MODEL_DIR/$MODEL_2_NAME.mlmodel"
EVAL_PDFS=("Eu-0693.pdf" "Eu-0692.pdf" "W-0150.pdf" "W-0575.pdf")
IMPORT_DIR="./import"
OUTPUT_BASE="./ocr-model-evaluation"
# ─── Functions ────────────────────────────────────────────────────────────────
download_model() {
local doi="$1"
local dest="$2"
local name="$3"
echo " Downloading $name ($doi)..."
# kraken get downloads to /root/.local/share/htrmopo/<uuid>/<name>.mlmodel
# Parse the "Model dir: <path>" line from kraken output to locate the file
docker compose run --rm "$COMPOSE_SERVICE" sh -c '
OUTPUT=$(kraken get '"$doi"' 2>&1)
echo "$OUTPUT"
MODEL_DIR=$(echo "$OUTPUT" | grep -oP "Model dir: \K[^ ]+")
if [ -n "$MODEL_DIR" ] && [ -d "$MODEL_DIR" ]; then
FOUND=$(find "$MODEL_DIR" -name "*.mlmodel" | head -1)
if [ -n "$FOUND" ]; then
cp "$FOUND" '"$dest"'
echo "Saved to '"$dest"'"
else
echo "ERROR: No .mlmodel file in $MODEL_DIR"
ls -la "$MODEL_DIR"
exit 1
fi
else
echo "ERROR: Could not parse model directory from kraken output"
exit 1
fi
'
}
download_models() {
echo "Downloading Kraken HTR models into the ocr_models volume..."
echo ""
echo "Model 1: $MODEL_1_NAME"
echo " $MODEL_1_DESC"
download_model "$MODEL_1_DOI" "$MODEL_1_PATH" "$MODEL_1_NAME"
echo ""
echo "Model 2: $MODEL_2_NAME"
echo " $MODEL_2_DESC"
download_model "$MODEL_2_DOI" "$MODEL_2_PATH" "$MODEL_2_NAME"
echo ""
echo "Both models downloaded."
}
run_evaluation() {
echo ""
echo "═══════════════════════════════════════════════════════"
echo " Running OCR evaluation on ${#EVAL_PDFS[@]} documents"
echo "═══════════════════════════════════════════════════════"
echo ""
local dir_1="$OUTPUT_BASE/$MODEL_1_NAME"
local dir_2="$OUTPUT_BASE/$MODEL_2_NAME"
mkdir -p "$dir_1" "$dir_2"
for pdf in "${EVAL_PDFS[@]}"; do
local src="$IMPORT_DIR/$pdf"
local basename="${pdf%.pdf}"
if [[ ! -f "$src" ]]; then
echo "SKIP: $src not found"
continue
fi
echo "──── $pdf ────"
# Run both models inside a single container run:
# 1. Extract PDF pages as PNGs (pypdfium2 is already installed)
# 2. Run kraken on each page image for both models
# 3. Concatenate per-page output into one file per model
docker compose run --rm \
-v "$(cd "$IMPORT_DIR" && pwd):/eval-input:ro" \
-v "$(cd "$dir_1" && pwd):/eval-out-1" \
-v "$(cd "$dir_2" && pwd):/eval-out-2" \
"$COMPOSE_SERVICE" \
python3 -c "
import pypdfium2 as pdfium, subprocess, sys, os
pdf = pdfium.PdfDocument('/eval-input/$pdf')
pages = []
for i in range(len(pdf)):
bmp = pdf[i].render(scale=300/72)
path = f'/tmp/page_{i:04d}.png'
bmp.to_pil().save(path)
pages.append(path)
print(f'Extracted {len(pages)} pages')
for label, model, outdir in [
('Model 1', '$MODEL_1_PATH', '/eval-out-1'),
('Model 2', '$MODEL_2_PATH', '/eval-out-2'),
]:
print(f' {label}...')
combined = ''
for p in pages:
args = ['kraken', '-i', p, '/dev/stdout', 'segment', '-bl', 'ocr', '-m', model]
r = subprocess.run(args, capture_output=True, text=True)
combined += r.stdout
if r.returncode != 0:
print(f' ⚠ failed on {os.path.basename(p)}: {r.stderr[:200]}', file=sys.stderr)
with open(f'{outdir}/${basename}.txt', 'w') as f:
f.write(combined)
lines = combined.count(chr(10))
print(f' → {lines} lines')
" || echo " ⚠ Failed on $pdf"
echo ""
done
echo "═══════════════════════════════════════════════════════"
echo " Evaluation complete. Results:"
echo ""
echo " Model 1 — $MODEL_1_NAME ($MODEL_1_DOI):"
for f in "$dir_1"/*.txt; do
[[ -f "$f" ]] && echo " $(basename "$f"): $(wc -l < "$f") lines, $(wc -c < "$f") bytes"
done
echo ""
echo " Model 2 — $MODEL_2_NAME ($MODEL_2_DOI):"
for f in "$dir_2"/*.txt; do
[[ -f "$f" ]] && echo " $(basename "$f"): $(wc -l < "$f") lines, $(wc -c < "$f") bytes"
done
echo ""
echo " Compare outputs:"
echo " diff $dir_1/Eu-0693.txt $dir_2/Eu-0693.txt"
echo " # or view individually:"
echo " cat $dir_1/Eu-0693.txt"
echo ""
echo " Activate the better model:"
echo " ./scripts/download-kraken-models.sh --activate 1 # $MODEL_1_NAME"
echo " ./scripts/download-kraken-models.sh --activate 2 # $MODEL_2_NAME"
echo "═══════════════════════════════════════════════════════"
}
activate_model() {
local choice="$1"
case "$choice" in
1)
echo "Activating model 1: $MODEL_1_NAME"
docker compose run --rm "$COMPOSE_SERVICE" cp "$MODEL_1_PATH" "$ACTIVE_MODEL"
;;
2)
echo "Activating model 2: $MODEL_2_NAME"
docker compose run --rm "$COMPOSE_SERVICE" cp "$MODEL_2_PATH" "$ACTIVE_MODEL"
;;
*)
echo "Error: --activate expects 1 or 2"
exit 1
;;
esac
echo "Active model is now: $ACTIVE_MODEL"
echo "Restart the OCR service to load the new model:"
echo " docker compose restart ocr-service"
}
# ─── Main ─────────────────────────────────────────────────────────────────────
case "${1:-}" in
--activate)
activate_model "${2:-}"
;;
--eval-only)
run_evaluation
;;
*)
download_models
run_evaluation
;;
esac