@@ -16,10 +16,10 @@ public class AsyncConfig {
|
||||
@Bean
|
||||
public Executor taskExecutor() {
|
||||
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
|
||||
executor.setCorePoolSize(1);
|
||||
executor.setMaxPoolSize(1);
|
||||
executor.setQueueCapacity(1);
|
||||
executor.setThreadNamePrefix("Import-");
|
||||
executor.setCorePoolSize(2);
|
||||
executor.setMaxPoolSize(2);
|
||||
executor.setQueueCapacity(10);
|
||||
executor.setThreadNamePrefix("Async-");
|
||||
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy());
|
||||
return executor;
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.s3.S3Client;
|
||||
import software.amazon.awssdk.services.s3.S3Configuration;
|
||||
import software.amazon.awssdk.services.s3.presigner.S3Presigner;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.CommandLineRunner;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@@ -44,6 +45,19 @@ public class MinioConfig {
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public S3Presigner s3Presigner() {
|
||||
return S3Presigner.builder()
|
||||
.endpointOverride(URI.create(endpoint))
|
||||
.serviceConfiguration(S3Configuration.builder()
|
||||
.pathStyleAccessEnabled(true)
|
||||
.build())
|
||||
.region(Region.of(region))
|
||||
.credentialsProvider(StaticCredentialsProvider.create(
|
||||
AwsBasicCredentials.create(accessKey, secretKey)))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public CommandLineRunner testS3Connection(S3Client s3Client) {
|
||||
return args -> {
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
package org.raddatz.familienarchiv.controller;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.dto.BatchOcrDTO;
|
||||
import org.raddatz.familienarchiv.dto.OcrStatusDTO;
|
||||
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
|
||||
import org.raddatz.familienarchiv.model.AppUser;
|
||||
import org.raddatz.familienarchiv.model.OcrJob;
|
||||
import org.raddatz.familienarchiv.security.Permission;
|
||||
import org.raddatz.familienarchiv.security.RequirePermission;
|
||||
import org.raddatz.familienarchiv.service.OcrBatchService;
|
||||
import org.raddatz.familienarchiv.service.OcrProgressService;
|
||||
import org.raddatz.familienarchiv.service.OcrService;
|
||||
import org.raddatz.familienarchiv.service.UserService;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import jakarta.validation.Valid;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
@RestController
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class OcrController {
|
||||
|
||||
private final OcrService ocrService;
|
||||
private final OcrBatchService ocrBatchService;
|
||||
private final OcrProgressService ocrProgressService;
|
||||
private final UserService userService;
|
||||
|
||||
@PostMapping("/api/documents/{documentId}/ocr")
|
||||
@ResponseStatus(HttpStatus.ACCEPTED)
|
||||
@RequirePermission(Permission.WRITE_ALL)
|
||||
public Map<String, UUID> triggerOcr(
|
||||
@PathVariable UUID documentId,
|
||||
@RequestBody TriggerOcrDTO dto,
|
||||
Authentication authentication) {
|
||||
UUID userId = resolveUserId(authentication);
|
||||
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId);
|
||||
return Map.of("jobId", jobId);
|
||||
}
|
||||
|
||||
@PostMapping("/api/ocr/batch")
|
||||
@ResponseStatus(HttpStatus.ACCEPTED)
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public Map<String, UUID> triggerBatch(
|
||||
@RequestBody @Valid BatchOcrDTO dto,
|
||||
Authentication authentication) {
|
||||
UUID userId = resolveUserId(authentication);
|
||||
UUID jobId = ocrBatchService.startBatch(dto.getDocumentIds(), userId);
|
||||
return Map.of("jobId", jobId);
|
||||
}
|
||||
|
||||
@GetMapping("/api/ocr/jobs/{jobId}")
|
||||
@RequirePermission(Permission.READ_ALL)
|
||||
public OcrJob getJobStatus(@PathVariable UUID jobId) {
|
||||
return ocrService.getJob(jobId);
|
||||
}
|
||||
|
||||
@GetMapping(value = "/api/ocr/jobs/{jobId}/progress", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
@RequirePermission(Permission.READ_ALL)
|
||||
public SseEmitter streamProgress(@PathVariable UUID jobId) {
|
||||
ocrService.getJob(jobId);
|
||||
return ocrProgressService.register(jobId);
|
||||
}
|
||||
|
||||
@GetMapping("/api/documents/{documentId}/ocr-status")
|
||||
@RequirePermission(Permission.READ_ALL)
|
||||
public OcrStatusDTO getDocumentOcrStatus(@PathVariable UUID documentId) {
|
||||
return ocrService.getDocumentOcrStatus(documentId);
|
||||
}
|
||||
|
||||
private UUID resolveUserId(Authentication authentication) {
|
||||
if (authentication == null || !authentication.isAuthenticated()) return null;
|
||||
try {
|
||||
AppUser user = userService.findByUsername(authentication.getName());
|
||||
return user != null ? user.getId() : null;
|
||||
} catch (Exception e) {
|
||||
log.warn("Failed to resolve user ID for authentication: {}", authentication.getName(), e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -81,6 +81,14 @@ public class TranscriptionBlockController {
|
||||
return transcriptionService.listBlocks(documentId);
|
||||
}
|
||||
|
||||
@PutMapping("/{blockId}/review")
|
||||
@RequirePermission(Permission.WRITE_ALL)
|
||||
public TranscriptionBlock reviewBlock(
|
||||
@PathVariable UUID documentId,
|
||||
@PathVariable UUID blockId) {
|
||||
return transcriptionService.reviewBlock(documentId, blockId);
|
||||
}
|
||||
|
||||
@GetMapping("/{blockId}/history")
|
||||
@RequirePermission(Permission.READ_ALL)
|
||||
public List<TranscriptionBlockVersion> getBlockHistory(
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package org.raddatz.familienarchiv.dto;
|
||||
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.Size;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class BatchOcrDTO {
|
||||
@NotEmpty
|
||||
@Size(max = 500, message = "batch size must not exceed 500 documents")
|
||||
private List<UUID> documentIds;
|
||||
}
|
||||
@@ -1,9 +1,15 @@
|
||||
package org.raddatz.familienarchiv.dto;
|
||||
|
||||
import jakarta.validation.Valid;
|
||||
import jakarta.validation.constraints.DecimalMax;
|
||||
import jakarta.validation.constraints.DecimalMin;
|
||||
import jakarta.validation.constraints.Size;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@@ -14,4 +20,19 @@ public class CreateAnnotationDTO {
|
||||
private double width;
|
||||
private double height;
|
||||
private String color;
|
||||
|
||||
@Size(min = 4, max = 4, message = "polygon must have exactly 4 points")
|
||||
@UniquePoints
|
||||
@Valid
|
||||
private List<@Size(min = 2, max = 2, message = "each point must have exactly 2 coordinates")
|
||||
List<@DecimalMin("0.0") @DecimalMax("1.0") Double>> polygon;
|
||||
|
||||
public CreateAnnotationDTO(int pageNumber, double x, double y, double width, double height, String color) {
|
||||
this.pageNumber = pageNumber;
|
||||
this.x = x;
|
||||
this.y = y;
|
||||
this.width = width;
|
||||
this.height = height;
|
||||
this.color = color;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import lombok.Data;
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
|
||||
@Data
|
||||
public class DocumentUpdateDTO {
|
||||
@@ -18,4 +19,5 @@ public class DocumentUpdateDTO {
|
||||
private List<UUID> receiverIds;
|
||||
private String tags;
|
||||
private Boolean metadataComplete;
|
||||
private ScriptType scriptType;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package org.raddatz.familienarchiv.dto;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class OcrStatusDTO {
|
||||
private String status;
|
||||
private UUID jobId;
|
||||
private int currentPage;
|
||||
private int totalPages;
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package org.raddatz.familienarchiv.dto;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class TriggerOcrDTO {
|
||||
private ScriptType scriptType;
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package org.raddatz.familienarchiv.dto;
|
||||
|
||||
import jakarta.validation.Constraint;
|
||||
import jakarta.validation.Payload;
|
||||
|
||||
import java.lang.annotation.*;
|
||||
|
||||
@Documented
|
||||
@Constraint(validatedBy = UniquePointsValidator.class)
|
||||
@Target({ElementType.FIELD})
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
public @interface UniquePoints {
|
||||
String message() default "polygon must contain 4 unique points";
|
||||
Class<?>[] groups() default {};
|
||||
Class<? extends Payload>[] payload() default {};
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package org.raddatz.familienarchiv.dto;
|
||||
|
||||
import jakarta.validation.ConstraintValidator;
|
||||
import jakarta.validation.ConstraintValidatorContext;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
|
||||
public class UniquePointsValidator implements ConstraintValidator<UniquePoints, List<List<Double>>> {
|
||||
|
||||
@Override
|
||||
public boolean isValid(List<List<Double>> polygon, ConstraintValidatorContext context) {
|
||||
if (polygon == null) return true;
|
||||
return new HashSet<>(polygon).size() == polygon.size();
|
||||
}
|
||||
}
|
||||
@@ -66,6 +66,16 @@ public enum ErrorCode {
|
||||
/** The notification with the given ID does not exist. 404 */
|
||||
NOTIFICATION_NOT_FOUND,
|
||||
|
||||
// --- OCR ---
|
||||
/** The OCR service is not available or not healthy. 503 */
|
||||
OCR_SERVICE_UNAVAILABLE,
|
||||
/** The OCR job with the given ID does not exist. 404 */
|
||||
OCR_JOB_NOT_FOUND,
|
||||
/** The document is not in UPLOADED status and cannot be OCR'd. 400 */
|
||||
OCR_DOCUMENT_NOT_UPLOADED,
|
||||
/** OCR processing failed for the document. 500 */
|
||||
OCR_PROCESSING_FAILED,
|
||||
|
||||
// --- Generic ---
|
||||
/** Request validation failed (missing or malformed fields). 400 */
|
||||
VALIDATION_ERROR,
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
public enum BlockSource {
|
||||
MANUAL,
|
||||
OCR
|
||||
}
|
||||
@@ -91,6 +91,12 @@ public class Document {
|
||||
@Builder.Default
|
||||
private boolean metadataComplete = false;
|
||||
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Column(name = "script_type", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@Builder.Default
|
||||
private ScriptType scriptType = ScriptType.UNKNOWN;
|
||||
|
||||
@ManyToMany(fetch = FetchType.EAGER)
|
||||
@JoinTable(name = "document_receivers", joinColumns = @JoinColumn(name = "document_id"), inverseJoinColumns = @JoinColumn(name = "person_id"))
|
||||
@Builder.Default
|
||||
|
||||
@@ -4,8 +4,11 @@ import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.persistence.*;
|
||||
import lombok.*;
|
||||
import org.hibernate.annotations.CreationTimestamp;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
@Entity
|
||||
@@ -52,6 +55,10 @@ public class DocumentAnnotation {
|
||||
@Column(name = "file_hash", length = 64)
|
||||
private String fileHash;
|
||||
|
||||
@JdbcTypeCode(SqlTypes.JSON)
|
||||
@Column(columnDefinition = "jsonb")
|
||||
private List<List<Double>> polygon;
|
||||
|
||||
@Column(name = "created_by")
|
||||
private UUID createdBy;
|
||||
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
public enum OcrDocumentStatus {
|
||||
PENDING,
|
||||
RUNNING,
|
||||
DONE,
|
||||
FAILED,
|
||||
SKIPPED
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.persistence.*;
|
||||
import lombok.*;
|
||||
import org.hibernate.annotations.CreationTimestamp;
|
||||
import org.hibernate.annotations.UpdateTimestamp;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.UUID;
|
||||
|
||||
@Entity
|
||||
@Table(name = "ocr_jobs")
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class OcrJob {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.UUID)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private UUID id;
|
||||
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Column(nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@Builder.Default
|
||||
private OcrJobStatus status = OcrJobStatus.PENDING;
|
||||
|
||||
@Column(name = "total_documents", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private int totalDocuments;
|
||||
|
||||
@Column(name = "processed_documents", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@Builder.Default
|
||||
private int processedDocuments = 0;
|
||||
|
||||
@Column(name = "error_count", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@Builder.Default
|
||||
private int errorCount = 0;
|
||||
|
||||
@Column(name = "skipped_count", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@Builder.Default
|
||||
private int skippedCount = 0;
|
||||
|
||||
@Column(name = "progress_message")
|
||||
private String progressMessage;
|
||||
|
||||
@Column(name = "created_by")
|
||||
private UUID createdBy;
|
||||
|
||||
@Column(name = "created_at", nullable = false, updatable = false)
|
||||
@CreationTimestamp
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private LocalDateTime createdAt;
|
||||
|
||||
@Column(name = "updated_at", nullable = false)
|
||||
@UpdateTimestamp
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private LocalDateTime updatedAt;
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.persistence.*;
|
||||
import lombok.*;
|
||||
import org.hibernate.annotations.CreationTimestamp;
|
||||
import org.hibernate.annotations.UpdateTimestamp;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.UUID;
|
||||
|
||||
@Entity
|
||||
@Table(name = "ocr_job_documents")
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class OcrJobDocument {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.UUID)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private UUID id;
|
||||
|
||||
@Column(name = "job_id", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private UUID jobId;
|
||||
|
||||
@Column(name = "document_id", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private UUID documentId;
|
||||
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Column(nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@Builder.Default
|
||||
private OcrDocumentStatus status = OcrDocumentStatus.PENDING;
|
||||
|
||||
@Column(name = "error_message")
|
||||
private String errorMessage;
|
||||
|
||||
@Column(name = "current_page")
|
||||
@Builder.Default
|
||||
private int currentPage = 0;
|
||||
|
||||
@Column(name = "total_pages")
|
||||
@Builder.Default
|
||||
private int totalPages = 0;
|
||||
|
||||
@Column(name = "created_at", nullable = false, updatable = false)
|
||||
@CreationTimestamp
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private LocalDateTime createdAt;
|
||||
|
||||
@Column(name = "updated_at", nullable = false)
|
||||
@UpdateTimestamp
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private LocalDateTime updatedAt;
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
public enum OcrJobStatus {
|
||||
PENDING,
|
||||
RUNNING,
|
||||
DONE,
|
||||
FAILED
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import jakarta.persistence.AttributeConverter;
|
||||
import jakarta.persistence.Converter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Converter
|
||||
public class PolygonConverter implements AttributeConverter<List<List<Double>>, String> {
|
||||
|
||||
private static final ObjectMapper MAPPER = new ObjectMapper();
|
||||
private static final TypeReference<List<List<Double>>> TYPE_REF = new TypeReference<>() {};
|
||||
|
||||
@Override
|
||||
public String convertToDatabaseColumn(List<List<Double>> polygon) {
|
||||
if (polygon == null) return null;
|
||||
try {
|
||||
return MAPPER.writeValueAsString(polygon);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalArgumentException("Failed to serialize polygon", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<List<Double>> convertToEntityAttribute(String json) {
|
||||
if (json == null || json.isEmpty()) return null;
|
||||
try {
|
||||
return MAPPER.readValue(json, TYPE_REF);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalArgumentException("Failed to deserialize polygon", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
public enum ScriptType {
|
||||
UNKNOWN,
|
||||
TYPEWRITER,
|
||||
HANDWRITING_LATIN,
|
||||
HANDWRITING_KURRENT
|
||||
}
|
||||
@@ -41,6 +41,17 @@ public class TranscriptionBlock {
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private int sortOrder;
|
||||
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Column(nullable = false, length = 10)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@Builder.Default
|
||||
private BlockSource source = BlockSource.MANUAL;
|
||||
|
||||
@Column(nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@Builder.Default
|
||||
private boolean reviewed = false;
|
||||
|
||||
@Version
|
||||
@Column(nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package org.raddatz.familienarchiv.repository;
|
||||
|
||||
import org.raddatz.familienarchiv.model.OcrDocumentStatus;
|
||||
import org.raddatz.familienarchiv.model.OcrJobDocument;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
public interface OcrJobDocumentRepository extends JpaRepository<OcrJobDocument, UUID> {
|
||||
|
||||
List<OcrJobDocument> findByJobIdOrderByCreatedAtAsc(UUID jobId);
|
||||
|
||||
List<OcrJobDocument> findByJobIdAndStatus(UUID jobId, OcrDocumentStatus status);
|
||||
|
||||
Optional<OcrJobDocument> findByJobIdAndDocumentId(UUID jobId, UUID documentId);
|
||||
|
||||
Optional<OcrJobDocument> findFirstByDocumentIdAndStatusIn(UUID documentId, List<OcrDocumentStatus> statuses);
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package org.raddatz.familienarchiv.repository;
|
||||
|
||||
import org.raddatz.familienarchiv.model.OcrJob;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
public interface OcrJobRepository extends JpaRepository<OcrJob, UUID> {
|
||||
}
|
||||
@@ -48,6 +48,26 @@ public class AnnotationService {
|
||||
return annotationRepository.save(annotation);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public DocumentAnnotation createOcrAnnotation(UUID documentId, CreateAnnotationDTO dto,
|
||||
UUID userId, String fileHash,
|
||||
List<List<Double>> polygon) {
|
||||
DocumentAnnotation annotation = DocumentAnnotation.builder()
|
||||
.documentId(documentId)
|
||||
.pageNumber(dto.getPageNumber())
|
||||
.x(dto.getX())
|
||||
.y(dto.getY())
|
||||
.width(dto.getWidth())
|
||||
.height(dto.getHeight())
|
||||
.color(dto.getColor())
|
||||
.fileHash(fileHash)
|
||||
.createdBy(userId)
|
||||
.polygon(polygon)
|
||||
.build();
|
||||
|
||||
return annotationRepository.save(annotation);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void deleteAnnotation(UUID documentId, UUID annotationId, UUID userId) {
|
||||
DocumentAnnotation annotation = annotationRepository
|
||||
|
||||
@@ -8,6 +8,7 @@ import org.raddatz.familienarchiv.dto.IncompleteDocumentDTO;
|
||||
import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.dto.DocumentSort;
|
||||
import org.raddatz.familienarchiv.model.DocumentStatus;
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
import org.raddatz.familienarchiv.model.Person;
|
||||
import org.raddatz.familienarchiv.model.Tag;
|
||||
import org.raddatz.familienarchiv.repository.DocumentRepository;
|
||||
@@ -222,6 +223,10 @@ public class DocumentService {
|
||||
doc.setMetadataComplete(dto.getMetadataComplete());
|
||||
}
|
||||
|
||||
if (dto.getScriptType() != null) {
|
||||
doc.setScriptType(dto.getScriptType());
|
||||
}
|
||||
|
||||
// 4. Datei austauschen (nur wenn eine neue ausgewählt wurde)
|
||||
if (newFile != null && !newFile.isEmpty()) {
|
||||
FileService.UploadResult upload = fileService.uploadFile(newFile, newFile.getOriginalFilename());
|
||||
@@ -373,6 +378,13 @@ public class DocumentService {
|
||||
return documentRepository.findAll(conversation, Sort.by(Sort.Direction.ASC, "documentDate"));
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void updateScriptType(UUID documentId, ScriptType scriptType) {
|
||||
Document doc = getDocumentById(documentId);
|
||||
doc.setScriptType(scriptType);
|
||||
documentRepository.save(doc);
|
||||
}
|
||||
|
||||
public Document getDocumentById(UUID id) {
|
||||
return documentRepository.findById(id)
|
||||
.orElseThrow(() -> DomainException.notFound(ErrorCode.DOCUMENT_NOT_FOUND, "Document not found: " + id));
|
||||
|
||||
@@ -4,6 +4,8 @@ import software.amazon.awssdk.core.ResponseInputStream;
|
||||
import software.amazon.awssdk.core.sync.RequestBody;
|
||||
import software.amazon.awssdk.services.s3.S3Client;
|
||||
import software.amazon.awssdk.services.s3.model.*;
|
||||
import software.amazon.awssdk.services.s3.presigner.S3Presigner;
|
||||
import software.amazon.awssdk.services.s3.presigner.model.GetObjectPresignRequest;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -16,6 +18,7 @@ import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.time.Duration;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
@@ -24,10 +27,13 @@ public class FileService {
|
||||
private static final Logger log = LoggerFactory.getLogger(FileService.class);
|
||||
|
||||
private final S3Client s3Client;
|
||||
private final S3Presigner s3Presigner;
|
||||
private final String bucketName;
|
||||
|
||||
public FileService(S3Client s3Client, @Value("${app.s3.bucket}") String bucketName) {
|
||||
public FileService(S3Client s3Client, S3Presigner s3Presigner,
|
||||
@Value("${app.s3.bucket}") String bucketName) {
|
||||
this.s3Client = s3Client;
|
||||
this.s3Presigner = s3Presigner;
|
||||
this.bucketName = bucketName;
|
||||
}
|
||||
|
||||
@@ -106,6 +112,24 @@ public class FileService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a presigned URL for downloading an object from S3/MinIO.
|
||||
* Valid for 15 minutes — enough for OCR processing on CPU.
|
||||
*/
|
||||
public String generatePresignedUrl(String s3Key) {
|
||||
GetObjectRequest getObjectRequest = GetObjectRequest.builder()
|
||||
.bucket(bucketName)
|
||||
.key(s3Key)
|
||||
.build();
|
||||
|
||||
GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder()
|
||||
.signatureDuration(Duration.ofMinutes(15))
|
||||
.getObjectRequest(getObjectRequest)
|
||||
.build();
|
||||
|
||||
return s3Presigner.presignGetObject(presignRequest).url().toString();
|
||||
}
|
||||
|
||||
// ─── private helpers ──────────────────────────────────────────────────────
|
||||
|
||||
private static String sha256Hex(byte[] bytes) {
|
||||
|
||||
@@ -0,0 +1,217 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.dto.CreateAnnotationDTO;
|
||||
import org.raddatz.familienarchiv.model.*;
|
||||
import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository;
|
||||
import org.raddatz.familienarchiv.repository.OcrJobRepository;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class OcrAsyncRunner {
|
||||
|
||||
private static final String OCR_ANNOTATION_COLOR = "#00C7B1";
|
||||
|
||||
private final OcrClient ocrClient;
|
||||
private final DocumentService documentService;
|
||||
private final TranscriptionService transcriptionService;
|
||||
private final AnnotationService annotationService;
|
||||
private final FileService fileService;
|
||||
private final OcrJobRepository ocrJobRepository;
|
||||
private final OcrJobDocumentRepository ocrJobDocumentRepository;
|
||||
private final OcrProgressService ocrProgressService;
|
||||
|
||||
@Async
|
||||
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
|
||||
OcrJob job = ocrJobRepository.findById(jobId).orElse(null);
|
||||
if (job == null) return;
|
||||
|
||||
job.setStatus(OcrJobStatus.RUNNING);
|
||||
updateProgress(job, "PREPARING");
|
||||
|
||||
OcrJobDocument jobDoc = ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, documentId)
|
||||
.orElse(null);
|
||||
if (jobDoc != null) {
|
||||
jobDoc.setStatus(OcrDocumentStatus.RUNNING);
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
}
|
||||
|
||||
Document doc = documentService.getDocumentById(documentId);
|
||||
|
||||
try {
|
||||
updateProgress(job, "LOADING");
|
||||
clearExistingBlocks(documentId);
|
||||
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
|
||||
|
||||
AtomicInteger blockCounter = new AtomicInteger(0);
|
||||
AtomicInteger currentPage = new AtomicInteger(0);
|
||||
AtomicInteger skippedPages = new AtomicInteger(0);
|
||||
AtomicInteger totalPages = new AtomicInteger(0);
|
||||
|
||||
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), event -> {
|
||||
switch (event) {
|
||||
case OcrStreamEvent.Start start -> {
|
||||
totalPages.set(start.totalPages());
|
||||
if (jobDoc != null) {
|
||||
jobDoc.setTotalPages(start.totalPages());
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
}
|
||||
}
|
||||
case OcrStreamEvent.Page page -> {
|
||||
for (OcrBlockResult block : page.blocks()) {
|
||||
createSingleBlock(documentId, block, userId,
|
||||
doc.getFileHash(), blockCounter.getAndIncrement());
|
||||
}
|
||||
currentPage.incrementAndGet();
|
||||
if (jobDoc != null) {
|
||||
jobDoc.setCurrentPage(currentPage.get());
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
}
|
||||
updateProgress(job, "ANALYZING_PAGE:" + currentPage.get()
|
||||
+ ":" + totalPages.get() + ":" + blockCounter.get());
|
||||
}
|
||||
case OcrStreamEvent.Error error -> {
|
||||
log.warn("OCR page {} failed for document {}: {}",
|
||||
error.pageNumber(), documentId, error.message());
|
||||
skippedPages.incrementAndGet();
|
||||
currentPage.incrementAndGet();
|
||||
if (jobDoc != null) {
|
||||
jobDoc.setCurrentPage(currentPage.get());
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
}
|
||||
}
|
||||
case OcrStreamEvent.Done done -> {
|
||||
if (jobDoc != null) {
|
||||
jobDoc.setCurrentPage(totalPages.get());
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
job.setStatus(OcrJobStatus.DONE);
|
||||
job.setProcessedDocuments(1);
|
||||
updateProgress(job, "DONE:" + blockCounter.get() + ":" + skippedPages.get());
|
||||
if (jobDoc != null) {
|
||||
jobDoc.setStatus(OcrDocumentStatus.DONE);
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("OCR processing failed for document {}", documentId, e);
|
||||
job.setStatus(OcrJobStatus.FAILED);
|
||||
job.setErrorCount(1);
|
||||
updateProgress(job, "ERROR");
|
||||
if (jobDoc != null) {
|
||||
jobDoc.setStatus(OcrDocumentStatus.FAILED);
|
||||
jobDoc.setErrorMessage(e.getMessage());
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void updateProgress(OcrJob job, String message) {
|
||||
job.setProgressMessage(message);
|
||||
ocrJobRepository.save(job);
|
||||
}
|
||||
|
||||
@Async
|
||||
public void runBatch(UUID jobId, UUID userId) {
|
||||
OcrJob job = ocrJobRepository.findById(jobId).orElse(null);
|
||||
if (job == null) return;
|
||||
|
||||
job.setStatus(OcrJobStatus.RUNNING);
|
||||
ocrJobRepository.save(job);
|
||||
|
||||
List<OcrJobDocument> jobDocs = ocrJobDocumentRepository.findByJobIdOrderByCreatedAtAsc(jobId);
|
||||
|
||||
for (OcrJobDocument jobDoc : jobDocs) {
|
||||
Document doc = documentService.getDocumentById(jobDoc.getDocumentId());
|
||||
|
||||
if (doc.getStatus() == DocumentStatus.PLACEHOLDER) {
|
||||
jobDoc.setStatus(OcrDocumentStatus.SKIPPED);
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
job.setSkippedCount(job.getSkippedCount() + 1);
|
||||
ocrJobRepository.save(job);
|
||||
ocrProgressService.emit(jobId, "document", Map.of(
|
||||
"documentId", jobDoc.getDocumentId(),
|
||||
"status", "SKIPPED",
|
||||
"processed", job.getProcessedDocuments(),
|
||||
"total", job.getTotalDocuments()));
|
||||
continue;
|
||||
}
|
||||
|
||||
jobDoc.setStatus(OcrDocumentStatus.RUNNING);
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
|
||||
try {
|
||||
processDocument(jobDoc.getDocumentId(), doc, userId);
|
||||
jobDoc.setStatus(OcrDocumentStatus.DONE);
|
||||
job.setProcessedDocuments(job.getProcessedDocuments() + 1);
|
||||
} catch (Exception e) {
|
||||
log.error("OCR batch: failed document {}", jobDoc.getDocumentId(), e);
|
||||
jobDoc.setStatus(OcrDocumentStatus.FAILED);
|
||||
jobDoc.setErrorMessage(e.getMessage());
|
||||
job.setErrorCount(job.getErrorCount() + 1);
|
||||
}
|
||||
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
ocrJobRepository.save(job);
|
||||
|
||||
ocrProgressService.emit(jobId, "document", Map.of(
|
||||
"documentId", jobDoc.getDocumentId(),
|
||||
"status", jobDoc.getStatus().name(),
|
||||
"processed", job.getProcessedDocuments(),
|
||||
"total", job.getTotalDocuments()));
|
||||
}
|
||||
|
||||
job.setStatus(OcrJobStatus.DONE);
|
||||
ocrJobRepository.save(job);
|
||||
|
||||
ocrProgressService.emit(jobId, "done", Map.of(
|
||||
"processed", job.getProcessedDocuments(),
|
||||
"errors", job.getErrorCount(),
|
||||
"skipped", job.getSkippedCount()));
|
||||
ocrProgressService.complete(jobId);
|
||||
}
|
||||
|
||||
void processDocument(UUID documentId, Document doc, UUID userId) {
|
||||
clearExistingBlocks(documentId);
|
||||
|
||||
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
|
||||
List<OcrBlockResult> blocks = ocrClient.extractBlocks(pdfUrl, doc.getScriptType());
|
||||
createTranscriptionBlocks(documentId, blocks, userId, doc.getFileHash());
|
||||
}
|
||||
|
||||
private void clearExistingBlocks(UUID documentId) {
|
||||
transcriptionService.deleteAllBlocksByDocument(documentId);
|
||||
}
|
||||
|
||||
private void createTranscriptionBlocks(UUID documentId, List<OcrBlockResult> blocks,
|
||||
UUID userId, String fileHash) {
|
||||
for (int i = 0; i < blocks.size(); i++) {
|
||||
createSingleBlock(documentId, blocks.get(i), userId, fileHash, i);
|
||||
}
|
||||
}
|
||||
|
||||
void createSingleBlock(UUID documentId, OcrBlockResult block,
|
||||
UUID userId, String fileHash, int sortOrder) {
|
||||
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
|
||||
block.pageNumber(), block.x(), block.y(),
|
||||
block.width(), block.height(), OCR_ANNOTATION_COLOR);
|
||||
|
||||
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
|
||||
documentId, annotationDTO, userId, fileHash, block.polygon());
|
||||
|
||||
transcriptionService.createOcrBlock(documentId, annotation.getId(),
|
||||
block.text(), sortOrder, userId);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.*;
|
||||
import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository;
|
||||
import org.raddatz.familienarchiv.repository.OcrJobRepository;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class OcrBatchService {
|
||||
|
||||
private final OcrHealthClient ocrHealthClient;
|
||||
private final OcrJobRepository ocrJobRepository;
|
||||
private final OcrJobDocumentRepository ocrJobDocumentRepository;
|
||||
private final OcrAsyncRunner ocrAsyncRunner;
|
||||
|
||||
public UUID startBatch(List<UUID> documentIds, UUID userId) {
|
||||
if (!ocrHealthClient.isHealthy()) {
|
||||
throw DomainException.internal(ErrorCode.OCR_SERVICE_UNAVAILABLE,
|
||||
"OCR service is not available");
|
||||
}
|
||||
|
||||
OcrJob job = OcrJob.builder()
|
||||
.totalDocuments(documentIds.size())
|
||||
.createdBy(userId)
|
||||
.status(OcrJobStatus.PENDING)
|
||||
.build();
|
||||
job = ocrJobRepository.save(job);
|
||||
|
||||
for (UUID docId : documentIds) {
|
||||
OcrJobDocument jobDoc = OcrJobDocument.builder()
|
||||
.jobId(job.getId())
|
||||
.documentId(docId)
|
||||
.status(OcrDocumentStatus.PENDING)
|
||||
.build();
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
}
|
||||
|
||||
ocrAsyncRunner.runBatch(job.getId(), userId);
|
||||
return job.getId();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||
public record OcrBlockResult(
|
||||
int pageNumber,
|
||||
double x,
|
||||
double y,
|
||||
double width,
|
||||
double height,
|
||||
List<List<Double>> polygon,
|
||||
String text
|
||||
) {}
|
||||
@@ -0,0 +1,35 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
public interface OcrClient {
|
||||
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
|
||||
|
||||
/**
|
||||
* Stream OCR results page-by-page via NDJSON. Implementations should override
|
||||
* this method. The default exists only for backward compatibility during migration
|
||||
* — it calls extractBlocks() and synthesizes events from the collected result.
|
||||
*/
|
||||
default void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
|
||||
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
|
||||
|
||||
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();
|
||||
for (OcrBlockResult block : allBlocks) {
|
||||
byPage.computeIfAbsent(block.pageNumber(), k -> new ArrayList<>()).add(block);
|
||||
}
|
||||
|
||||
int totalPages = byPage.isEmpty() ? 0 : byPage.keySet().stream().mapToInt(i -> i).max().orElse(0) + 1;
|
||||
handler.accept(new OcrStreamEvent.Start(totalPages));
|
||||
|
||||
for (var entry : byPage.entrySet()) {
|
||||
handler.accept(new OcrStreamEvent.Page(entry.getKey(), entry.getValue()));
|
||||
}
|
||||
|
||||
handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
public interface OcrHealthClient {
|
||||
boolean isHealthy();
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OcrProgressService {
|
||||
|
||||
private static final long SSE_TIMEOUT = 5 * 60 * 1000L;
|
||||
|
||||
private final ConcurrentHashMap<UUID, List<SseEmitter>> emitters = new ConcurrentHashMap<>();
|
||||
|
||||
public SseEmitter register(UUID jobId) {
|
||||
SseEmitter emitter = new SseEmitter(SSE_TIMEOUT);
|
||||
emitters.computeIfAbsent(jobId, k -> new CopyOnWriteArrayList<>()).add(emitter);
|
||||
|
||||
emitter.onCompletion(() -> removeEmitter(jobId, emitter));
|
||||
emitter.onTimeout(() -> removeEmitter(jobId, emitter));
|
||||
emitter.onError(e -> removeEmitter(jobId, emitter));
|
||||
|
||||
return emitter;
|
||||
}
|
||||
|
||||
public void emit(UUID jobId, String eventType, Object data) {
|
||||
List<SseEmitter> jobEmitters = emitters.get(jobId);
|
||||
if (jobEmitters == null) return;
|
||||
|
||||
for (SseEmitter emitter : jobEmitters) {
|
||||
try {
|
||||
emitter.send(SseEmitter.event().name(eventType).data(data));
|
||||
} catch (IOException e) {
|
||||
log.debug("SSE send failed for job {} — removing emitter", jobId);
|
||||
removeEmitter(jobId, emitter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void complete(UUID jobId) {
|
||||
List<SseEmitter> jobEmitters = emitters.remove(jobId);
|
||||
if (jobEmitters == null) return;
|
||||
|
||||
for (SseEmitter emitter : jobEmitters) {
|
||||
try {
|
||||
emitter.complete();
|
||||
} catch (Exception e) {
|
||||
log.debug("SSE complete failed for job {}", jobId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void removeEmitter(UUID jobId, SseEmitter emitter) {
|
||||
List<SseEmitter> jobEmitters = emitters.get(jobId);
|
||||
if (jobEmitters != null) {
|
||||
jobEmitters.remove(emitter);
|
||||
if (jobEmitters.isEmpty()) {
|
||||
emitters.remove(jobId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.dto.OcrStatusDTO;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.*;
|
||||
import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository;
|
||||
import org.raddatz.familienarchiv.repository.OcrJobRepository;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class OcrService {
|
||||
|
||||
private final OcrHealthClient ocrHealthClient;
|
||||
private final DocumentService documentService;
|
||||
private final OcrJobRepository ocrJobRepository;
|
||||
private final OcrJobDocumentRepository ocrJobDocumentRepository;
|
||||
private final OcrAsyncRunner ocrAsyncRunner;
|
||||
|
||||
public OcrJob getJob(UUID jobId) {
|
||||
return ocrJobRepository.findById(jobId)
|
||||
.orElseThrow(() -> DomainException.notFound(
|
||||
ErrorCode.OCR_JOB_NOT_FOUND, "OCR job not found: " + jobId));
|
||||
}
|
||||
|
||||
public OcrStatusDTO getDocumentOcrStatus(UUID documentId) {
|
||||
List<OcrDocumentStatus> activeStatuses = List.of(
|
||||
OcrDocumentStatus.PENDING, OcrDocumentStatus.RUNNING);
|
||||
|
||||
Optional<OcrJobDocument> activeJobDoc = ocrJobDocumentRepository
|
||||
.findFirstByDocumentIdAndStatusIn(documentId, activeStatuses);
|
||||
|
||||
if (activeJobDoc.isEmpty()) {
|
||||
return OcrStatusDTO.builder().status("NONE").build();
|
||||
}
|
||||
|
||||
OcrJobDocument jobDoc = activeJobDoc.get();
|
||||
return OcrStatusDTO.builder()
|
||||
.status(jobDoc.getStatus().name())
|
||||
.jobId(jobDoc.getJobId())
|
||||
.currentPage(jobDoc.getCurrentPage())
|
||||
.totalPages(jobDoc.getTotalPages())
|
||||
.build();
|
||||
}
|
||||
|
||||
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId) {
|
||||
Document doc = documentService.getDocumentById(documentId);
|
||||
|
||||
if (doc.getStatus() == DocumentStatus.PLACEHOLDER) {
|
||||
throw DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED,
|
||||
"Document has no file attached: " + documentId);
|
||||
}
|
||||
|
||||
if (!ocrHealthClient.isHealthy()) {
|
||||
throw DomainException.internal(ErrorCode.OCR_SERVICE_UNAVAILABLE,
|
||||
"OCR service is not available");
|
||||
}
|
||||
|
||||
if (scriptTypeOverride != null) {
|
||||
documentService.updateScriptType(documentId, scriptTypeOverride);
|
||||
}
|
||||
|
||||
OcrJob job = OcrJob.builder()
|
||||
.totalDocuments(1)
|
||||
.createdBy(userId)
|
||||
.status(OcrJobStatus.PENDING)
|
||||
.build();
|
||||
job = ocrJobRepository.save(job);
|
||||
|
||||
OcrJobDocument jobDoc = OcrJobDocument.builder()
|
||||
.jobId(job.getId())
|
||||
.documentId(documentId)
|
||||
.status(OcrDocumentStatus.PENDING)
|
||||
.build();
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId);
|
||||
return job.getId();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public sealed interface OcrStreamEvent {
|
||||
|
||||
record Start(int totalPages) implements OcrStreamEvent {}
|
||||
|
||||
record Page(int pageNumber, List<OcrBlockResult> blocks) implements OcrStreamEvent {}
|
||||
|
||||
record Error(int pageNumber, String message) implements OcrStreamEvent {}
|
||||
|
||||
record Done(int totalBlocks, int skippedPages) implements OcrStreamEvent {}
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.client.JdkClientHttpRequestFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestClient;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
|
||||
private static final ObjectMapper NDJSON_MAPPER = new ObjectMapper()
|
||||
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, true);
|
||||
|
||||
private final RestClient restClient;
|
||||
private final HttpClient streamingHttpClient;
|
||||
private final String baseUrl;
|
||||
|
||||
public RestClientOcrClient(@Value("${app.ocr.base-url:http://ocr-service:8000}") String baseUrl) {
|
||||
this.baseUrl = baseUrl;
|
||||
|
||||
HttpClient httpClient = HttpClient.newBuilder()
|
||||
.version(HttpClient.Version.HTTP_1_1)
|
||||
.connectTimeout(Duration.ofSeconds(10))
|
||||
.build();
|
||||
JdkClientHttpRequestFactory requestFactory = new JdkClientHttpRequestFactory(httpClient);
|
||||
requestFactory.setReadTimeout(Duration.ofMinutes(10));
|
||||
|
||||
this.restClient = RestClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.requestFactory(requestFactory)
|
||||
.build();
|
||||
|
||||
this.streamingHttpClient = HttpClient.newBuilder()
|
||||
.version(HttpClient.Version.HTTP_1_1)
|
||||
.connectTimeout(Duration.ofSeconds(10))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
|
||||
Map<String, String> body = Map.of(
|
||||
"pdfUrl", pdfUrl,
|
||||
"scriptType", scriptType.name(),
|
||||
"language", "de");
|
||||
|
||||
List<OcrBlockJson> response = restClient.post()
|
||||
.uri("/ocr")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.body(body)
|
||||
.retrieve()
|
||||
.body(new ParameterizedTypeReference<>() {});
|
||||
|
||||
if (response == null) return List.of();
|
||||
|
||||
return response.stream()
|
||||
.map(OcrBlockJson::toResult)
|
||||
.toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isHealthy() {
|
||||
try {
|
||||
restClient.get()
|
||||
.uri("/health")
|
||||
.retrieve()
|
||||
.toBodilessEntity();
|
||||
return true;
|
||||
} catch (Exception e) {
|
||||
log.warn("OCR service health check failed: {}", e.getMessage());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
|
||||
String body;
|
||||
try {
|
||||
body = NDJSON_MAPPER.writeValueAsString(Map.of(
|
||||
"pdfUrl", pdfUrl,
|
||||
"scriptType", scriptType.name(),
|
||||
"language", "de"));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to serialize OCR request", e);
|
||||
}
|
||||
|
||||
HttpRequest request = HttpRequest.newBuilder()
|
||||
.uri(URI.create(baseUrl + "/ocr/stream"))
|
||||
.header("Content-Type", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(body))
|
||||
.timeout(Duration.ofMinutes(5))
|
||||
.build();
|
||||
|
||||
try {
|
||||
HttpResponse<InputStream> response = streamingHttpClient.send(
|
||||
request, HttpResponse.BodyHandlers.ofInputStream());
|
||||
|
||||
if (response.statusCode() == 404) {
|
||||
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
|
||||
OcrClient.super.streamBlocks(pdfUrl, scriptType, handler);
|
||||
return;
|
||||
}
|
||||
|
||||
try (InputStream inputStream = response.body()) {
|
||||
parseNdjsonStream(inputStream, handler);
|
||||
}
|
||||
} catch (IOException | InterruptedException e) {
|
||||
if (e instanceof InterruptedException) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
throw new RuntimeException("NDJSON stream failed: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
static void parseNdjsonStream(InputStream inputStream, Consumer<OcrStreamEvent> handler) {
|
||||
try (BufferedReader reader = new BufferedReader(
|
||||
new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (line.isBlank()) continue;
|
||||
|
||||
JsonNode node = NDJSON_MAPPER.readTree(line);
|
||||
String type = node.path("type").asText();
|
||||
|
||||
switch (type) {
|
||||
case "start" -> handler.accept(
|
||||
new OcrStreamEvent.Start(node.path("totalPages").asInt()));
|
||||
case "page" -> {
|
||||
int pageNumber = node.path("pageNumber").asInt();
|
||||
List<OcrBlockResult> blocks = NDJSON_MAPPER.convertValue(
|
||||
node.path("blocks"),
|
||||
new TypeReference<>() {});
|
||||
handler.accept(new OcrStreamEvent.Page(pageNumber, blocks));
|
||||
}
|
||||
case "error" -> handler.accept(
|
||||
new OcrStreamEvent.Error(
|
||||
node.path("pageNumber").asInt(),
|
||||
node.path("message").asText()));
|
||||
case "done" -> handler.accept(
|
||||
new OcrStreamEvent.Done(
|
||||
node.path("totalBlocks").asInt(),
|
||||
node.path("skippedPages").asInt()));
|
||||
default -> log.debug("Ignoring unknown NDJSON event type: {}", type);
|
||||
}
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to parse NDJSON stream: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
record OcrBlockJson(
|
||||
@JsonProperty("pageNumber") int pageNumber,
|
||||
double x,
|
||||
double y,
|
||||
double width,
|
||||
double height,
|
||||
List<List<Double>> polygon,
|
||||
String text
|
||||
) {
|
||||
OcrBlockResult toResult() {
|
||||
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import org.raddatz.familienarchiv.dto.ReorderTranscriptionBlocksDTO;
|
||||
import org.raddatz.familienarchiv.dto.UpdateTranscriptionBlockDTO;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.BlockSource;
|
||||
import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.model.DocumentAnnotation;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
@@ -75,6 +76,24 @@ public class TranscriptionService {
|
||||
return saved;
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public TranscriptionBlock createOcrBlock(UUID documentId, UUID annotationId,
|
||||
String text, int sortOrder, UUID userId) {
|
||||
String sanitized = sanitizeText(text);
|
||||
TranscriptionBlock block = TranscriptionBlock.builder()
|
||||
.annotationId(annotationId)
|
||||
.documentId(documentId)
|
||||
.text(sanitized)
|
||||
.sortOrder(sortOrder)
|
||||
.source(BlockSource.OCR)
|
||||
.createdBy(userId)
|
||||
.updatedBy(userId)
|
||||
.build();
|
||||
TranscriptionBlock saved = blockRepository.save(block);
|
||||
saveVersion(saved, userId);
|
||||
return saved;
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public TranscriptionBlock updateBlock(UUID documentId, UUID blockId,
|
||||
UpdateTranscriptionBlockDTO dto, UUID userId) {
|
||||
@@ -106,6 +125,21 @@ public class TranscriptionService {
|
||||
blockId, annotationId, documentId);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void deleteAllBlocksByDocument(UUID documentId) {
|
||||
List<TranscriptionBlock> blocks = blockRepository.findByDocumentIdOrderBySortOrderAsc(documentId);
|
||||
if (blocks.isEmpty()) return;
|
||||
|
||||
List<UUID> annotationIds = blocks.stream()
|
||||
.map(TranscriptionBlock::getAnnotationId)
|
||||
.toList();
|
||||
|
||||
blockRepository.deleteAll(blocks);
|
||||
blockRepository.flush();
|
||||
annotationRepository.deleteAllById(annotationIds);
|
||||
log.info("Bulk-deleted {} transcription blocks for document {}", blocks.size(), documentId);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void reorderBlocks(UUID documentId, ReorderTranscriptionBlocksDTO dto) {
|
||||
List<UUID> blockIds = dto.getBlockIds();
|
||||
@@ -116,6 +150,13 @@ public class TranscriptionService {
|
||||
}
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public TranscriptionBlock reviewBlock(UUID documentId, UUID blockId) {
|
||||
TranscriptionBlock block = getBlock(documentId, blockId);
|
||||
block.setReviewed(!block.isReviewed());
|
||||
return blockRepository.save(block);
|
||||
}
|
||||
|
||||
public List<TranscriptionBlockVersion> getBlockHistory(UUID documentId, UUID blockId) {
|
||||
getBlock(documentId, blockId);
|
||||
return versionRepository.findByBlockIdOrderByChangedAtDesc(blockId);
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
-- Add optional polygon field for quadrilateral annotation shapes (Kraken OCR output).
|
||||
-- See ADR-002 for the design decision.
|
||||
|
||||
ALTER TABLE document_annotations ADD COLUMN polygon JSONB;
|
||||
|
||||
ALTER TABLE document_annotations
|
||||
ADD CONSTRAINT chk_annotation_polygon_quad
|
||||
CHECK (polygon IS NULL OR jsonb_array_length(polygon) = 4);
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE documents ADD COLUMN script_type VARCHAR(30) NOT NULL DEFAULT 'UNKNOWN';
|
||||
@@ -0,0 +1,26 @@
|
||||
CREATE TABLE ocr_jobs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'PENDING',
|
||||
total_documents INT NOT NULL,
|
||||
processed_documents INT NOT NULL DEFAULT 0,
|
||||
error_count INT NOT NULL DEFAULT 0,
|
||||
skipped_count INT NOT NULL DEFAULT 0,
|
||||
created_by UUID,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE ocr_job_documents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
job_id UUID NOT NULL REFERENCES ocr_jobs(id) ON DELETE CASCADE,
|
||||
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'PENDING',
|
||||
error_message TEXT,
|
||||
current_page INT DEFAULT 0,
|
||||
total_pages INT DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_ocr_job_documents_job_id ON ocr_job_documents(job_id);
|
||||
CREATE INDEX idx_ocr_job_documents_document_id ON ocr_job_documents(document_id);
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE transcription_blocks ADD COLUMN source VARCHAR(10) NOT NULL DEFAULT 'MANUAL';
|
||||
ALTER TABLE transcription_blocks ADD COLUMN reviewed BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE ocr_jobs ADD COLUMN progress_message TEXT;
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE ocr_job_documents
|
||||
ADD CONSTRAINT uq_ocr_job_document UNIQUE (job_id, document_id);
|
||||
@@ -0,0 +1,135 @@
|
||||
package org.raddatz.familienarchiv.controller;
|
||||
|
||||
import tools.jackson.databind.ObjectMapper;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.config.SecurityConfig;
|
||||
import org.raddatz.familienarchiv.dto.BatchOcrDTO;
|
||||
import org.raddatz.familienarchiv.dto.OcrStatusDTO;
|
||||
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.*;
|
||||
import org.raddatz.familienarchiv.security.PermissionAspect;
|
||||
import org.raddatz.familienarchiv.service.*;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.autoconfigure.aop.AopAutoConfiguration;
|
||||
import org.springframework.boot.webmvc.test.autoconfigure.WebMvcTest;
|
||||
import org.springframework.context.annotation.Import;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.security.test.context.support.WithMockUser;
|
||||
import org.springframework.test.context.bean.override.mockito.MockitoBean;
|
||||
import org.springframework.test.web.servlet.MockMvc;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
|
||||
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
|
||||
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
|
||||
|
||||
@WebMvcTest(OcrController.class)
|
||||
@Import({SecurityConfig.class, PermissionAspect.class, AopAutoConfiguration.class})
|
||||
class OcrControllerTest {
|
||||
|
||||
@Autowired MockMvc mockMvc;
|
||||
private final ObjectMapper objectMapper = new ObjectMapper();
|
||||
|
||||
@MockitoBean OcrService ocrService;
|
||||
@MockitoBean OcrBatchService ocrBatchService;
|
||||
@MockitoBean OcrProgressService ocrProgressService;
|
||||
@MockitoBean UserService userService;
|
||||
@MockitoBean CustomUserDetailsService customUserDetailsService;
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void triggerOcr_returns202_withJobId() throws Exception {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID jobId = UUID.randomUUID();
|
||||
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER);
|
||||
|
||||
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any())).thenReturn(jobId);
|
||||
|
||||
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content(objectMapper.writeValueAsString(dto)))
|
||||
.andExpect(status().isAccepted())
|
||||
.andExpect(jsonPath("$.jobId").value(jobId.toString()));
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void triggerOcr_returns400_whenDocumentNotUploaded() throws Exception {
|
||||
UUID docId = UUID.randomUUID();
|
||||
when(ocrService.startOcr(eq(docId), any(), any()))
|
||||
.thenThrow(DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED, "Not uploaded"));
|
||||
|
||||
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{}"))
|
||||
.andExpect(status().isBadRequest());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void getJobStatus_returns404_whenJobNotFound() throws Exception {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
when(ocrService.getJob(jobId))
|
||||
.thenThrow(DomainException.notFound(ErrorCode.OCR_JOB_NOT_FOUND, "OCR job not found"));
|
||||
|
||||
mockMvc.perform(get("/api/ocr/jobs/{jobId}", jobId))
|
||||
.andExpect(status().isNotFound());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void getJobStatus_returnsJobInfo_whenFound() throws Exception {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
OcrJob job = OcrJob.builder()
|
||||
.id(jobId)
|
||||
.status(OcrJobStatus.RUNNING)
|
||||
.totalDocuments(10)
|
||||
.processedDocuments(3)
|
||||
.errorCount(1)
|
||||
.skippedCount(0)
|
||||
.build();
|
||||
when(ocrService.getJob(jobId)).thenReturn(job);
|
||||
|
||||
mockMvc.perform(get("/api/ocr/jobs/{jobId}", jobId))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.status").value("RUNNING"))
|
||||
.andExpect(jsonPath("$.totalDocuments").value(10))
|
||||
.andExpect(jsonPath("$.processedDocuments").value(3));
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void triggerBatch_returns202_withJobId() throws Exception {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
List<UUID> docIds = List.of(UUID.randomUUID(), UUID.randomUUID());
|
||||
BatchOcrDTO dto = new BatchOcrDTO(docIds);
|
||||
|
||||
when(ocrBatchService.startBatch(eq(docIds), any())).thenReturn(jobId);
|
||||
|
||||
mockMvc.perform(post("/api/ocr/batch")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content(objectMapper.writeValueAsString(dto)))
|
||||
.andExpect(status().isAccepted())
|
||||
.andExpect(jsonPath("$.jobId").value(jobId.toString()));
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {
|
||||
UUID docId = UUID.randomUUID();
|
||||
when(ocrService.getDocumentOcrStatus(docId))
|
||||
.thenReturn(OcrStatusDTO.builder().status("NONE").build());
|
||||
|
||||
mockMvc.perform(get("/api/documents/{id}/ocr-status", docId))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.status").value("NONE"));
|
||||
}
|
||||
}
|
||||
@@ -356,4 +356,20 @@ class TranscriptionBlockControllerTest {
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$").isEmpty());
|
||||
}
|
||||
|
||||
// ─── PUT .../review ──────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void reviewBlock_returns200_withToggledBlock() throws Exception {
|
||||
TranscriptionBlock reviewed = TranscriptionBlock.builder()
|
||||
.id(BLOCK_ID).documentId(DOC_ID).annotationId(UUID.randomUUID())
|
||||
.text("text").sortOrder(0).reviewed(true).build();
|
||||
when(transcriptionService.reviewBlock(DOC_ID, BLOCK_ID)).thenReturn(reviewed);
|
||||
|
||||
mockMvc.perform(put("/api/documents/{documentId}/transcription-blocks/{blockId}/review",
|
||||
DOC_ID, BLOCK_ID))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.reviewed").value(true));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
package org.raddatz.familienarchiv.dto;
|
||||
|
||||
import jakarta.validation.ConstraintViolation;
|
||||
import jakarta.validation.Validation;
|
||||
import jakarta.validation.Validator;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class UniquePointsValidatorTest {
|
||||
|
||||
private final Validator validator = Validation.buildDefaultValidatorFactory().getValidator();
|
||||
|
||||
@Test
|
||||
void shouldAcceptNull() {
|
||||
var dto = new CreateAnnotationDTO();
|
||||
dto.setPolygon(null);
|
||||
|
||||
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
|
||||
|
||||
assertThat(violations).noneMatch(v -> v.getPropertyPath().toString().equals("polygon"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldAcceptFourUniquePoints() {
|
||||
var dto = new CreateAnnotationDTO();
|
||||
dto.setPolygon(List.of(
|
||||
List.of(0.1, 0.1),
|
||||
List.of(0.9, 0.1),
|
||||
List.of(0.9, 0.9),
|
||||
List.of(0.1, 0.9)));
|
||||
|
||||
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
|
||||
|
||||
assertThat(violations).noneMatch(v -> v.getPropertyPath().toString().equals("polygon"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRejectDuplicatePoints() {
|
||||
var dto = new CreateAnnotationDTO();
|
||||
dto.setPolygon(List.of(
|
||||
List.of(0.1, 0.1),
|
||||
List.of(0.1, 0.1),
|
||||
List.of(0.9, 0.9),
|
||||
List.of(0.1, 0.9)));
|
||||
|
||||
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
|
||||
|
||||
assertThat(violations).anyMatch(v -> v.getPropertyPath().toString().equals("polygon"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRejectPolygonWithThreePoints() {
|
||||
var dto = new CreateAnnotationDTO();
|
||||
dto.setPolygon(List.of(
|
||||
List.of(0.1, 0.1),
|
||||
List.of(0.9, 0.1),
|
||||
List.of(0.9, 0.9)));
|
||||
|
||||
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
|
||||
|
||||
assertThat(violations).anyMatch(v -> v.getPropertyPath().toString().equals("polygon"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRejectPolygonWithFivePoints() {
|
||||
var dto = new CreateAnnotationDTO();
|
||||
dto.setPolygon(List.of(
|
||||
List.of(0.1, 0.1),
|
||||
List.of(0.5, 0.1),
|
||||
List.of(0.9, 0.1),
|
||||
List.of(0.9, 0.9),
|
||||
List.of(0.1, 0.9)));
|
||||
|
||||
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
|
||||
|
||||
assertThat(violations).anyMatch(v -> v.getPropertyPath().toString().equals("polygon"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRejectCoordinateOutOfRange() {
|
||||
var dto = new CreateAnnotationDTO();
|
||||
dto.setPolygon(List.of(
|
||||
List.of(1.5, 0.1),
|
||||
List.of(0.9, 0.1),
|
||||
List.of(0.9, 0.9),
|
||||
List.of(0.1, 0.9)));
|
||||
|
||||
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
|
||||
|
||||
assertThat(violations).anyMatch(v -> v.getPropertyPath().toString().contains("polygon"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRejectNegativeCoordinate() {
|
||||
var dto = new CreateAnnotationDTO();
|
||||
dto.setPolygon(List.of(
|
||||
List.of(-0.1, 0.1),
|
||||
List.of(0.9, 0.1),
|
||||
List.of(0.9, 0.9),
|
||||
List.of(0.1, 0.9)));
|
||||
|
||||
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
|
||||
|
||||
assertThat(violations).anyMatch(v -> v.getPropertyPath().toString().contains("polygon"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldRejectPointWithOneCoordinate() {
|
||||
var dto = new CreateAnnotationDTO();
|
||||
dto.setPolygon(List.of(
|
||||
List.of(0.1),
|
||||
List.of(0.9, 0.1),
|
||||
List.of(0.9, 0.9),
|
||||
List.of(0.1, 0.9)));
|
||||
|
||||
Set<ConstraintViolation<CreateAnnotationDTO>> violations = validator.validate(dto);
|
||||
|
||||
assertThat(violations).anyMatch(v -> v.getPropertyPath().toString().contains("polygon"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class PolygonConverterTest {
|
||||
|
||||
private final PolygonConverter converter = new PolygonConverter();
|
||||
|
||||
@Test
|
||||
void convertToDatabaseColumn_returnsNull_whenPolygonIsNull() {
|
||||
assertThat(converter.convertToDatabaseColumn(null)).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void convertToDatabaseColumn_returnsJsonArray_whenPolygonIsValid() {
|
||||
List<List<Double>> polygon = List.of(
|
||||
List.of(0.1, 0.2),
|
||||
List.of(0.9, 0.2),
|
||||
List.of(0.9, 0.8),
|
||||
List.of(0.1, 0.8));
|
||||
|
||||
String json = converter.convertToDatabaseColumn(polygon);
|
||||
|
||||
assertThat(json).isEqualTo("[[0.1,0.2],[0.9,0.2],[0.9,0.8],[0.1,0.8]]");
|
||||
}
|
||||
|
||||
@Test
|
||||
void convertToEntityAttribute_returnsNull_whenJsonIsNull() {
|
||||
assertThat(converter.convertToEntityAttribute(null)).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void convertToEntityAttribute_returnsNull_whenJsonIsEmpty() {
|
||||
assertThat(converter.convertToEntityAttribute("")).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void convertToEntityAttribute_returnsPolygon_whenJsonIsValid() {
|
||||
String json = "[[0.1,0.2],[0.9,0.2],[0.9,0.8],[0.1,0.8]]";
|
||||
|
||||
List<List<Double>> polygon = converter.convertToEntityAttribute(json);
|
||||
|
||||
assertThat(polygon).hasSize(4);
|
||||
assertThat(polygon.get(0)).containsExactly(0.1, 0.2);
|
||||
assertThat(polygon.get(3)).containsExactly(0.1, 0.8);
|
||||
}
|
||||
|
||||
@Test
|
||||
void roundTrip_preservesValues() {
|
||||
List<List<Double>> original = List.of(
|
||||
List.of(0.12, 0.08),
|
||||
List.of(0.88, 0.09),
|
||||
List.of(0.87, 0.14),
|
||||
List.of(0.11, 0.13));
|
||||
|
||||
String json = converter.convertToDatabaseColumn(original);
|
||||
List<List<Double>> restored = converter.convertToEntityAttribute(json);
|
||||
|
||||
assertThat(restored).isEqualTo(original);
|
||||
}
|
||||
}
|
||||
@@ -260,6 +260,55 @@ class AnnotationServiceTest {
|
||||
verify(annotationRepository).save(any());
|
||||
}
|
||||
|
||||
// ─── createOcrAnnotation ──────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void createOcrAnnotation_skipsOverlapCheck_andSavesWithPolygon() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.8, 0.04, "#00C7B1");
|
||||
List<List<Double>> polygon = List.of(
|
||||
List.of(0.1, 0.1), List.of(0.9, 0.11),
|
||||
List.of(0.89, 0.14), List.of(0.11, 0.13));
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
DocumentAnnotation result = annotationService.createOcrAnnotation(
|
||||
docId, dto, userId, "filehash", polygon);
|
||||
|
||||
assertThat(result.getPolygon()).isEqualTo(polygon);
|
||||
assertThat(result.getDocumentId()).isEqualTo(docId);
|
||||
verify(annotationRepository).save(any());
|
||||
verify(annotationRepository, never()).findByDocumentIdAndPageNumber(any(), any(int.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void createOcrAnnotation_savesWithNullPolygon_whenPolygonNotProvided() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.8, 0.04, "#00C7B1");
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
DocumentAnnotation result = annotationService.createOcrAnnotation(
|
||||
docId, dto, userId, "filehash", null);
|
||||
|
||||
assertThat(result.getPolygon()).isNull();
|
||||
verify(annotationRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void createOcrAnnotation_doesNotCheckOverlap_evenWhenOverlappingAnnotationExists() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
CreateAnnotationDTO dto = new CreateAnnotationDTO(1, 0.1, 0.1, 0.3, 0.3, "#00C7B1");
|
||||
when(annotationRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
annotationService.createOcrAnnotation(docId, dto, userId, "hash", null);
|
||||
|
||||
verify(annotationRepository, never()).findByDocumentIdAndPageNumber(any(), any(int.class));
|
||||
}
|
||||
|
||||
// ─── overlaps — partial overlap cases ────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void createAnnotation_noConflict_whenAnnotationIsAbove() {
|
||||
// x ranges overlap, y ranges don't — existing is ABOVE the new annotation
|
||||
|
||||
@@ -32,7 +32,7 @@ class FileServiceTest {
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
s3Client = mock(S3Client.class);
|
||||
fileService = new FileService(s3Client, "test-bucket");
|
||||
fileService = new FileService(s3Client, null, "test-bucket");
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.raddatz.familienarchiv.dto.CreateAnnotationDTO;
|
||||
import org.raddatz.familienarchiv.model.*;
|
||||
import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository;
|
||||
import org.raddatz.familienarchiv.repository.OcrJobRepository;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class OcrAsyncRunnerTest {
|
||||
|
||||
@Mock OcrClient ocrClient;
|
||||
@Mock DocumentService documentService;
|
||||
@Mock TranscriptionService transcriptionService;
|
||||
@Mock AnnotationService annotationService;
|
||||
@Mock FileService fileService;
|
||||
@Mock OcrJobRepository ocrJobRepository;
|
||||
@Mock OcrJobDocumentRepository ocrJobDocumentRepository;
|
||||
@Mock OcrProgressService ocrProgressService;
|
||||
|
||||
@InjectMocks OcrAsyncRunner ocrAsyncRunner;
|
||||
|
||||
@Test
|
||||
void processDocument_clearsExistingBlocks() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
Document doc = Document.builder().id(docId).filePath("test.pdf")
|
||||
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of());
|
||||
|
||||
ocrAsyncRunner.processDocument(docId, doc, userId);
|
||||
|
||||
verify(transcriptionService).deleteAllBlocksByDocument(docId);
|
||||
}
|
||||
|
||||
@Test
|
||||
void processDocument_createsAnnotationAndBlock_forEachResult() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
UUID annId = UUID.randomUUID();
|
||||
Document doc = Document.builder().id(docId).filePath("test.pdf")
|
||||
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
|
||||
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2")));
|
||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||
|
||||
ocrAsyncRunner.processDocument(docId, doc, userId);
|
||||
|
||||
verify(annotationService, times(2)).createOcrAnnotation(
|
||||
eq(docId), any(CreateAnnotationDTO.class), eq(userId), eq("hash"), any());
|
||||
verify(transcriptionService, times(2)).createOcrBlock(
|
||||
eq(docId), eq(annId), any(), anyInt(), eq(userId));
|
||||
}
|
||||
|
||||
@Test
|
||||
void processDocument_delegatesBlockCreationToTranscriptionService() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
UUID annId = UUID.randomUUID();
|
||||
Document doc = Document.builder().id(docId).filePath("test.pdf")
|
||||
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
|
||||
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test")));
|
||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||
|
||||
ocrAsyncRunner.processDocument(docId, doc, userId);
|
||||
|
||||
verify(transcriptionService).createOcrBlock(docId, annId, "Test", 0, userId);
|
||||
}
|
||||
|
||||
@Test
|
||||
void runSingleDocument_setsJobDone_onSuccess() {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
|
||||
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
|
||||
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
|
||||
Document doc = Document.builder().id(docId).filePath("test.pdf")
|
||||
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
|
||||
|
||||
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
|
||||
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
|
||||
.thenReturn(Optional.of(jobDoc));
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
handler.accept(new OcrStreamEvent.Start(1));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
assertThat(job.getStatus()).isEqualTo(OcrJobStatus.DONE);
|
||||
}
|
||||
|
||||
@Test
|
||||
void runSingleDocument_setsJobFailed_onError() {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
|
||||
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
|
||||
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
|
||||
Document doc = Document.builder().id(docId).filePath("test.pdf")
|
||||
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
|
||||
|
||||
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
|
||||
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
|
||||
.thenReturn(Optional.of(jobDoc));
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
assertThat(job.getStatus()).isEqualTo(OcrJobStatus.FAILED);
|
||||
assertThat(job.getErrorCount()).isEqualTo(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void runSingleDocument_updatesProgressPerPage() {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
|
||||
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
|
||||
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
|
||||
Document doc = Document.builder().id(docId).filePath("test.pdf")
|
||||
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
|
||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(UUID.randomUUID()).build();
|
||||
|
||||
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
|
||||
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
|
||||
.thenReturn(Optional.of(jobDoc));
|
||||
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||
|
||||
List<String> progressMessages = new ArrayList<>();
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
handler.accept(new OcrStreamEvent.Start(3));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2"))));
|
||||
progressMessages.add(job.getProgressMessage());
|
||||
handler.accept(new OcrStreamEvent.Page(1, List.of(
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3"))));
|
||||
progressMessages.add(job.getProgressMessage());
|
||||
handler.accept(new OcrStreamEvent.Done(3, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
assertThat(progressMessages.get(0)).isEqualTo("ANALYZING_PAGE:1:3:2");
|
||||
assertThat(progressMessages.get(1)).isEqualTo("ANALYZING_PAGE:2:3:3");
|
||||
assertThat(job.getProgressMessage()).isEqualTo("DONE:3:0");
|
||||
}
|
||||
|
||||
@Test
|
||||
void runSingleDocument_includesSkippedPagesInDoneMessage() {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
|
||||
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
|
||||
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
|
||||
Document doc = Document.builder().id(docId).filePath("test.pdf")
|
||||
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
|
||||
|
||||
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
|
||||
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
|
||||
.thenReturn(Optional.of(jobDoc));
|
||||
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
handler.accept(new OcrStreamEvent.Start(3));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Error(1, "failed"));
|
||||
handler.accept(new OcrStreamEvent.Page(2, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 1));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
assertThat(job.getStatus()).isEqualTo(OcrJobStatus.DONE);
|
||||
assertThat(job.getProgressMessage()).isEqualTo("DONE:0:1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void runSingleDocument_logsStreamErrorAtWarnWithoutSettingJobFailed() {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
|
||||
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
|
||||
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
|
||||
Document doc = Document.builder().id(docId).filePath("test.pdf")
|
||||
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
|
||||
|
||||
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
|
||||
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
|
||||
.thenReturn(Optional.of(jobDoc));
|
||||
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
handler.accept(new OcrStreamEvent.Start(2));
|
||||
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
|
||||
handler.accept(new OcrStreamEvent.Page(1, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 1));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
// Job should still be DONE, not FAILED (per-page errors don't fail the whole job)
|
||||
assertThat(job.getStatus()).isEqualTo(OcrJobStatus.DONE);
|
||||
// Raw error message should not leak to progress
|
||||
assertThat(job.getProgressMessage()).doesNotContain("python traceback");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.*;
|
||||
import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository;
|
||||
import org.raddatz.familienarchiv.repository.OcrJobRepository;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class OcrBatchServiceTest {
|
||||
|
||||
@Mock OcrHealthClient ocrHealthClient;
|
||||
@Mock OcrJobRepository ocrJobRepository;
|
||||
@Mock OcrJobDocumentRepository ocrJobDocumentRepository;
|
||||
@Mock OcrAsyncRunner ocrAsyncRunner;
|
||||
|
||||
@InjectMocks OcrBatchService ocrBatchService;
|
||||
|
||||
@Test
|
||||
void startBatch_throwsServiceUnavailable_whenOcrServiceIsDown() {
|
||||
when(ocrHealthClient.isHealthy()).thenReturn(false);
|
||||
|
||||
assertThatThrownBy(() -> ocrBatchService.startBatch(List.of(UUID.randomUUID()), UUID.randomUUID()))
|
||||
.isInstanceOf(DomainException.class)
|
||||
.satisfies(e -> assertThat(((DomainException) e).getCode())
|
||||
.isEqualTo(ErrorCode.OCR_SERVICE_UNAVAILABLE));
|
||||
}
|
||||
|
||||
@Test
|
||||
void startBatch_createsJobAndDispatchesAsync() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
UUID jobId = UUID.randomUUID();
|
||||
|
||||
when(ocrHealthClient.isHealthy()).thenReturn(true);
|
||||
when(ocrJobRepository.save(any())).thenAnswer(inv -> {
|
||||
OcrJob job = inv.getArgument(0);
|
||||
job.setId(jobId);
|
||||
return job;
|
||||
});
|
||||
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
UUID result = ocrBatchService.startBatch(List.of(docId), userId);
|
||||
|
||||
assertThat(result).isEqualTo(jobId);
|
||||
verify(ocrAsyncRunner).runBatch(jobId, userId);
|
||||
}
|
||||
|
||||
@Test
|
||||
void startBatch_createsJobDocumentForEachId() {
|
||||
UUID doc1 = UUID.randomUUID();
|
||||
UUID doc2 = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
|
||||
when(ocrHealthClient.isHealthy()).thenReturn(true);
|
||||
when(ocrJobRepository.save(any())).thenAnswer(inv -> {
|
||||
OcrJob job = inv.getArgument(0);
|
||||
job.setId(UUID.randomUUID());
|
||||
return job;
|
||||
});
|
||||
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
ocrBatchService.startBatch(List.of(doc1, doc2), userId);
|
||||
|
||||
verify(ocrJobDocumentRepository, times(2)).save(any());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class OcrClientDefaultStreamTest {
|
||||
|
||||
@Test
|
||||
void defaultStreamBlocksSynthesizesEventsFromExtractBlocks() {
|
||||
OcrClient client = (pdfUrl, scriptType) -> List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
||||
|
||||
assertThat(events).hasSize(4);
|
||||
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
||||
assertThat(((OcrStreamEvent.Start) events.get(0)).totalPages()).isEqualTo(2);
|
||||
|
||||
assertThat(events.get(1)).isInstanceOf(OcrStreamEvent.Page.class);
|
||||
var page0 = (OcrStreamEvent.Page) events.get(1);
|
||||
assertThat(page0.pageNumber()).isEqualTo(0);
|
||||
assertThat(page0.blocks()).hasSize(2);
|
||||
|
||||
assertThat(events.get(2)).isInstanceOf(OcrStreamEvent.Page.class);
|
||||
var page1 = (OcrStreamEvent.Page) events.get(2);
|
||||
assertThat(page1.pageNumber()).isEqualTo(1);
|
||||
assertThat(page1.blocks()).hasSize(1);
|
||||
|
||||
assertThat(events.get(3)).isInstanceOf(OcrStreamEvent.Done.class);
|
||||
var done = (OcrStreamEvent.Done) events.get(3);
|
||||
assertThat(done.totalBlocks()).isEqualTo(3);
|
||||
assertThat(done.skippedPages()).isEqualTo(0);
|
||||
}
|
||||
|
||||
@Test
|
||||
void defaultStreamBlocksHandlesEmptyResults() {
|
||||
OcrClient client = (pdfUrl, scriptType) -> List.of();
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
||||
|
||||
assertThat(events).hasSize(2);
|
||||
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
||||
assertThat(((OcrStreamEvent.Start) events.get(0)).totalPages()).isEqualTo(0);
|
||||
assertThat(events.get(1)).isInstanceOf(OcrStreamEvent.Done.class);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||
|
||||
class OcrProgressServiceTest {
|
||||
|
||||
private final OcrProgressService progressService = new OcrProgressService();
|
||||
|
||||
@Test
|
||||
void register_returnsNonNullEmitter() {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
SseEmitter emitter = progressService.register(jobId);
|
||||
assertThat(emitter).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void emit_doesNotThrow_whenNoEmittersRegistered() {
|
||||
assertThatCode(() -> progressService.emit(UUID.randomUUID(), "test", "data"))
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
|
||||
@Test
|
||||
void complete_doesNotThrow_whenNoEmittersRegistered() {
|
||||
assertThatCode(() -> progressService.complete(UUID.randomUUID()))
|
||||
.doesNotThrowAnyException();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.raddatz.familienarchiv.dto.OcrStatusDTO;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.*;
|
||||
import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository;
|
||||
import org.raddatz.familienarchiv.repository.OcrJobRepository;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.springframework.http.HttpStatus.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class OcrServiceTest {
|
||||
|
||||
@Mock OcrHealthClient ocrHealthClient;
|
||||
@Mock DocumentService documentService;
|
||||
@Mock OcrJobRepository ocrJobRepository;
|
||||
@Mock OcrJobDocumentRepository ocrJobDocumentRepository;
|
||||
@Mock OcrAsyncRunner ocrAsyncRunner;
|
||||
|
||||
@InjectMocks OcrService ocrService;
|
||||
|
||||
// ─── getJob ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void getJob_returnsJob_whenFound() {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
OcrJob job = OcrJob.builder().id(jobId).status(OcrJobStatus.RUNNING).build();
|
||||
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
|
||||
|
||||
OcrJob result = ocrService.getJob(jobId);
|
||||
|
||||
assertThat(result).isEqualTo(job);
|
||||
}
|
||||
|
||||
@Test
|
||||
void getJob_throwsNotFound_whenJobDoesNotExist() {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.empty());
|
||||
|
||||
assertThatThrownBy(() -> ocrService.getJob(jobId))
|
||||
.isInstanceOf(DomainException.class)
|
||||
.satisfies(e -> {
|
||||
DomainException de = (DomainException) e;
|
||||
assertThat(de.getStatus()).isEqualTo(NOT_FOUND);
|
||||
assertThat(de.getCode()).isEqualTo(ErrorCode.OCR_JOB_NOT_FOUND);
|
||||
});
|
||||
}
|
||||
|
||||
// ─── getDocumentOcrStatus ───────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void getDocumentOcrStatus_returnsNone_whenNoActiveJob() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
when(ocrJobDocumentRepository.findFirstByDocumentIdAndStatusIn(any(), any()))
|
||||
.thenReturn(Optional.empty());
|
||||
|
||||
OcrStatusDTO result = ocrService.getDocumentOcrStatus(docId);
|
||||
|
||||
assertThat(result.getStatus()).isEqualTo("NONE");
|
||||
assertThat(result.getJobId()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void getDocumentOcrStatus_returnsActiveStatus_whenJobExists() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID jobId = UUID.randomUUID();
|
||||
OcrJobDocument jobDoc = OcrJobDocument.builder()
|
||||
.jobId(jobId).documentId(docId)
|
||||
.status(OcrDocumentStatus.RUNNING)
|
||||
.currentPage(2).totalPages(5)
|
||||
.build();
|
||||
when(ocrJobDocumentRepository.findFirstByDocumentIdAndStatusIn(any(), any()))
|
||||
.thenReturn(Optional.of(jobDoc));
|
||||
|
||||
OcrStatusDTO result = ocrService.getDocumentOcrStatus(docId);
|
||||
|
||||
assertThat(result.getStatus()).isEqualTo("RUNNING");
|
||||
assertThat(result.getJobId()).isEqualTo(jobId);
|
||||
assertThat(result.getCurrentPage()).isEqualTo(2);
|
||||
assertThat(result.getTotalPages()).isEqualTo(5);
|
||||
}
|
||||
|
||||
// ─── startOcr ───────────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void startOcr_throwsBadRequest_whenDocumentIsPlaceholder() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
Document doc = Document.builder().id(docId).status(DocumentStatus.PLACEHOLDER).build();
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
|
||||
assertThatThrownBy(() -> ocrService.startOcr(docId, null, UUID.randomUUID()))
|
||||
.isInstanceOf(DomainException.class)
|
||||
.satisfies(e -> {
|
||||
DomainException de = (DomainException) e;
|
||||
assertThat(de.getStatus()).isEqualTo(BAD_REQUEST);
|
||||
assertThat(de.getCode()).isEqualTo(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED);
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void startOcr_throwsServiceUnavailable_whenOcrServiceIsDown() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
Document doc = Document.builder().id(docId).status(DocumentStatus.UPLOADED)
|
||||
.filePath("test.pdf").build();
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
when(ocrHealthClient.isHealthy()).thenReturn(false);
|
||||
|
||||
assertThatThrownBy(() -> ocrService.startOcr(docId, null, UUID.randomUUID()))
|
||||
.isInstanceOf(DomainException.class)
|
||||
.satisfies(e -> assertThat(((DomainException) e).getCode())
|
||||
.isEqualTo(ErrorCode.OCR_SERVICE_UNAVAILABLE));
|
||||
}
|
||||
|
||||
@Test
|
||||
void startOcr_createsJobAndDispatchesAsync() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
UUID jobId = UUID.randomUUID();
|
||||
Document doc = Document.builder().id(docId).status(DocumentStatus.UPLOADED)
|
||||
.filePath("test.pdf").scriptType(ScriptType.TYPEWRITER).build();
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
when(ocrHealthClient.isHealthy()).thenReturn(true);
|
||||
when(ocrJobRepository.save(any())).thenAnswer(inv -> {
|
||||
OcrJob job = inv.getArgument(0);
|
||||
job.setId(jobId);
|
||||
return job;
|
||||
});
|
||||
|
||||
UUID result = ocrService.startOcr(docId, null, userId);
|
||||
|
||||
assertThat(result).isEqualTo(jobId);
|
||||
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId);
|
||||
}
|
||||
|
||||
@Test
|
||||
void startOcr_updatesScriptType_whenProvided() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
Document doc = Document.builder().id(docId).status(DocumentStatus.UPLOADED)
|
||||
.filePath("test.pdf").scriptType(ScriptType.UNKNOWN).build();
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
when(ocrHealthClient.isHealthy()).thenReturn(true);
|
||||
when(ocrJobRepository.save(any())).thenAnswer(inv -> {
|
||||
OcrJob job = inv.getArgument(0);
|
||||
job.setId(UUID.randomUUID());
|
||||
return job;
|
||||
});
|
||||
|
||||
ocrService.startOcr(docId, ScriptType.HANDWRITING_LATIN, UUID.randomUUID());
|
||||
|
||||
verify(documentService).updateScriptType(docId, ScriptType.HANDWRITING_LATIN);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class OcrStreamEventTest {
|
||||
|
||||
@Test
|
||||
void startRecordHoldsTotalPages() {
|
||||
var start = new OcrStreamEvent.Start(5);
|
||||
assertThat(start.totalPages()).isEqualTo(5);
|
||||
assertThat(start).isInstanceOf(OcrStreamEvent.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void pageRecordHoldsBlocksAndPageNumber() {
|
||||
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test");
|
||||
var page = new OcrStreamEvent.Page(0, List.of(block));
|
||||
assertThat(page.pageNumber()).isEqualTo(0);
|
||||
assertThat(page.blocks()).hasSize(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void errorRecordHoldsPageAndMessage() {
|
||||
var error = new OcrStreamEvent.Error(2, "failed");
|
||||
assertThat(error.pageNumber()).isEqualTo(2);
|
||||
assertThat(error.message()).isEqualTo("failed");
|
||||
}
|
||||
|
||||
@Test
|
||||
void doneRecordHoldsTotalBlocksAndSkippedPages() {
|
||||
var done = new OcrStreamEvent.Done(12, 2);
|
||||
assertThat(done.totalBlocks()).isEqualTo(12);
|
||||
assertThat(done.skippedPages()).isEqualTo(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void patternMatchingWorksOnSealedInterface() {
|
||||
OcrStreamEvent event = new OcrStreamEvent.Start(3);
|
||||
String result = switch (event) {
|
||||
case OcrStreamEvent.Start s -> "start:" + s.totalPages();
|
||||
case OcrStreamEvent.Page p -> "page:" + p.pageNumber();
|
||||
case OcrStreamEvent.Error e -> "error:" + e.pageNumber();
|
||||
case OcrStreamEvent.Done d -> "done:" + d.totalBlocks();
|
||||
};
|
||||
assertThat(result).isEqualTo("start:3");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.InputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class RestClientOcrClientStreamTest {
|
||||
|
||||
@Test
|
||||
void parseNdjsonStream_dispatchesStartPageDoneInOrder() {
|
||||
String ndjson = """
|
||||
{"type":"start","totalPages":2}
|
||||
{"type":"page","pageNumber":0,"blocks":[{"pageNumber":0,"x":0.1,"y":0.2,"width":0.8,"height":0.1,"polygon":null,"text":"Line 1"}]}
|
||||
{"type":"page","pageNumber":1,"blocks":[{"pageNumber":1,"x":0.1,"y":0.3,"width":0.8,"height":0.1,"polygon":null,"text":"Line 2"}]}
|
||||
{"type":"done","totalBlocks":2,"skippedPages":0}
|
||||
""";
|
||||
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
RestClientOcrClient.parseNdjsonStream(stream, events::add);
|
||||
|
||||
assertThat(events).hasSize(4);
|
||||
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
||||
assertThat(((OcrStreamEvent.Start) events.get(0)).totalPages()).isEqualTo(2);
|
||||
|
||||
assertThat(events.get(1)).isInstanceOf(OcrStreamEvent.Page.class);
|
||||
var page0 = (OcrStreamEvent.Page) events.get(1);
|
||||
assertThat(page0.pageNumber()).isEqualTo(0);
|
||||
assertThat(page0.blocks()).hasSize(1);
|
||||
assertThat(page0.blocks().get(0).text()).isEqualTo("Line 1");
|
||||
|
||||
assertThat(events.get(2)).isInstanceOf(OcrStreamEvent.Page.class);
|
||||
var page1 = (OcrStreamEvent.Page) events.get(2);
|
||||
assertThat(page1.pageNumber()).isEqualTo(1);
|
||||
|
||||
assertThat(events.get(3)).isInstanceOf(OcrStreamEvent.Done.class);
|
||||
var done = (OcrStreamEvent.Done) events.get(3);
|
||||
assertThat(done.totalBlocks()).isEqualTo(2);
|
||||
assertThat(done.skippedPages()).isEqualTo(0);
|
||||
}
|
||||
|
||||
@Test
|
||||
void parseNdjsonStream_parsesErrorEvents() {
|
||||
String ndjson = """
|
||||
{"type":"start","totalPages":3}
|
||||
{"type":"page","pageNumber":0,"blocks":[]}
|
||||
{"type":"error","pageNumber":1,"message":"OCR processing failed on page 1"}
|
||||
{"type":"page","pageNumber":2,"blocks":[]}
|
||||
{"type":"done","totalBlocks":0,"skippedPages":1}
|
||||
""";
|
||||
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
RestClientOcrClient.parseNdjsonStream(stream, events::add);
|
||||
|
||||
assertThat(events).hasSize(5);
|
||||
assertThat(events.get(2)).isInstanceOf(OcrStreamEvent.Error.class);
|
||||
var error = (OcrStreamEvent.Error) events.get(2);
|
||||
assertThat(error.pageNumber()).isEqualTo(1);
|
||||
assertThat(error.message()).contains("OCR processing failed");
|
||||
}
|
||||
|
||||
@Test
|
||||
void parseNdjsonStream_skipsBlankLines() {
|
||||
String ndjson = """
|
||||
{"type":"start","totalPages":1}
|
||||
|
||||
{"type":"page","pageNumber":0,"blocks":[]}
|
||||
|
||||
{"type":"done","totalBlocks":0,"skippedPages":0}
|
||||
""";
|
||||
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
RestClientOcrClient.parseNdjsonStream(stream, events::add);
|
||||
|
||||
assertThat(events).hasSize(3);
|
||||
}
|
||||
|
||||
@Test
|
||||
void parseNdjsonStream_ignoresUnknownEventTypes() {
|
||||
String ndjson = """
|
||||
{"type":"start","totalPages":1}
|
||||
{"type":"unknown","foo":"bar"}
|
||||
{"type":"done","totalBlocks":0,"skippedPages":0}
|
||||
""";
|
||||
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
RestClientOcrClient.parseNdjsonStream(stream, events::add);
|
||||
|
||||
assertThat(events).hasSize(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void parseNdjsonStream_handlesUnknownFieldsInBlocks() {
|
||||
String ndjson = """
|
||||
{"type":"start","totalPages":1}
|
||||
{"type":"page","pageNumber":0,"blocks":[{"pageNumber":0,"x":0.1,"y":0.2,"width":0.8,"height":0.1,"polygon":null,"text":"Line 1","confidence":0.95,"newFutureField":"ignored"}]}
|
||||
{"type":"done","totalBlocks":1,"skippedPages":0}
|
||||
""";
|
||||
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
RestClientOcrClient.parseNdjsonStream(stream, events::add);
|
||||
|
||||
assertThat(events).hasSize(3);
|
||||
var page = (OcrStreamEvent.Page) events.get(1);
|
||||
assertThat(page.blocks().get(0).text()).isEqualTo("Line 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void parseNdjsonStream_parsesPageWithPolygon() {
|
||||
String ndjson = """
|
||||
{"type":"start","totalPages":1}
|
||||
{"type":"page","pageNumber":0,"blocks":[{"pageNumber":0,"x":0.1,"y":0.2,"width":0.8,"height":0.1,"polygon":[[0.1,0.2],[0.9,0.2],[0.9,0.3],[0.1,0.3]],"text":"With polygon"}]}
|
||||
{"type":"done","totalBlocks":1,"skippedPages":0}
|
||||
""";
|
||||
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
RestClientOcrClient.parseNdjsonStream(stream, events::add);
|
||||
|
||||
var page = (OcrStreamEvent.Page) events.get(1);
|
||||
assertThat(page.blocks().get(0).polygon()).hasSize(4);
|
||||
assertThat(page.blocks().get(0).text()).isEqualTo("With polygon");
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import org.raddatz.familienarchiv.dto.CreateTranscriptionBlockDTO;
|
||||
import org.raddatz.familienarchiv.dto.ReorderTranscriptionBlocksDTO;
|
||||
import org.raddatz.familienarchiv.dto.UpdateTranscriptionBlockDTO;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.model.BlockSource;
|
||||
import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.model.DocumentAnnotation;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
@@ -26,8 +27,8 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
import static org.springframework.http.HttpStatus.NOT_FOUND;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
@@ -99,6 +100,50 @@ class TranscriptionServiceTest {
|
||||
verify(versionRepository).save(any(TranscriptionBlockVersion.class));
|
||||
}
|
||||
|
||||
// ─── createOcrBlock ──────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void createOcrBlock_createsBlockWithOcrSourceAndSavesVersion() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID annotId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> {
|
||||
TranscriptionBlock b = inv.getArgument(0);
|
||||
b.setId(UUID.randomUUID());
|
||||
return b;
|
||||
});
|
||||
|
||||
TranscriptionBlock result = transcriptionService.createOcrBlock(
|
||||
docId, annotId, "OCR text", 3, userId);
|
||||
|
||||
assertThat(result.getAnnotationId()).isEqualTo(annotId);
|
||||
assertThat(result.getDocumentId()).isEqualTo(docId);
|
||||
assertThat(result.getText()).isEqualTo("OCR text");
|
||||
assertThat(result.getSortOrder()).isEqualTo(3);
|
||||
assertThat(result.getSource()).isEqualTo(BlockSource.OCR);
|
||||
assertThat(result.getCreatedBy()).isEqualTo(userId);
|
||||
verify(versionRepository).save(any(TranscriptionBlockVersion.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void createOcrBlock_sanitizesNullText() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID annotId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> {
|
||||
TranscriptionBlock b = inv.getArgument(0);
|
||||
b.setId(UUID.randomUUID());
|
||||
return b;
|
||||
});
|
||||
|
||||
TranscriptionBlock result = transcriptionService.createOcrBlock(
|
||||
docId, annotId, null, 0, userId);
|
||||
|
||||
assertThat(result.getText()).isEmpty();
|
||||
}
|
||||
|
||||
// ─── updateBlock ─────────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
@@ -168,6 +213,39 @@ class TranscriptionServiceTest {
|
||||
.satisfies(e -> assertThat(((DomainException) e).getStatus()).isEqualTo(NOT_FOUND));
|
||||
}
|
||||
|
||||
// ─── deleteAllBlocksByDocument ─────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void deleteAllBlocksByDocument_deletesAllBlocksAndAnnotations() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID annId1 = UUID.randomUUID();
|
||||
UUID annId2 = UUID.randomUUID();
|
||||
|
||||
TranscriptionBlock block1 = TranscriptionBlock.builder()
|
||||
.id(UUID.randomUUID()).documentId(docId).annotationId(annId1).sortOrder(0).build();
|
||||
TranscriptionBlock block2 = TranscriptionBlock.builder()
|
||||
.id(UUID.randomUUID()).documentId(docId).annotationId(annId2).sortOrder(1).build();
|
||||
|
||||
when(blockRepository.findByDocumentIdOrderBySortOrderAsc(docId))
|
||||
.thenReturn(List.of(block1, block2));
|
||||
|
||||
transcriptionService.deleteAllBlocksByDocument(docId);
|
||||
|
||||
verify(blockRepository).deleteAll(List.of(block1, block2));
|
||||
verify(blockRepository).flush();
|
||||
verify(annotationRepository).deleteAllById(List.of(annId1, annId2));
|
||||
}
|
||||
|
||||
@Test
|
||||
void deleteAllBlocksByDocument_doesNothing_whenNoBlocksExist() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
when(blockRepository.findByDocumentIdOrderBySortOrderAsc(docId)).thenReturn(List.of());
|
||||
|
||||
transcriptionService.deleteAllBlocksByDocument(docId);
|
||||
|
||||
verify(blockRepository, never()).deleteAll(any());
|
||||
}
|
||||
|
||||
// ─── reorderBlocks ───────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
@@ -243,4 +321,47 @@ class TranscriptionServiceTest {
|
||||
|
||||
assertThat(transcriptionService.listBlocks(docId)).containsExactly(b);
|
||||
}
|
||||
|
||||
// ─── reviewBlock ─────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void reviewBlock_setsReviewedTrue() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID blockId = UUID.randomUUID();
|
||||
TranscriptionBlock block = TranscriptionBlock.builder()
|
||||
.id(blockId).documentId(docId).annotationId(UUID.randomUUID())
|
||||
.text("corrected text").sortOrder(0).reviewed(false).build();
|
||||
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
TranscriptionBlock result = transcriptionService.reviewBlock(docId, blockId);
|
||||
|
||||
assertThat(result.isReviewed()).isTrue();
|
||||
verify(blockRepository).save(block);
|
||||
}
|
||||
|
||||
@Test
|
||||
void reviewBlock_togglesReviewedFalse_whenAlreadyReviewed() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID blockId = UUID.randomUUID();
|
||||
TranscriptionBlock block = TranscriptionBlock.builder()
|
||||
.id(blockId).documentId(docId).annotationId(UUID.randomUUID())
|
||||
.text("corrected text").sortOrder(0).reviewed(true).build();
|
||||
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
TranscriptionBlock result = transcriptionService.reviewBlock(docId, blockId);
|
||||
|
||||
assertThat(result.isReviewed()).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
void reviewBlock_throwsNotFound_whenBlockMissing() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID blockId = UUID.randomUUID();
|
||||
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.empty());
|
||||
|
||||
assertThatThrownBy(() -> transcriptionService.reviewBlock(docId, blockId))
|
||||
.isInstanceOf(DomainException.class);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,6 +71,35 @@ services:
|
||||
networks:
|
||||
- archive-net
|
||||
|
||||
# --- OCR: Python microservice (Surya + Kraken) ---
|
||||
ocr-service:
|
||||
build:
|
||||
context: ./ocr-service
|
||||
dockerfile: Dockerfile
|
||||
container_name: archive-ocr
|
||||
restart: unless-stopped
|
||||
expose:
|
||||
- "8000"
|
||||
mem_limit: 8g
|
||||
memswap_limit: 8g
|
||||
volumes:
|
||||
- ocr_models:/app/models
|
||||
- ocr_cache:/root/.cache
|
||||
environment:
|
||||
KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel
|
||||
OCR_CONFIDENCE_THRESHOLD: "0.3"
|
||||
OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5"
|
||||
RECOGNITION_BATCH_SIZE: "16"
|
||||
DETECTOR_BATCH_SIZE: "8"
|
||||
networks:
|
||||
- archive-net
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 12
|
||||
start_period: 60s
|
||||
|
||||
# --- Backend: Spring Boot ---
|
||||
backend:
|
||||
build:
|
||||
@@ -89,6 +118,8 @@ services:
|
||||
condition: service_healthy
|
||||
mailpit:
|
||||
condition: service_started
|
||||
ocr-service:
|
||||
condition: service_started
|
||||
environment:
|
||||
SPRING_DATASOURCE_URL: jdbc:postgresql://db:5432/${POSTGRES_DB}
|
||||
SPRING_DATASOURCE_USERNAME: ${POSTGRES_USER}
|
||||
@@ -109,6 +140,7 @@ services:
|
||||
# Mailpit needs no auth or STARTTLS; production SMTP overrides these via .env
|
||||
SPRING_MAIL_PROPERTIES_MAIL_SMTP_AUTH: ${MAIL_SMTP_AUTH:-false}
|
||||
SPRING_MAIL_PROPERTIES_MAIL_SMTP_STARTTLS_ENABLE: ${MAIL_STARTTLS_ENABLE:-false}
|
||||
APP_OCR_BASE_URL: http://ocr-service:8000
|
||||
ports:
|
||||
- "${PORT_BACKEND}:8080"
|
||||
networks:
|
||||
@@ -155,3 +187,5 @@ networks:
|
||||
volumes:
|
||||
frontend_node_modules:
|
||||
maven_cache:
|
||||
ocr_models:
|
||||
ocr_cache:
|
||||
|
||||
84
docs/adr/001-ocr-python-microservice.md
Normal file
84
docs/adr/001-ocr-python-microservice.md
Normal file
@@ -0,0 +1,84 @@
|
||||
# ADR-001: OCR Python Microservice
|
||||
|
||||
## Status
|
||||
|
||||
Accepted
|
||||
|
||||
## Context
|
||||
|
||||
The Familienarchiv needs OCR capability to pre-populate transcription blocks from scanned documents. Two OCR engines are required:
|
||||
|
||||
- **Surya** — transformer-based, handles typewritten and modern Latin handwriting
|
||||
- **Kraken** — historical HTR model support, required for pre-1941 German Kurrent/Suetterlin scripts
|
||||
|
||||
Both engines exist exclusively in the Python ecosystem. There are no production-quality Java bindings for either engine. Tess4J (Tesseract for Java) was considered but rejected: Tesseract has poor accuracy on degraded historical handwriting and no HTR-United model support.
|
||||
|
||||
The server has no GPU. CPU-only inference is the target (16-32 GB system RAM).
|
||||
|
||||
## Decision
|
||||
|
||||
Introduce a separate Python container (`ocr-service`) that exposes a simple HTTP API. Spring Boot calls this service via `RestClient`. The Python service is stateless — all job tracking and business logic remain in Spring Boot.
|
||||
|
||||
**Interface contract:**
|
||||
|
||||
Request:
|
||||
```json
|
||||
{
|
||||
"pdfUrl": "http://minio:9000/archive-documents/abc.pdf?presigned...",
|
||||
"scriptType": "HANDWRITING_KURRENT",
|
||||
"language": "de"
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"pageNumber": 0,
|
||||
"x": 0.12, "y": 0.08, "width": 0.76, "height": 0.04,
|
||||
"polygon": [[0.12,0.08],[0.88,0.09],[0.87,0.12],[0.13,0.11]],
|
||||
"text": "Sehr geehrter Herr ..."
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Coordinates are normalized (0-1) relative to page dimensions.
|
||||
|
||||
**Java-side integration:**
|
||||
|
||||
- `OcrClient` interface with `extractBlocks()` method — mockable for unit tests
|
||||
- `OcrHealthClient` interface with `isHealthy()` — separate concern from block extraction
|
||||
- `RestClientOcrClient` implements both interfaces
|
||||
- `OcrService` orchestrates: presigned URL generation, OCR call, block mapping, TranscriptionService delegation
|
||||
|
||||
**Docker networking:**
|
||||
|
||||
- `ocr-service` is on the internal Docker network only — no host port mapping
|
||||
- Spring Boot reaches it via `http://ocr-service:8000`
|
||||
- Health check with `start_period: 60s` to account for model loading (~30-60s on CPU)
|
||||
|
||||
## Alternatives Considered
|
||||
|
||||
| Alternative | Why rejected |
|
||||
|---|---|
|
||||
| Tess4J (Tesseract in Java) | No HTR-United model support; poor Kurrent accuracy |
|
||||
| Calling Python via ProcessBuilder | Fragile, no health checks, model reloading on every call |
|
||||
| Embedding Python via GraalVM | Experimental, complex dependency management for ML libraries |
|
||||
| External SaaS OCR (Google Vision, AWS Textract) | Data sovereignty concern for private family documents; no Kurrent support |
|
||||
|
||||
## Consequences
|
||||
|
||||
**Easier:**
|
||||
- Each engine is used via its native Python API — no bridging complexity
|
||||
- OCR service can be updated independently of the main application
|
||||
- Models can be swapped via volume mount without code changes
|
||||
|
||||
**Harder:**
|
||||
- One additional container to operate (memory, health checks, restarts)
|
||||
- Integration tests require WireMock stub — real OCR service is too slow for CI
|
||||
- Presigned URL TTL must be managed (15-30 min recommended)
|
||||
|
||||
## Future Direction
|
||||
|
||||
- LISTEN/NOTIFY from PostgreSQL to push progress events when scaling to multiple instances
|
||||
- GPU acceleration if the server is upgraded — only the Docker image needs to change
|
||||
52
docs/adr/002-polygon-jsonb-storage.md
Normal file
52
docs/adr/002-polygon-jsonb-storage.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# ADR-002: Polygon JSONB Storage for Annotations
|
||||
|
||||
## Status
|
||||
|
||||
Accepted
|
||||
|
||||
## Context
|
||||
|
||||
Document annotations currently store axis-aligned bounding boxes (`x, y, width, height`). Kraken OCR outputs polygon boundaries for text lines — historical handwriting (Kurrent, Suetterlin) produces rotated and curved text that axis-aligned rectangles approximate poorly.
|
||||
|
||||
We need to store an optional quadrilateral (4 corner points) per annotation to represent the precise text region. The polygon is display-only — overlap detection and all server-side geometry logic continues to use the AABB fields.
|
||||
|
||||
## Decision
|
||||
|
||||
Add a `polygon JSONB` column to `document_annotations`:
|
||||
|
||||
```sql
|
||||
ALTER TABLE document_annotations ADD COLUMN polygon JSONB;
|
||||
ALTER TABLE document_annotations
|
||||
ADD CONSTRAINT chk_annotation_polygon_quad
|
||||
CHECK (polygon IS NULL OR jsonb_array_length(polygon) = 4);
|
||||
```
|
||||
|
||||
- `null` means rectangle — render using existing `x, y, width, height` fields (fully backward compatible)
|
||||
- Non-null value is a normalized 4-point quadrilateral: `[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]` with coordinates in the 0-1 range relative to page dimensions
|
||||
|
||||
The existing AABB fields are always populated (even when a polygon is present) and remain the authoritative geometry for overlap detection.
|
||||
|
||||
**Java entity:** `List<List<Double>> polygon` backed by a custom `AttributeConverter<List<List<Double>>, String>`. No new dependency (Hypersistence Utils is not in the project and won't be added for a single column).
|
||||
|
||||
**Semantic invariant:** `polygon`, if present, is a 4-point quadrilateral with coordinates normalized to [0, 1] relative to page dimensions. It may originate from OCR engine output (Kraken) or from a future manual drawing tool. The AABB fields remain the geometry source of truth for server-side logic.
|
||||
|
||||
## Alternatives Considered
|
||||
|
||||
| Alternative | Why rejected |
|
||||
|---|---|
|
||||
| 8 `NUMERIC(8,6)` columns (x1,y1,...,x4,y4) | Verbose, no structural enforcement, awkward to query or extend |
|
||||
| Separate `annotation_polygons` join table | Unnecessary complexity for a 1:1 optional relationship |
|
||||
| PostGIS geometry column | Adds a heavyweight extension for a display-only field with no spatial queries |
|
||||
| `String polygon` on the entity | Requires manual parsing at every callsite; error-prone |
|
||||
|
||||
## Consequences
|
||||
|
||||
**Easier:**
|
||||
- Backward compatible — all existing annotations continue to work unchanged
|
||||
- Frontend renders `<polygon>` or `<rect>` based on a simple null check
|
||||
- Schema can accommodate N-point polygons in the future (JSONB is flexible), though the CHECK constraint currently enforces exactly 4
|
||||
|
||||
**Harder:**
|
||||
- Cannot express range checks (`0 <= x <= 1`) as database constraints without a PL/pgSQL function — validated at the DTO layer instead
|
||||
- No server-side geometry queries on polygon coordinates (acceptable — polygon is display-only)
|
||||
- AttributeConverter adds a small amount of serialization code to maintain
|
||||
@@ -500,5 +500,37 @@
|
||||
"person_alias_delete_title": "Alias entfernen?",
|
||||
"person_alias_delete_body": "Dieser Name wird aus der Suche entfernt.",
|
||||
"person_alias_btn_delete": "Entfernen",
|
||||
"error_alias_not_found": "Der Namensalias wurde nicht gefunden."
|
||||
"error_alias_not_found": "Der Namensalias wurde nicht gefunden.",
|
||||
"error_ocr_service_unavailable": "Der OCR-Dienst ist nicht verfügbar.",
|
||||
"error_ocr_job_not_found": "Der OCR-Auftrag wurde nicht gefunden.",
|
||||
"error_ocr_document_not_uploaded": "Das Dokument hat keine Datei — OCR ist nicht möglich.",
|
||||
"error_ocr_processing_failed": "Die OCR-Verarbeitung ist fehlgeschlagen.",
|
||||
"ocr_script_type_typewriter": "Schreibmaschine",
|
||||
"ocr_script_type_handwriting_latin": "Handschrift (lateinisch)",
|
||||
"ocr_script_type_handwriting_kurrent": "Handschrift (Kurrent/Sütterlin)",
|
||||
"ocr_trigger_label": "Schrifttyp",
|
||||
"ocr_trigger_select_placeholder": "Schrifttyp wählen…",
|
||||
"ocr_trigger_btn": "OCR starten",
|
||||
"ocr_trigger_btn_disabled": "Bitte wählen Sie einen Schrifttyp",
|
||||
"ocr_confirm_title": "Vorhandene Transkription ersetzen?",
|
||||
"ocr_confirm_body": "Alle {count} vorhandenen Blöcke werden gelöscht und durch die OCR-Ergebnisse ersetzt. Diese Aktion kann nicht rückgängig gemacht werden.",
|
||||
"ocr_confirm_btn": "Ersetzen",
|
||||
"ocr_rerun_label": "OCR erneut ausführen…",
|
||||
"ocr_progress_heading": "OCR läuft",
|
||||
"ocr_progress_page": "Seite {current} von {total}",
|
||||
"ocr_error_heading": "OCR fehlgeschlagen",
|
||||
"ocr_error_retry": "Erneut versuchen",
|
||||
"ocr_batch_running": "OCR läuft · {processed} von {total} Dokumente abgeschlossen",
|
||||
"ocr_batch_done": "OCR abgeschlossen · {processed} erfolgreich · {errors} fehlgeschlagen",
|
||||
"ocr_status_preparing": "Dokument wird vorbereitet…",
|
||||
"ocr_status_loading": "Lade Modell und Dokument…",
|
||||
"ocr_status_analyzing": "OCR-Analyse läuft — dies kann einige Minuten dauern…",
|
||||
"ocr_status_creating_blocks": "{count} Textblöcke erkannt — erstelle Transkription…",
|
||||
"ocr_status_done_blocks": "{count} Blöcke erstellt",
|
||||
"ocr_status_analyzing_page": "Seite {current} von {total} wird analysiert…",
|
||||
"ocr_status_done_skipped": "{count} Blöcke erstellt, {skipped} Seite(n) übersprungen",
|
||||
"ocr_status_error": "OCR fehlgeschlagen",
|
||||
"transcription_block_review": "Als geprüft markieren",
|
||||
"transcription_block_unreview": "Markierung aufheben",
|
||||
"transcription_reviewed_count": "{reviewed} von {total} geprüft"
|
||||
}
|
||||
|
||||
@@ -500,5 +500,37 @@
|
||||
"person_alias_delete_title": "Remove alias?",
|
||||
"person_alias_delete_body": "This name will be removed from search results.",
|
||||
"person_alias_btn_delete": "Remove",
|
||||
"error_alias_not_found": "The name alias was not found."
|
||||
"error_alias_not_found": "The name alias was not found.",
|
||||
"error_ocr_service_unavailable": "The OCR service is not available.",
|
||||
"error_ocr_job_not_found": "The OCR job was not found.",
|
||||
"error_ocr_document_not_uploaded": "The document has no file — OCR is not possible.",
|
||||
"error_ocr_processing_failed": "OCR processing failed.",
|
||||
"ocr_script_type_typewriter": "Typewriter",
|
||||
"ocr_script_type_handwriting_latin": "Handwriting (Latin)",
|
||||
"ocr_script_type_handwriting_kurrent": "Handwriting (Kurrent/Sütterlin)",
|
||||
"ocr_trigger_label": "Script type",
|
||||
"ocr_trigger_select_placeholder": "Select script type…",
|
||||
"ocr_trigger_btn": "Start OCR",
|
||||
"ocr_trigger_btn_disabled": "Please select a script type",
|
||||
"ocr_confirm_title": "Replace existing transcription?",
|
||||
"ocr_confirm_body": "All {count} existing blocks will be deleted and replaced with OCR results. This action cannot be undone.",
|
||||
"ocr_confirm_btn": "Replace",
|
||||
"ocr_rerun_label": "Re-run OCR…",
|
||||
"ocr_progress_heading": "OCR running",
|
||||
"ocr_progress_page": "Page {current} of {total}",
|
||||
"ocr_error_heading": "OCR failed",
|
||||
"ocr_error_retry": "Try again",
|
||||
"ocr_batch_running": "OCR running · {processed} of {total} documents complete",
|
||||
"ocr_batch_done": "OCR complete · {processed} successful · {errors} failed",
|
||||
"ocr_status_preparing": "Preparing document…",
|
||||
"ocr_status_loading": "Loading model and document…",
|
||||
"ocr_status_analyzing": "OCR analysis running — this may take a few minutes…",
|
||||
"ocr_status_creating_blocks": "{count} text blocks detected — creating transcription…",
|
||||
"ocr_status_done_blocks": "{count} blocks created",
|
||||
"ocr_status_analyzing_page": "Analyzing page {current} of {total}…",
|
||||
"ocr_status_done_skipped": "{count} blocks created, {skipped} page(s) skipped",
|
||||
"ocr_status_error": "OCR failed",
|
||||
"transcription_block_review": "Mark as reviewed",
|
||||
"transcription_block_unreview": "Unmark as reviewed",
|
||||
"transcription_reviewed_count": "{reviewed} of {total} reviewed"
|
||||
}
|
||||
|
||||
@@ -500,5 +500,37 @@
|
||||
"person_alias_delete_title": "Eliminar alias?",
|
||||
"person_alias_delete_body": "Este nombre se eliminara de los resultados de busqueda.",
|
||||
"person_alias_btn_delete": "Eliminar",
|
||||
"error_alias_not_found": "No se encontro el alias de nombre."
|
||||
"error_alias_not_found": "No se encontro el alias de nombre.",
|
||||
"error_ocr_service_unavailable": "El servicio OCR no está disponible.",
|
||||
"error_ocr_job_not_found": "No se encontró el trabajo OCR.",
|
||||
"error_ocr_document_not_uploaded": "El documento no tiene archivo — OCR no es posible.",
|
||||
"error_ocr_processing_failed": "El procesamiento OCR ha fallado.",
|
||||
"ocr_script_type_typewriter": "Máquina de escribir",
|
||||
"ocr_script_type_handwriting_latin": "Escritura manuscrita (latina)",
|
||||
"ocr_script_type_handwriting_kurrent": "Escritura manuscrita (Kurrent/Sütterlin)",
|
||||
"ocr_trigger_label": "Tipo de escritura",
|
||||
"ocr_trigger_select_placeholder": "Seleccionar tipo de escritura…",
|
||||
"ocr_trigger_btn": "Iniciar OCR",
|
||||
"ocr_trigger_btn_disabled": "Por favor seleccione un tipo de escritura",
|
||||
"ocr_confirm_title": "¿Reemplazar transcripción existente?",
|
||||
"ocr_confirm_body": "Los {count} bloques existentes serán eliminados y reemplazados con los resultados del OCR. Esta acción no se puede deshacer.",
|
||||
"ocr_confirm_btn": "Reemplazar",
|
||||
"ocr_rerun_label": "Ejecutar OCR de nuevo…",
|
||||
"ocr_progress_heading": "OCR en curso",
|
||||
"ocr_progress_page": "Página {current} de {total}",
|
||||
"ocr_error_heading": "OCR fallido",
|
||||
"ocr_error_retry": "Intentar de nuevo",
|
||||
"ocr_batch_running": "OCR en curso · {processed} de {total} documentos completados",
|
||||
"ocr_batch_done": "OCR completado · {processed} exitosos · {errors} fallidos",
|
||||
"ocr_status_preparing": "Preparando documento…",
|
||||
"ocr_status_loading": "Cargando modelo y documento…",
|
||||
"ocr_status_analyzing": "Análisis OCR en curso — esto puede tardar unos minutos…",
|
||||
"ocr_status_creating_blocks": "{count} bloques de texto detectados — creando transcripción…",
|
||||
"ocr_status_done_blocks": "{count} bloques creados",
|
||||
"ocr_status_analyzing_page": "Analizando página {current} de {total}…",
|
||||
"ocr_status_done_skipped": "{count} bloques creados, {skipped} página(s) omitida(s)",
|
||||
"ocr_status_error": "OCR fallido",
|
||||
"transcription_block_review": "Marcar como revisado",
|
||||
"transcription_block_unreview": "Desmarcar como revisado",
|
||||
"transcription_reviewed_count": "{reviewed} de {total} revisados"
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import type { Annotation } from '$lib/types';
|
||||
import AnnotationShape from './AnnotationShape.svelte';
|
||||
|
||||
type DrawRect = {
|
||||
x: number;
|
||||
@@ -33,13 +34,6 @@ let {
|
||||
let drawStart = $state<{ x: number; y: number } | null>(null);
|
||||
let drawRect = $state<DrawRect | null>(null);
|
||||
|
||||
function hexToRgba(hex: string, alpha: number): string {
|
||||
const r = parseInt(hex.slice(1, 3), 16);
|
||||
const g = parseInt(hex.slice(3, 5), 16);
|
||||
const b = parseInt(hex.slice(5, 7), 16);
|
||||
return `rgba(${r}, ${g}, ${b}, ${alpha})`;
|
||||
}
|
||||
|
||||
function getNormalizedCoords(event: PointerEvent, element: HTMLElement): { x: number; y: number } {
|
||||
const rect = element.getBoundingClientRect();
|
||||
return {
|
||||
@@ -109,58 +103,18 @@ const containerStyle = $derived(
|
||||
onpointerup={handlePointerUp}
|
||||
>
|
||||
{#each annotations as annotation (annotation.id)}
|
||||
<div
|
||||
data-testid="annotation-{annotation.id}"
|
||||
data-annotation
|
||||
class:annotation-flash={flashAnnotationId === annotation.id}
|
||||
role="button"
|
||||
tabindex="0"
|
||||
aria-label="Block anzeigen"
|
||||
<AnnotationShape
|
||||
annotation={annotation}
|
||||
isHovered={hoveredId === annotation.id}
|
||||
isActive={annotation.id === activeAnnotationId}
|
||||
faded={!dimmed && !!activeAnnotationId && annotation.id !== activeAnnotationId}
|
||||
dimmed={dimmed}
|
||||
blockNumber={blockNumbers[annotation.id]}
|
||||
isFlashing={flashAnnotationId === annotation.id}
|
||||
onclick={() => onAnnotationClick?.(annotation.id)}
|
||||
onkeydown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') onAnnotationClick?.(annotation.id);
|
||||
}}
|
||||
onpointerenter={() => (hoveredId = annotation.id)}
|
||||
onpointerleave={() => (hoveredId = null)}
|
||||
style="
|
||||
position: absolute;
|
||||
left: {annotation.x * 100}%;
|
||||
top: {annotation.y * 100}%;
|
||||
width: {annotation.width * 100}%;
|
||||
height: {annotation.height * 100}%;
|
||||
background-color: {hexToRgba(annotation.color, dimmed ? 0.3 : (hoveredId === annotation.id || annotation.id === activeAnnotationId ? 0.5 : 0.3))};
|
||||
box-shadow: {dimmed ? 'none' : (annotation.id === activeAnnotationId ? `inset 0 0 0 2px ${hexToRgba(annotation.color, 0.8)}` : hoveredId === annotation.id ? `inset 0 0 0 2px ${hexToRgba(annotation.color, 0.8)}` : 'none')};
|
||||
opacity: {dimmed ? 1 : (activeAnnotationId && annotation.id !== activeAnnotationId ? 0.3 : 1)};
|
||||
pointer-events: auto;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.15s ease, box-shadow 0.15s ease, opacity 0.3s ease;
|
||||
"
|
||||
>
|
||||
{#if !dimmed && blockNumbers[annotation.id]}
|
||||
<div
|
||||
style="
|
||||
position: absolute;
|
||||
top: -8px;
|
||||
left: -8px;
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border-radius: 50%;
|
||||
background-color: {annotation.color};
|
||||
color: white;
|
||||
font-size: 11px;
|
||||
font-family: sans-serif;
|
||||
font-weight: 700;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
pointer-events: none;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.3);
|
||||
"
|
||||
>
|
||||
{blockNumbers[annotation.id]}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
/>
|
||||
{/each}
|
||||
|
||||
{#if drawRect && drawRect.width > 0}
|
||||
@@ -178,27 +132,3 @@ const containerStyle = $derived(
|
||||
></div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
@keyframes annotation-flash-anim {
|
||||
0% {
|
||||
outline: 3px solid color-mix(in srgb, var(--color-turquoise) 80%, transparent);
|
||||
outline-offset: 0px;
|
||||
}
|
||||
100% {
|
||||
outline: 3px solid color-mix(in srgb, var(--color-turquoise) 0%, transparent);
|
||||
outline-offset: 2px;
|
||||
}
|
||||
}
|
||||
|
||||
.annotation-flash {
|
||||
animation: annotation-flash-anim 1.5s ease-out;
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.annotation-flash {
|
||||
animation: none;
|
||||
outline: 3px solid color-mix(in srgb, var(--color-turquoise) 80%, transparent);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
136
frontend/src/lib/components/AnnotationShape.svelte
Normal file
136
frontend/src/lib/components/AnnotationShape.svelte
Normal file
@@ -0,0 +1,136 @@
|
||||
<script lang="ts">
|
||||
import type { Annotation } from '$lib/types';
|
||||
|
||||
let {
|
||||
annotation,
|
||||
isHovered,
|
||||
isActive,
|
||||
faded = false,
|
||||
dimmed = false,
|
||||
blockNumber = undefined,
|
||||
isFlashing = false,
|
||||
onclick,
|
||||
onpointerenter,
|
||||
onpointerleave
|
||||
}: {
|
||||
annotation: Annotation;
|
||||
isHovered: boolean;
|
||||
isActive: boolean;
|
||||
faded?: boolean;
|
||||
dimmed?: boolean;
|
||||
blockNumber?: number | undefined;
|
||||
isFlashing?: boolean;
|
||||
onclick: () => void;
|
||||
onpointerenter: () => void;
|
||||
onpointerleave: () => void;
|
||||
} = $props();
|
||||
|
||||
function hexToRgba(hex: string, alpha: number): string {
|
||||
const r = parseInt(hex.slice(1, 3), 16);
|
||||
const g = parseInt(hex.slice(3, 5), 16);
|
||||
const b = parseInt(hex.slice(5, 7), 16);
|
||||
return `rgba(${r}, ${g}, ${b}, ${alpha})`;
|
||||
}
|
||||
|
||||
let clipPath = $derived.by(() => {
|
||||
if (!annotation.polygon || annotation.polygon.length !== 4) return 'none';
|
||||
const points = annotation.polygon
|
||||
.map(([px, py]) => {
|
||||
const cx = ((px - annotation.x) / annotation.width) * 100;
|
||||
const cy = ((py - annotation.y) / annotation.height) * 100;
|
||||
return `${cx}% ${cy}%`;
|
||||
})
|
||||
.join(', ');
|
||||
return `polygon(${points})`;
|
||||
});
|
||||
|
||||
let bgAlpha = $derived(dimmed ? 0.3 : isHovered || isActive ? 0.5 : 0.3);
|
||||
|
||||
let boxShadow = $derived.by(() => {
|
||||
if (dimmed) return 'none';
|
||||
if (isActive || isHovered) return `inset 0 0 0 2px ${hexToRgba(annotation.color, 0.8)}`;
|
||||
return 'none';
|
||||
});
|
||||
|
||||
let opacity = $derived(dimmed ? 1 : faded ? 0.3 : 1);
|
||||
|
||||
let shapeStyle = $derived(
|
||||
`position: absolute;` +
|
||||
` left: ${annotation.x * 100}%;` +
|
||||
` top: ${annotation.y * 100}%;` +
|
||||
` width: ${annotation.width * 100}%;` +
|
||||
` height: ${annotation.height * 100}%;` +
|
||||
` background-color: ${hexToRgba(annotation.color, bgAlpha)};` +
|
||||
` box-shadow: ${boxShadow};` +
|
||||
` opacity: ${opacity};` +
|
||||
` pointer-events: auto;` +
|
||||
` cursor: pointer;` +
|
||||
` transition: background-color 0.15s ease, box-shadow 0.15s ease, opacity 0.3s ease;` +
|
||||
(clipPath !== 'none' ? ` clip-path: ${clipPath};` : '')
|
||||
);
|
||||
</script>
|
||||
|
||||
<div
|
||||
data-testid="annotation-{annotation.id}"
|
||||
data-annotation
|
||||
class:annotation-flash={isFlashing}
|
||||
role="button"
|
||||
tabindex="0"
|
||||
aria-label="Block anzeigen"
|
||||
onclick={onclick}
|
||||
onkeydown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') onclick();
|
||||
}}
|
||||
onpointerenter={onpointerenter}
|
||||
onpointerleave={onpointerleave}
|
||||
style={shapeStyle}
|
||||
>
|
||||
{#if !dimmed && blockNumber}
|
||||
<div
|
||||
style="
|
||||
position: absolute;
|
||||
top: -8px;
|
||||
left: -8px;
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
border-radius: 50%;
|
||||
background-color: {annotation.color};
|
||||
color: white;
|
||||
font-size: 12px;
|
||||
font-family: sans-serif;
|
||||
font-weight: 700;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
pointer-events: none;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.3);
|
||||
"
|
||||
>
|
||||
{blockNumber}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
@keyframes annotation-flash-anim {
|
||||
0% {
|
||||
outline: 3px solid color-mix(in srgb, var(--color-turquoise) 80%, transparent);
|
||||
outline-offset: 0px;
|
||||
}
|
||||
100% {
|
||||
outline: 3px solid color-mix(in srgb, var(--color-turquoise) 0%, transparent);
|
||||
outline-offset: 2px;
|
||||
}
|
||||
}
|
||||
|
||||
.annotation-flash {
|
||||
animation: annotation-flash-anim 1.5s ease-out;
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.annotation-flash {
|
||||
animation: none;
|
||||
outline: 3px solid color-mix(in srgb, var(--color-turquoise) 80%, transparent);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
91
frontend/src/lib/components/OcrProgress.svelte
Normal file
91
frontend/src/lib/components/OcrProgress.svelte
Normal file
@@ -0,0 +1,91 @@
|
||||
<script lang="ts">
|
||||
import { m } from '$lib/paraglide/messages.js';
|
||||
|
||||
interface Props {
|
||||
jobId: string;
|
||||
onDone: () => void;
|
||||
}
|
||||
|
||||
let { jobId, onDone }: Props = $props();
|
||||
|
||||
let status: 'running' | 'done' | 'error' = $state('running');
|
||||
let processed: number = $state(0);
|
||||
let total: number = $state(0);
|
||||
let currentPage: number = $state(0);
|
||||
let totalPages: number = $state(0);
|
||||
let retryCount: number = $state(0);
|
||||
|
||||
let progressPercent = $derived(total > 0 ? Math.round((processed / total) * 100) : 0);
|
||||
|
||||
$effect(() => {
|
||||
void retryCount; // track dependency to re-create EventSource on retry
|
||||
const source = new EventSource(`/api/ocr/jobs/${jobId}/progress`);
|
||||
|
||||
source.addEventListener('document', (e) => {
|
||||
const data = JSON.parse(e.data);
|
||||
processed = data.processed;
|
||||
total = data.total;
|
||||
});
|
||||
|
||||
source.addEventListener('page', (e) => {
|
||||
const data = JSON.parse(e.data);
|
||||
currentPage = data.page;
|
||||
totalPages = data.totalPages;
|
||||
});
|
||||
|
||||
source.addEventListener('done', () => {
|
||||
status = 'done';
|
||||
source.close();
|
||||
onDone();
|
||||
});
|
||||
|
||||
source.addEventListener('error', () => {
|
||||
status = 'error';
|
||||
source.close();
|
||||
});
|
||||
|
||||
source.onerror = () => {
|
||||
status = 'error';
|
||||
source.close();
|
||||
};
|
||||
|
||||
return () => {
|
||||
source.close();
|
||||
};
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if status === 'running'}
|
||||
<div class="border-brand-sand rounded-sm border bg-white p-4">
|
||||
<h3 class="mb-3 text-xs font-bold tracking-widest text-gray-400 uppercase">
|
||||
{m.ocr_progress_heading()}
|
||||
</h3>
|
||||
<div class="bg-brand-sand h-2 w-full overflow-hidden rounded-full">
|
||||
<div
|
||||
class="h-full bg-brand-mint transition-all duration-300"
|
||||
style="width: {progressPercent}%"
|
||||
role="progressbar"
|
||||
aria-label={m.ocr_progress_heading()}
|
||||
aria-valuenow={progressPercent}
|
||||
aria-valuemin={0}
|
||||
aria-valuemax={100}
|
||||
></div>
|
||||
</div>
|
||||
<p class="mt-2 text-right text-sm text-gray-500">
|
||||
{m.ocr_progress_page({ current: String(currentPage), total: String(totalPages) })}
|
||||
</p>
|
||||
</div>
|
||||
{:else if status === 'error'}
|
||||
<div class="border-brand-sand rounded-sm border border-l-4 border-l-red-500 bg-white p-4">
|
||||
<h3 class="mb-2 text-sm font-semibold text-red-700">
|
||||
{m.ocr_error_heading()}
|
||||
</h3>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => { retryCount++; status = 'running'; }}
|
||||
class="text-sm font-medium text-brand-navy transition-colors hover:text-brand-navy/80"
|
||||
>
|
||||
{m.ocr_error_retry()}
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
49
frontend/src/lib/components/OcrTrigger.svelte
Normal file
49
frontend/src/lib/components/OcrTrigger.svelte
Normal file
@@ -0,0 +1,49 @@
|
||||
<script lang="ts">
|
||||
import { untrack } from 'svelte';
|
||||
import { m } from '$lib/paraglide/messages.js';
|
||||
import { getConfirmService } from '$lib/services/confirm.svelte';
|
||||
import ScriptTypeSelect from './ScriptTypeSelect.svelte';
|
||||
|
||||
interface Props {
|
||||
existingBlockCount: number;
|
||||
storedScriptType: string;
|
||||
onTrigger: (scriptType: string) => void;
|
||||
}
|
||||
|
||||
let { existingBlockCount, storedScriptType, onTrigger }: Props = $props();
|
||||
|
||||
const { confirm } = getConfirmService();
|
||||
|
||||
let selectedScriptType: string = $state(
|
||||
untrack(() => (storedScriptType && storedScriptType !== 'UNKNOWN' ? storedScriptType : ''))
|
||||
);
|
||||
|
||||
async function handleClick() {
|
||||
if (!selectedScriptType) return;
|
||||
|
||||
if (existingBlockCount > 0) {
|
||||
const confirmed = await confirm({
|
||||
title: m.ocr_confirm_title(),
|
||||
body: m.ocr_confirm_body({ count: String(existingBlockCount) }),
|
||||
confirmLabel: m.ocr_confirm_btn(),
|
||||
destructive: true
|
||||
});
|
||||
if (!confirmed) return;
|
||||
}
|
||||
|
||||
onTrigger(selectedScriptType);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col gap-3">
|
||||
<ScriptTypeSelect bind:value={selectedScriptType} />
|
||||
<button
|
||||
type="button"
|
||||
disabled={!selectedScriptType}
|
||||
title={!selectedScriptType ? m.ocr_trigger_btn_disabled() : undefined}
|
||||
onclick={handleClick}
|
||||
class="min-h-[44px] w-full rounded-sm bg-brand-navy font-sans text-sm font-medium text-white transition-colors hover:bg-brand-navy/90 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
{m.ocr_trigger_btn()}
|
||||
</button>
|
||||
</div>
|
||||
@@ -145,6 +145,7 @@ async function renderPage(doc: PDFDocumentProxy, pageNum: number) {
|
||||
|
||||
// Text layer
|
||||
const textDiv = textLayerEl;
|
||||
if (!textDiv) return;
|
||||
textDiv.innerHTML = '';
|
||||
textDiv.style.width = `${viewport.width / dpr}px`;
|
||||
textDiv.style.height = `${viewport.height / dpr}px`;
|
||||
|
||||
27
frontend/src/lib/components/ScriptTypeSelect.svelte
Normal file
27
frontend/src/lib/components/ScriptTypeSelect.svelte
Normal file
@@ -0,0 +1,27 @@
|
||||
<script lang="ts">
|
||||
import { m } from '$lib/paraglide/messages.js';
|
||||
|
||||
interface Props {
|
||||
value: string;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
let { value = $bindable(), disabled = false }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div>
|
||||
<label for="script-type-select" class="text-sm font-bold tracking-widest text-gray-400 uppercase">
|
||||
{m.ocr_trigger_label()}
|
||||
</label>
|
||||
<select
|
||||
id="script-type-select"
|
||||
bind:value={value}
|
||||
disabled={disabled}
|
||||
class="border-brand-sand min-h-[44px] w-full rounded-sm border bg-white px-3 py-2 font-serif text-sm text-brand-navy focus:ring-2 focus:ring-brand-mint focus:outline-none"
|
||||
>
|
||||
<option value="" disabled>{m.ocr_trigger_select_placeholder()}</option>
|
||||
<option value="TYPEWRITER">{m.ocr_script_type_typewriter()}</option>
|
||||
<option value="HANDWRITING_LATIN">{m.ocr_script_type_handwriting_latin()}</option>
|
||||
<option value="HANDWRITING_KURRENT">{m.ocr_script_type_handwriting_kurrent()}</option>
|
||||
</select>
|
||||
</div>
|
||||
@@ -14,6 +14,7 @@ type Props = {
|
||||
text: string;
|
||||
label: string | null;
|
||||
active: boolean;
|
||||
reviewed: boolean;
|
||||
saveState: SaveState;
|
||||
canComment: boolean;
|
||||
currentUserId: string | null;
|
||||
@@ -21,6 +22,7 @@ type Props = {
|
||||
onFocus: () => void;
|
||||
onDeleteClick: () => void;
|
||||
onRetry: () => void;
|
||||
onReviewToggle: () => void;
|
||||
onMoveUp?: () => void;
|
||||
onMoveDown?: () => void;
|
||||
isFirst?: boolean;
|
||||
@@ -34,6 +36,7 @@ let {
|
||||
text,
|
||||
label = null,
|
||||
active,
|
||||
reviewed,
|
||||
saveState,
|
||||
canComment,
|
||||
currentUserId,
|
||||
@@ -41,6 +44,7 @@ let {
|
||||
onFocus,
|
||||
onDeleteClick,
|
||||
onRetry,
|
||||
onReviewToggle,
|
||||
onMoveUp,
|
||||
onMoveDown,
|
||||
isFirst = false,
|
||||
@@ -239,6 +243,29 @@ function handleTextareaMouseUp() {
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<!-- Review toggle -->
|
||||
<button
|
||||
type="button"
|
||||
class="cursor-pointer transition-colors {reviewed ? 'text-turquoise hover:text-turquoise/70' : 'text-ink-3 hover:text-turquoise'}"
|
||||
aria-label={reviewed ? m.transcription_block_unreview() : m.transcription_block_review()}
|
||||
title={reviewed ? m.transcription_block_unreview() : m.transcription_block_review()}
|
||||
onclick={onReviewToggle}
|
||||
>
|
||||
<svg
|
||||
class="h-4 w-4"
|
||||
fill={reviewed ? 'currentColor' : 'none'}
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="1.5"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
|
||||
<!-- Delete button -->
|
||||
<button
|
||||
type="button"
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import { m } from '$lib/paraglide/messages.js';
|
||||
import { SvelteMap } from 'svelte/reactivity';
|
||||
import TranscriptionBlock from './TranscriptionBlock.svelte';
|
||||
import OcrTrigger from './OcrTrigger.svelte';
|
||||
import type { TranscriptionBlockData } from '$lib/types';
|
||||
|
||||
type SaveState = 'idle' | 'saving' | 'saved' | 'fading' | 'error';
|
||||
@@ -12,9 +13,13 @@ type Props = {
|
||||
canComment: boolean;
|
||||
currentUserId: string | null;
|
||||
activeAnnotationId?: string | null;
|
||||
storedScriptType?: string;
|
||||
canRunOcr?: boolean;
|
||||
onBlockFocus: (blockId: string) => void;
|
||||
onSaveBlock: (blockId: string, text: string) => Promise<void>;
|
||||
onDeleteBlock: (blockId: string) => Promise<void>;
|
||||
onReviewToggle: (blockId: string) => Promise<void>;
|
||||
onTriggerOcr?: (scriptType: string) => void;
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -23,9 +28,13 @@ let {
|
||||
canComment,
|
||||
currentUserId,
|
||||
activeAnnotationId = null,
|
||||
storedScriptType = '',
|
||||
canRunOcr = false,
|
||||
onBlockFocus,
|
||||
onSaveBlock,
|
||||
onDeleteBlock
|
||||
onDeleteBlock,
|
||||
onReviewToggle,
|
||||
onTriggerOcr
|
||||
}: Props = $props();
|
||||
|
||||
let activeBlockId: string | null = $state(null);
|
||||
@@ -282,6 +291,7 @@ $effect(() => {
|
||||
text={block.text}
|
||||
label={block.label}
|
||||
active={activeBlockId === block.id}
|
||||
reviewed={block.reviewed ?? false}
|
||||
saveState={getSaveState(block.id)}
|
||||
canComment={canComment}
|
||||
currentUserId={currentUserId}
|
||||
@@ -289,6 +299,7 @@ $effect(() => {
|
||||
onFocus={() => handleFocus(block.id)}
|
||||
onDeleteClick={() => handleDelete(block.id)}
|
||||
onRetry={() => handleRetry(block.id)}
|
||||
onReviewToggle={() => onReviewToggle(block.id)}
|
||||
onMoveUp={() => handleMoveUp(block.id)}
|
||||
onMoveDown={() => handleMoveDown(block.id)}
|
||||
isFirst={i === 0}
|
||||
@@ -307,6 +318,23 @@ $effect(() => {
|
||||
>
|
||||
{m.transcription_next_block_cta({ number: sortedBlocks.length + 1 })}
|
||||
</div>
|
||||
|
||||
{#if canRunOcr && onTriggerOcr}
|
||||
<details class="mt-6">
|
||||
<summary
|
||||
class="cursor-pointer font-sans text-xs font-medium text-ink-3 transition-colors hover:text-brand-navy"
|
||||
>
|
||||
{m.ocr_rerun_label()}
|
||||
</summary>
|
||||
<div class="mt-3 max-w-xs">
|
||||
<OcrTrigger
|
||||
existingBlockCount={blocks.length}
|
||||
storedScriptType={storedScriptType}
|
||||
onTrigger={onTriggerOcr}
|
||||
/>
|
||||
</div>
|
||||
</details>
|
||||
{/if}
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex flex-1 flex-col items-center justify-center px-6 py-12 text-center">
|
||||
@@ -323,9 +351,26 @@ $effect(() => {
|
||||
d="M19.5 14.25v-2.625a3.375 3.375 0 00-3.375-3.375h-1.5A1.125 1.125 0 0113.5 7.125v-1.5a3.375 3.375 0 00-3.375-3.375H8.25m0 12.75h7.5m-7.5 3H12M10.5 2.25H5.625c-.621 0-1.125.504-1.125 1.125v17.25c0 .621.504 1.125 1.125 1.125h12.75c.621 0 1.125-.504 1.125-1.125V11.25a9 9 0 00-9-9z"
|
||||
/>
|
||||
</svg>
|
||||
<p class="max-w-xs text-sm leading-relaxed text-ink-3">
|
||||
{m.transcription_empty_cta()}
|
||||
</p>
|
||||
|
||||
{#if canRunOcr && onTriggerOcr}
|
||||
<p class="mb-6 max-w-xs text-sm leading-relaxed text-ink-3">
|
||||
{m.transcription_empty_title()}
|
||||
</p>
|
||||
<div class="w-full max-w-xs">
|
||||
<OcrTrigger
|
||||
existingBlockCount={0}
|
||||
storedScriptType={storedScriptType}
|
||||
onTrigger={onTriggerOcr}
|
||||
/>
|
||||
</div>
|
||||
<p class="mt-4 text-xs text-ink-3">
|
||||
{m.transcription_empty_desc()}
|
||||
</p>
|
||||
{:else}
|
||||
<p class="max-w-xs text-sm leading-relaxed text-ink-3">
|
||||
{m.transcription_empty_cta()}
|
||||
</p>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -22,6 +22,10 @@ export type ErrorCode =
|
||||
| 'TRANSCRIPTION_BLOCK_NOT_FOUND'
|
||||
| 'TRANSCRIPTION_BLOCK_CONFLICT'
|
||||
| 'COMMENT_NOT_FOUND'
|
||||
| 'OCR_SERVICE_UNAVAILABLE'
|
||||
| 'OCR_JOB_NOT_FOUND'
|
||||
| 'OCR_DOCUMENT_NOT_UPLOADED'
|
||||
| 'OCR_PROCESSING_FAILED'
|
||||
| 'UNAUTHORIZED'
|
||||
| 'FORBIDDEN'
|
||||
| 'VALIDATION_ERROR'
|
||||
@@ -85,6 +89,14 @@ export function getErrorMessage(code: ErrorCode | string | undefined): string {
|
||||
return m.error_transcription_block_conflict();
|
||||
case 'COMMENT_NOT_FOUND':
|
||||
return m.error_comment_not_found();
|
||||
case 'OCR_SERVICE_UNAVAILABLE':
|
||||
return m.error_ocr_service_unavailable();
|
||||
case 'OCR_JOB_NOT_FOUND':
|
||||
return m.error_ocr_job_not_found();
|
||||
case 'OCR_DOCUMENT_NOT_UPLOADED':
|
||||
return m.error_ocr_document_not_uploaded();
|
||||
case 'OCR_PROCESSING_FAILED':
|
||||
return m.error_ocr_processing_failed();
|
||||
case 'UNAUTHORIZED':
|
||||
return m.error_unauthorized();
|
||||
case 'FORBIDDEN':
|
||||
|
||||
78
frontend/src/lib/ocr/translateOcrProgress.spec.ts
Normal file
78
frontend/src/lib/ocr/translateOcrProgress.spec.ts
Normal file
@@ -0,0 +1,78 @@
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
|
||||
vi.mock('$lib/paraglide/messages.js', () => ({
|
||||
m: {
|
||||
ocr_progress_heading: () => 'OCR-Analyse',
|
||||
ocr_status_preparing: () => 'Dokument wird vorbereitet…',
|
||||
ocr_status_loading: () => 'Lade Modell und Dokument…',
|
||||
ocr_status_analyzing: () => 'OCR-Analyse läuft…',
|
||||
ocr_status_creating_blocks: ({ count }: { count: string }) => `${count} Textblöcke erkannt`,
|
||||
ocr_status_done_blocks: ({ count }: { count: string }) => `${count} Blöcke erstellt`,
|
||||
ocr_status_done_skipped: ({ count, skipped }: { count: string; skipped: string }) =>
|
||||
`${count} Blöcke erstellt, ${skipped} Seite(n) übersprungen`,
|
||||
ocr_status_analyzing_page: ({ current, total }: { current: string; total: string }) =>
|
||||
`Seite ${current} von ${total} wird analysiert…`,
|
||||
ocr_status_error: () => 'OCR fehlgeschlagen'
|
||||
}
|
||||
}));
|
||||
|
||||
import { translateOcrProgress } from './translateOcrProgress';
|
||||
|
||||
describe('translateOcrProgress', () => {
|
||||
it('returns heading for empty code', () => {
|
||||
const result = translateOcrProgress('');
|
||||
expect(result.message).toBe('OCR-Analyse');
|
||||
});
|
||||
|
||||
it('translates PREPARING', () => {
|
||||
const result = translateOcrProgress('PREPARING');
|
||||
expect(result.message).toBe('Dokument wird vorbereitet…');
|
||||
expect(result.currentPage).toBeUndefined();
|
||||
expect(result.totalPages).toBeUndefined();
|
||||
});
|
||||
|
||||
it('translates LOADING', () => {
|
||||
expect(translateOcrProgress('LOADING').message).toBe('Lade Modell und Dokument…');
|
||||
});
|
||||
|
||||
it('translates ANALYZING', () => {
|
||||
expect(translateOcrProgress('ANALYZING').message).toBe('OCR-Analyse läuft…');
|
||||
});
|
||||
|
||||
it('translates CREATING_BLOCKS with count', () => {
|
||||
expect(translateOcrProgress('CREATING_BLOCKS:42').message).toBe('42 Textblöcke erkannt');
|
||||
});
|
||||
|
||||
it('translates DONE without skipped pages', () => {
|
||||
const result = translateOcrProgress('DONE:15');
|
||||
expect(result.message).toBe('15 Blöcke erstellt');
|
||||
expect(result.skippedPages).toBeUndefined();
|
||||
});
|
||||
|
||||
it('translates DONE with zero skipped', () => {
|
||||
const result = translateOcrProgress('DONE:15:0');
|
||||
expect(result.message).toBe('15 Blöcke erstellt');
|
||||
expect(result.skippedPages).toBeUndefined();
|
||||
});
|
||||
|
||||
it('translates DONE with skipped pages', () => {
|
||||
const result = translateOcrProgress('DONE:12:2');
|
||||
expect(result.message).toBe('12 Blöcke erstellt, 2 Seite(n) übersprungen');
|
||||
expect(result.skippedPages).toBe(2);
|
||||
});
|
||||
|
||||
it('translates ANALYZING_PAGE with current, total, and blocks', () => {
|
||||
const result = translateOcrProgress('ANALYZING_PAGE:2:5:10');
|
||||
expect(result.message).toBe('Seite 2 von 5 wird analysiert…');
|
||||
expect(result.currentPage).toBe(2);
|
||||
expect(result.totalPages).toBe(5);
|
||||
});
|
||||
|
||||
it('translates ERROR', () => {
|
||||
expect(translateOcrProgress('ERROR').message).toBe('OCR fehlgeschlagen');
|
||||
});
|
||||
|
||||
it('returns raw code for unknown codes', () => {
|
||||
expect(translateOcrProgress('UNKNOWN_CODE').message).toBe('UNKNOWN_CODE');
|
||||
});
|
||||
});
|
||||
56
frontend/src/lib/ocr/translateOcrProgress.ts
Normal file
56
frontend/src/lib/ocr/translateOcrProgress.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
import { m } from '$lib/paraglide/messages.js';
|
||||
|
||||
export interface OcrProgressResult {
|
||||
message: string;
|
||||
currentPage?: number;
|
||||
totalPages?: number;
|
||||
skippedPages?: number;
|
||||
}
|
||||
|
||||
export function translateOcrProgress(code: string): OcrProgressResult {
|
||||
if (!code) return { message: m.ocr_progress_heading() };
|
||||
|
||||
const parts = code.split(':');
|
||||
const key = parts[0];
|
||||
|
||||
switch (key) {
|
||||
case 'PREPARING':
|
||||
return { message: m.ocr_status_preparing() };
|
||||
case 'LOADING':
|
||||
return { message: m.ocr_status_loading() };
|
||||
case 'ANALYZING':
|
||||
return { message: m.ocr_status_analyzing() };
|
||||
case 'CREATING_BLOCKS':
|
||||
return { message: m.ocr_status_creating_blocks({ count: parts[1] ?? '0' }) };
|
||||
case 'DONE': {
|
||||
const count = parts[1] ?? '0';
|
||||
const skipped = parts[2] ? parseInt(parts[2], 10) : 0;
|
||||
if (skipped > 0) {
|
||||
return {
|
||||
message: m.ocr_status_done_skipped({
|
||||
count,
|
||||
skipped: String(skipped)
|
||||
}),
|
||||
skippedPages: skipped
|
||||
};
|
||||
}
|
||||
return { message: m.ocr_status_done_blocks({ count }) };
|
||||
}
|
||||
case 'ANALYZING_PAGE': {
|
||||
const current = parseInt(parts[1] ?? '0', 10);
|
||||
const total = parseInt(parts[2] ?? '0', 10);
|
||||
return {
|
||||
message: m.ocr_status_analyzing_page({
|
||||
current: String(current),
|
||||
total: String(total)
|
||||
}),
|
||||
currentPage: current,
|
||||
totalPages: total
|
||||
};
|
||||
}
|
||||
case 'ERROR':
|
||||
return { message: m.ocr_status_error() };
|
||||
default:
|
||||
return { message: code };
|
||||
}
|
||||
}
|
||||
@@ -35,6 +35,8 @@ export type TranscriptionBlockData = {
|
||||
label: string | null;
|
||||
sortOrder: number;
|
||||
version: number;
|
||||
source: 'MANUAL' | 'OCR';
|
||||
reviewed: boolean;
|
||||
updatedAt?: string | null;
|
||||
};
|
||||
|
||||
@@ -49,4 +51,5 @@ export type Annotation = {
|
||||
color: string;
|
||||
createdAt: string;
|
||||
fileHash?: string | null;
|
||||
polygon?: [number, number][] | null;
|
||||
};
|
||||
|
||||
@@ -7,6 +7,8 @@ import TranscriptionEditView from '$lib/components/TranscriptionEditView.svelte'
|
||||
import TranscriptionReadView from '$lib/components/TranscriptionReadView.svelte';
|
||||
import TranscriptionPanelHeader from '$lib/components/TranscriptionPanelHeader.svelte';
|
||||
import type { TranscriptionBlockData } from '$lib/types';
|
||||
import { getErrorMessage } from '$lib/errors';
|
||||
import { translateOcrProgress } from '$lib/ocr/translateOcrProgress';
|
||||
|
||||
let { data } = $props();
|
||||
|
||||
@@ -118,6 +120,78 @@ async function deleteBlock(blockId: string) {
|
||||
annotationReloadKey++;
|
||||
}
|
||||
|
||||
async function reviewToggle(blockId: string) {
|
||||
const res = await fetch(`/api/documents/${doc.id}/transcription-blocks/${blockId}/review`, {
|
||||
method: 'PUT'
|
||||
});
|
||||
if (!res.ok) return;
|
||||
const updated = await res.json();
|
||||
transcriptionBlocks = transcriptionBlocks.map((b) => (b.id === blockId ? updated : b));
|
||||
}
|
||||
|
||||
let ocrRunning = $state(false);
|
||||
let ocrProgressMessage = $state('');
|
||||
let ocrErrorMessage = $state('');
|
||||
let ocrPollTimer = $state<ReturnType<typeof setInterval> | null>(null);
|
||||
let ocrSkippedPages = $state(0);
|
||||
|
||||
async function triggerOcr(scriptType: string) {
|
||||
ocrRunning = true;
|
||||
ocrErrorMessage = '';
|
||||
try {
|
||||
const res = await fetch(`/api/documents/${doc.id}/ocr`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ scriptType })
|
||||
});
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
pollOcrJob(data.jobId);
|
||||
} else {
|
||||
ocrRunning = false;
|
||||
const body = await res.json().catch(() => null);
|
||||
const code = (body as { code?: string } | null)?.code;
|
||||
ocrErrorMessage = code ? getErrorMessage(code) : m.ocr_status_error();
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to trigger OCR:', e);
|
||||
ocrRunning = false;
|
||||
ocrErrorMessage = m.ocr_status_error();
|
||||
}
|
||||
}
|
||||
|
||||
function pollOcrJob(jobId: string) {
|
||||
if (ocrPollTimer) clearInterval(ocrPollTimer);
|
||||
ocrPollTimer = setInterval(async () => {
|
||||
try {
|
||||
const res = await fetch(`/api/ocr/jobs/${jobId}`);
|
||||
if (!res.ok) return;
|
||||
const job = await res.json();
|
||||
const rawCode = job.progressMessage ?? '';
|
||||
const progress = translateOcrProgress(rawCode);
|
||||
ocrProgressMessage = progress.message;
|
||||
if (progress.skippedPages !== undefined) ocrSkippedPages = progress.skippedPages;
|
||||
if (job.status === 'DONE' || job.status === 'FAILED') {
|
||||
if (ocrPollTimer) clearInterval(ocrPollTimer);
|
||||
ocrPollTimer = null;
|
||||
setTimeout(() => {
|
||||
ocrRunning = false;
|
||||
ocrProgressMessage = '';
|
||||
ocrSkippedPages = 0;
|
||||
}, 1000);
|
||||
if (job.status === 'FAILED') {
|
||||
ocrErrorMessage = m.ocr_status_error();
|
||||
}
|
||||
await loadTranscriptionBlocks();
|
||||
annotationReloadKey++;
|
||||
panelMode = transcriptionBlocks.length > 0 ? 'read' : 'edit';
|
||||
}
|
||||
} catch {
|
||||
// polling is best-effort
|
||||
}
|
||||
}, 2000);
|
||||
}
|
||||
|
||||
async function createBlockFromDraw(rect: {
|
||||
x: number;
|
||||
y: number;
|
||||
@@ -198,12 +272,28 @@ function handleParagraphClick(annotationId: string) {
|
||||
);
|
||||
}
|
||||
|
||||
// Load blocks when transcribe mode is entered and set default panel mode
|
||||
async function checkOcrStatus() {
|
||||
if (!doc?.id) return;
|
||||
try {
|
||||
const res = await fetch(`/api/documents/${doc.id}/ocr-status`);
|
||||
if (!res.ok) return;
|
||||
const status = await res.json();
|
||||
if ((status.status === 'PENDING' || status.status === 'RUNNING') && status.jobId) {
|
||||
ocrRunning = true;
|
||||
pollOcrJob(status.jobId);
|
||||
}
|
||||
} catch {
|
||||
// best-effort
|
||||
}
|
||||
}
|
||||
|
||||
// Load blocks and check OCR status when transcribe mode is entered
|
||||
$effect(() => {
|
||||
if (transcribeMode) {
|
||||
loadTranscriptionBlocks().then(() => {
|
||||
panelMode = transcriptionBlocks.length > 0 ? 'read' : 'edit';
|
||||
});
|
||||
checkOcrStatus();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -227,7 +317,10 @@ onMount(() => {
|
||||
}
|
||||
}
|
||||
document.addEventListener('keydown', onKeyDown);
|
||||
return () => document.removeEventListener('keydown', onKeyDown);
|
||||
return () => {
|
||||
document.removeEventListener('keydown', onKeyDown);
|
||||
if (ocrPollTimer) clearInterval(ocrPollTimer);
|
||||
};
|
||||
});
|
||||
</script>
|
||||
|
||||
@@ -258,7 +351,7 @@ onMount(() => {
|
||||
fileUrl={fileUrl}
|
||||
isLoading={isLoading}
|
||||
error={fileError}
|
||||
transcribeMode={transcribeMode}
|
||||
transcribeMode={transcribeMode && !ocrRunning}
|
||||
blockNumbers={blockNumbers}
|
||||
annotationReloadKey={annotationReloadKey}
|
||||
annotationsDimmed={transcribeMode && panelMode === 'read'}
|
||||
@@ -303,7 +396,45 @@ onMount(() => {
|
||||
onClose={() => (transcribeMode = false)}
|
||||
/>
|
||||
<div class="flex-1 overflow-y-auto">
|
||||
{#if panelMode === 'read'}
|
||||
{#if ocrErrorMessage}
|
||||
<div class="mx-4 mt-4 rounded-sm border border-red-200 bg-red-50 px-4 py-3">
|
||||
<p class="text-sm text-red-700">{ocrErrorMessage}</p>
|
||||
</div>
|
||||
{/if}
|
||||
{#if ocrRunning}
|
||||
<div class="flex flex-1 flex-col items-center justify-center px-6 py-12 text-center">
|
||||
<svg
|
||||
class="mb-4 h-8 w-8 animate-spin text-brand-mint"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
>
|
||||
<circle
|
||||
class="opacity-25"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r="10"
|
||||
stroke="currentColor"
|
||||
stroke-width="4"
|
||||
></circle>
|
||||
<path
|
||||
class="opacity-75"
|
||||
fill="currentColor"
|
||||
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4z"
|
||||
></path>
|
||||
</svg>
|
||||
<p class="text-xs font-bold tracking-widest text-gray-400 uppercase">
|
||||
{m.ocr_progress_heading()}
|
||||
</p>
|
||||
<p class="mt-2 text-sm text-ink-2">
|
||||
{ocrProgressMessage}
|
||||
</p>
|
||||
{#if ocrSkippedPages > 0}
|
||||
<p class="mt-1 text-xs text-amber-600">
|
||||
{ocrSkippedPages} Seiten übersprungen
|
||||
</p>
|
||||
{/if}
|
||||
</div>
|
||||
{:else if panelMode === 'read'}
|
||||
<TranscriptionReadView
|
||||
blocks={transcriptionBlocks}
|
||||
highlightBlockId={highlightBlockId}
|
||||
@@ -316,9 +447,13 @@ onMount(() => {
|
||||
canComment={canWrite}
|
||||
currentUserId={currentUserId}
|
||||
activeAnnotationId={activeAnnotationId}
|
||||
storedScriptType={doc.scriptType ?? ''}
|
||||
canRunOcr={canWrite && !!doc.filePath}
|
||||
onBlockFocus={handleBlockFocus}
|
||||
onSaveBlock={saveBlock}
|
||||
onDeleteBlock={deleteBlock}
|
||||
onReviewToggle={reviewToggle}
|
||||
onTriggerOcr={triggerOcr}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
6
ocr-service/.dockerignore
Normal file
6
ocr-service/.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
.venv
|
||||
models/
|
||||
.git
|
||||
.pytest_cache
|
||||
26
ocr-service/Dockerfile
Normal file
26
ocr-service/Dockerfile
Normal file
@@ -0,0 +1,26 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# curl for healthcheck; libgomp1 for PyTorch CPU threading; libvips for kraken PDF support
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
libgomp1 \
|
||||
libvips42 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# PyTorch CPU-only — separate layer; the whl/cpu index strips all CUDA variants (~2 GB saved)
|
||||
# torchvision must also come from the CPU index to match torch's operator registrations
|
||||
RUN pip install --no-cache-dir \
|
||||
torch==2.7.1 \
|
||||
torchvision==0.22.1 \
|
||||
--index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
90
ocr-service/confidence.py
Normal file
90
ocr-service/confidence.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Confidence-based [unleserlich] marker insertion for OCR output."""
|
||||
|
||||
import os
|
||||
|
||||
THRESHOLD_DEFAULT = float(os.environ.get("OCR_CONFIDENCE_THRESHOLD", "0.3"))
|
||||
THRESHOLD_KURRENT = float(os.environ.get("OCR_CONFIDENCE_THRESHOLD_KURRENT", "0.5"))
|
||||
|
||||
ILLEGIBLE_MARKER = "[unleserlich]"
|
||||
|
||||
|
||||
def get_threshold(script_type: str) -> float:
|
||||
if script_type and script_type.upper() == "HANDWRITING_KURRENT":
|
||||
return THRESHOLD_KURRENT
|
||||
return THRESHOLD_DEFAULT
|
||||
|
||||
|
||||
def apply_confidence_markers(words: list[dict], threshold: float | None = None) -> str:
|
||||
"""Replace low-confidence words with [unleserlich], collapsing adjacent markers.
|
||||
|
||||
Args:
|
||||
words: list of {"text": str, "confidence": float} dicts
|
||||
threshold: confidence threshold (uses THRESHOLD_DEFAULT if None)
|
||||
|
||||
Returns:
|
||||
Reconstructed text string with [unleserlich] substitutions.
|
||||
"""
|
||||
if not words:
|
||||
return ""
|
||||
|
||||
if threshold is None:
|
||||
threshold = THRESHOLD_DEFAULT
|
||||
|
||||
result: list[str] = []
|
||||
prev_was_marker = False
|
||||
|
||||
for word in words:
|
||||
if word["confidence"] < threshold:
|
||||
if not prev_was_marker:
|
||||
result.append(ILLEGIBLE_MARKER)
|
||||
prev_was_marker = True
|
||||
else:
|
||||
result.append(word["text"])
|
||||
prev_was_marker = False
|
||||
|
||||
return " ".join(result)
|
||||
|
||||
|
||||
def words_from_characters(prediction: str, confidences: list[float]) -> list[dict]:
|
||||
"""Reconstruct word-level confidence from character-level data.
|
||||
|
||||
Splits prediction on whitespace, maps characters to their confidences,
|
||||
computes mean confidence per word.
|
||||
|
||||
Args:
|
||||
prediction: full line text from Kraken
|
||||
confidences: per-character confidence list (same length as prediction)
|
||||
|
||||
Returns:
|
||||
list of {"text": str, "confidence": float} dicts
|
||||
"""
|
||||
if not prediction or not prediction.strip():
|
||||
return []
|
||||
|
||||
if len(confidences) != len(prediction):
|
||||
return [{"text": prediction, "confidence": 1.0}]
|
||||
|
||||
result: list[dict] = []
|
||||
current_word: list[str] = []
|
||||
current_confs: list[float] = []
|
||||
|
||||
for char, conf in zip(prediction, confidences):
|
||||
if char == " ":
|
||||
if current_word:
|
||||
result.append({
|
||||
"text": "".join(current_word),
|
||||
"confidence": sum(current_confs) / len(current_confs),
|
||||
})
|
||||
current_word = []
|
||||
current_confs = []
|
||||
else:
|
||||
current_word.append(char)
|
||||
current_confs.append(conf)
|
||||
|
||||
if current_word:
|
||||
result.append({
|
||||
"text": "".join(current_word),
|
||||
"confidence": sum(current_confs) / len(current_confs),
|
||||
})
|
||||
|
||||
return result
|
||||
0
ocr-service/engines/__init__.py
Normal file
0
ocr-service/engines/__init__.py
Normal file
200
ocr-service/engines/kraken.py
Normal file
200
ocr-service/engines/kraken.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_model = None
|
||||
_model_path = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
||||
|
||||
|
||||
def load_models():
|
||||
"""Load the Kraken model at startup. Skips if model file is not present."""
|
||||
global _model
|
||||
|
||||
if not os.path.exists(_model_path):
|
||||
logger.warning("Kraken model not found at %s — Kurrent OCR will not be available", _model_path)
|
||||
return
|
||||
|
||||
logger.info("Loading Kraken model from %s...", _model_path)
|
||||
|
||||
from kraken.lib import models as kraken_models
|
||||
_model = kraken_models.load_any(_model_path)
|
||||
|
||||
logger.info("Kraken model loaded successfully")
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
return _model is not None
|
||||
|
||||
|
||||
def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict]:
|
||||
"""Run Kraken segmentation + recognition on a single PIL image.
|
||||
|
||||
Returns block dicts for that page. Coordinates are normalized to [0, 1].
|
||||
"""
|
||||
from kraken import blla, rpred
|
||||
from confidence import words_from_characters
|
||||
|
||||
if _model is None:
|
||||
raise RuntimeError("Kraken model is not loaded")
|
||||
|
||||
page_w, page_h = image.size
|
||||
blocks = []
|
||||
|
||||
baseline_seg = blla.segment(image)
|
||||
pred_it = rpred.rpred(_model, image, baseline_seg)
|
||||
|
||||
for record in pred_it:
|
||||
polygon_pts = record.cuts if hasattr(record, "cuts") else []
|
||||
|
||||
if polygon_pts:
|
||||
xs = [p[0] for p in polygon_pts]
|
||||
ys = [p[1] for p in polygon_pts]
|
||||
x1, y1 = min(xs), min(ys)
|
||||
x2, y2 = max(xs), max(ys)
|
||||
else:
|
||||
xs = [p[0] for p in record.line]
|
||||
ys = [p[1] for p in record.line]
|
||||
x1, y1 = min(xs), min(ys) - 5
|
||||
x2, y2 = max(xs), max(ys) + 5
|
||||
|
||||
quad = _approximate_to_quad(polygon_pts, page_w, page_h) if polygon_pts else None
|
||||
|
||||
char_confidences = getattr(record, "confidences", [])
|
||||
words = words_from_characters(record.prediction, char_confidences)
|
||||
|
||||
blocks.append({
|
||||
"pageNumber": page_idx,
|
||||
"x": x1 / page_w,
|
||||
"y": y1 / page_h,
|
||||
"width": (x2 - x1) / page_w,
|
||||
"height": (y2 - y1) / page_h,
|
||||
"polygon": quad,
|
||||
"text": record.prediction,
|
||||
"words": words,
|
||||
})
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
||||
"""Run Kraken segmentation + recognition on a list of PIL images.
|
||||
|
||||
Returns block dicts with pageNumber, x, y, width, height, polygon, text.
|
||||
Polygon is a 4-point quadrilateral approximation of the baseline polygon.
|
||||
Coordinates are normalized to [0, 1].
|
||||
"""
|
||||
all_blocks = []
|
||||
|
||||
for page_idx, image in enumerate(images, start=1):
|
||||
all_blocks.extend(extract_page_blocks(image, page_idx, language))
|
||||
|
||||
return all_blocks
|
||||
|
||||
|
||||
def _approximate_to_quad(points: list[tuple], page_w: float, page_h: float) -> list[list[float]] | None:
|
||||
"""Approximate a polygon to a 4-point quadrilateral using the minimum bounding rectangle.
|
||||
|
||||
Uses gift-wrapping (Jarvis march) for convex hull, then rotating calipers
|
||||
for the minimum area bounding rectangle. Pure Python, no scipy/numpy.
|
||||
"""
|
||||
if len(points) < 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
hull = _convex_hull(points)
|
||||
if len(hull) < 3:
|
||||
return None
|
||||
|
||||
rect = _min_bounding_rect(hull)
|
||||
|
||||
# Normalize to [0, 1]
|
||||
return [[p[0] / page_w, p[1] / page_h] for p in rect]
|
||||
except Exception:
|
||||
logger.debug("Failed to approximate polygon to quad, returning None")
|
||||
return None
|
||||
|
||||
|
||||
def _convex_hull(points: list[tuple]) -> list[tuple]:
|
||||
"""Jarvis march (gift wrapping) algorithm for 2D convex hull."""
|
||||
pts = list(set(points))
|
||||
if len(pts) < 3:
|
||||
return pts
|
||||
|
||||
# Start from leftmost point
|
||||
start = min(pts, key=lambda p: (p[0], p[1]))
|
||||
hull = []
|
||||
current = start
|
||||
|
||||
while True:
|
||||
hull.append(current)
|
||||
candidate = pts[0]
|
||||
for p in pts[1:]:
|
||||
if candidate == current:
|
||||
candidate = p
|
||||
continue
|
||||
cross = _cross(current, candidate, p)
|
||||
if cross < 0:
|
||||
candidate = p
|
||||
elif cross == 0:
|
||||
# Collinear — pick the farther point
|
||||
if _dist_sq(current, p) > _dist_sq(current, candidate):
|
||||
candidate = p
|
||||
|
||||
current = candidate
|
||||
if current == start:
|
||||
break
|
||||
|
||||
return hull
|
||||
|
||||
|
||||
def _min_bounding_rect(hull: list[tuple]) -> list[tuple]:
|
||||
"""Find the minimum area bounding rectangle of a convex hull using rotating calipers."""
|
||||
n = len(hull)
|
||||
if n < 2:
|
||||
return hull
|
||||
|
||||
min_area = float("inf")
|
||||
best_rect = None
|
||||
|
||||
for i in range(n):
|
||||
# Edge vector
|
||||
edge_x = hull[(i + 1) % n][0] - hull[i][0]
|
||||
edge_y = hull[(i + 1) % n][1] - hull[i][1]
|
||||
edge_len = (edge_x ** 2 + edge_y ** 2) ** 0.5
|
||||
if edge_len == 0:
|
||||
continue
|
||||
|
||||
# Unit vectors along and perpendicular to the edge
|
||||
ux, uy = edge_x / edge_len, edge_y / edge_len
|
||||
vx, vy = -uy, ux
|
||||
|
||||
# Project all hull points onto the edge coordinate system
|
||||
projs_u = [p[0] * ux + p[1] * uy for p in hull]
|
||||
projs_v = [p[0] * vx + p[1] * vy for p in hull]
|
||||
|
||||
min_u, max_u = min(projs_u), max(projs_u)
|
||||
min_v, max_v = min(projs_v), max(projs_v)
|
||||
|
||||
area = (max_u - min_u) * (max_v - min_v)
|
||||
if area < min_area:
|
||||
min_area = area
|
||||
# Reconstruct 4 corners in original coordinates
|
||||
best_rect = [
|
||||
(min_u * ux + min_v * vx, min_u * uy + min_v * vy),
|
||||
(max_u * ux + min_v * vx, max_u * uy + min_v * vy),
|
||||
(max_u * ux + max_v * vx, max_u * uy + max_v * vy),
|
||||
(min_u * ux + max_v * vx, min_u * uy + max_v * vy),
|
||||
]
|
||||
|
||||
return best_rect if best_rect else hull[:4]
|
||||
|
||||
|
||||
def _cross(o: tuple, a: tuple, b: tuple) -> float:
|
||||
return (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0])
|
||||
|
||||
|
||||
def _dist_sq(a: tuple, b: tuple) -> float:
|
||||
return (a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2
|
||||
97
ocr-service/engines/surya.py
Normal file
97
ocr-service/engines/surya.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Surya OCR engine wrapper — transformer-based, handles typewritten and modern Latin handwriting."""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_recognition_predictor = None
|
||||
_detection_predictor = None
|
||||
_loaded = False
|
||||
|
||||
|
||||
def load_models():
|
||||
"""Lazy-load Surya models on first use to save RAM at idle.
|
||||
|
||||
Called automatically by extract_blocks(). Can also be called explicitly
|
||||
to pre-warm if desired.
|
||||
"""
|
||||
global _recognition_predictor, _detection_predictor, _loaded
|
||||
if _loaded:
|
||||
return
|
||||
|
||||
logger.info("Loading Surya models (lazy, first OCR request)...")
|
||||
|
||||
from surya.foundation import FoundationPredictor
|
||||
from surya.recognition import RecognitionPredictor
|
||||
from surya.detection import DetectionPredictor
|
||||
|
||||
foundation_predictor = FoundationPredictor()
|
||||
_recognition_predictor = RecognitionPredictor(foundation_predictor)
|
||||
_detection_predictor = DetectionPredictor()
|
||||
_loaded = True
|
||||
|
||||
logger.info("Surya models loaded successfully")
|
||||
|
||||
|
||||
def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict]:
|
||||
"""Run Surya OCR on a single PIL image and return block dicts for that page.
|
||||
|
||||
Coordinates are normalized to [0, 1].
|
||||
"""
|
||||
load_models()
|
||||
|
||||
page_w, page_h = image.size
|
||||
blocks = []
|
||||
|
||||
predictions = _recognition_predictor([image], det_predictor=_detection_predictor)
|
||||
page_pred = predictions[0]
|
||||
|
||||
for line in page_pred.text_lines:
|
||||
bbox = line.bbox
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
polygon = None
|
||||
if hasattr(line, "polygon") and line.polygon and len(line.polygon) == 4:
|
||||
polygon = [
|
||||
[p[0] / page_w, p[1] / page_h]
|
||||
for p in line.polygon
|
||||
]
|
||||
|
||||
words = []
|
||||
if hasattr(line, "words") and line.words:
|
||||
for word in line.words:
|
||||
words.append({
|
||||
"text": word.text,
|
||||
"confidence": word.confidence,
|
||||
})
|
||||
else:
|
||||
words = [{"text": line.text, "confidence": getattr(line, "confidence", 1.0)}]
|
||||
|
||||
blocks.append({
|
||||
"pageNumber": page_idx,
|
||||
"x": x1 / page_w,
|
||||
"y": y1 / page_h,
|
||||
"width": (x2 - x1) / page_w,
|
||||
"height": (y2 - y1) / page_h,
|
||||
"polygon": polygon,
|
||||
"text": line.text,
|
||||
"words": words,
|
||||
})
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
||||
"""Run Surya OCR on a list of PIL images (one per page).
|
||||
|
||||
Processes one page at a time to limit peak memory usage.
|
||||
Returns a flat list of block dicts with pageNumber, x, y, width, height,
|
||||
polygon, text, words. Coordinates are normalized to [0, 1].
|
||||
"""
|
||||
all_blocks = []
|
||||
|
||||
for page_idx, image in enumerate(images, start=1):
|
||||
all_blocks.extend(extract_page_blocks(image, page_idx, language))
|
||||
del image
|
||||
|
||||
return all_blocks
|
||||
192
ocr-service/main.py
Normal file
192
ocr-service/main.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""OCR microservice — FastAPI app with Surya and Kraken engine support."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import pypdfium2 as pdfium
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from PIL import Image
|
||||
|
||||
from confidence import apply_confidence_markers, get_threshold
|
||||
from engines import kraken as kraken_engine
|
||||
from engines import surya as surya_engine
|
||||
from models import OcrBlock, OcrRequest
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_models_ready = False
|
||||
|
||||
ALLOWED_PDF_HOSTS = set(
|
||||
h.strip() for h in os.getenv("ALLOWED_PDF_HOSTS", "minio,localhost,127.0.0.1").split(",")
|
||||
)
|
||||
|
||||
|
||||
def _validate_url(url: str) -> None:
|
||||
"""Validate that the PDF URL points to an allowed host (SSRF protection)."""
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname or ""
|
||||
if hostname not in ALLOWED_PDF_HOSTS:
|
||||
raise HTTPException(status_code=400, detail=f"PDF host not allowed: {hostname}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Load lightweight models at startup. Surya loads lazily on first request."""
|
||||
global _models_ready
|
||||
|
||||
logger.info("Loading Kraken model at startup (Surya loads lazily on first OCR request)...")
|
||||
kraken_engine.load_models()
|
||||
_models_ready = True
|
||||
logger.info("Startup complete — ready to accept requests")
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down OCR service")
|
||||
|
||||
|
||||
app = FastAPI(title="Familienarchiv OCR Service", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
"""Health endpoint — returns 200 only after models are loaded."""
|
||||
if not _models_ready:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||
return {"status": "ok", "surya": True, "kraken": kraken_engine.is_available()}
|
||||
|
||||
|
||||
@app.post("/ocr", response_model=list[OcrBlock])
|
||||
async def run_ocr(request: OcrRequest):
|
||||
"""Run OCR on a PDF document.
|
||||
|
||||
Downloads the PDF from the provided URL, converts pages to images,
|
||||
and runs the appropriate OCR engine based on scriptType.
|
||||
OCR engines run in a thread pool so the event loop stays free for /health.
|
||||
"""
|
||||
if not _models_ready:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||
|
||||
images = await _download_and_convert_pdf(request.pdfUrl)
|
||||
|
||||
script_type = request.scriptType.upper()
|
||||
|
||||
if script_type == "HANDWRITING_KURRENT":
|
||||
if not kraken_engine.is_available():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Kraken model not available — cannot process Kurrent script",
|
||||
)
|
||||
blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language)
|
||||
else:
|
||||
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
|
||||
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
|
||||
|
||||
threshold = get_threshold(script_type)
|
||||
for block in blocks:
|
||||
if block.get("words"):
|
||||
block["text"] = apply_confidence_markers(block["words"], threshold)
|
||||
block.pop("words", None)
|
||||
|
||||
return [OcrBlock(**b) for b in blocks]
|
||||
|
||||
|
||||
@app.post("/ocr/stream")
|
||||
async def run_ocr_stream(request: OcrRequest):
|
||||
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page."""
|
||||
if not _models_ready:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||
|
||||
images = await _download_and_convert_pdf(request.pdfUrl)
|
||||
script_type = request.scriptType.upper()
|
||||
threshold = get_threshold(script_type)
|
||||
|
||||
use_kraken = script_type == "HANDWRITING_KURRENT"
|
||||
if use_kraken and not kraken_engine.is_available():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Kraken model not available — cannot process Kurrent script",
|
||||
)
|
||||
|
||||
async def generate():
|
||||
total_pages = len(images)
|
||||
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
||||
|
||||
total_blocks = 0
|
||||
skipped_pages = 0
|
||||
|
||||
for page_idx, image in enumerate(images, start=1):
|
||||
try:
|
||||
engine = kraken_engine if use_kraken else surya_engine
|
||||
blocks = await asyncio.to_thread(
|
||||
engine.extract_page_blocks, image, page_idx, request.language
|
||||
)
|
||||
|
||||
for block in blocks:
|
||||
if block.get("words"):
|
||||
block["text"] = apply_confidence_markers(block["words"], threshold)
|
||||
block.pop("words", None)
|
||||
|
||||
total_blocks += len(blocks)
|
||||
yield json.dumps({
|
||||
"type": "page",
|
||||
"pageNumber": page_idx,
|
||||
"blocks": blocks,
|
||||
}) + "\n"
|
||||
|
||||
except Exception:
|
||||
logger.exception("OCR failed on page %d", page_idx)
|
||||
skipped_pages += 1
|
||||
yield json.dumps({
|
||||
"type": "error",
|
||||
"pageNumber": page_idx,
|
||||
"message": f"OCR processing failed on page {page_idx}",
|
||||
}) + "\n"
|
||||
|
||||
finally:
|
||||
del image
|
||||
|
||||
yield json.dumps({
|
||||
"type": "done",
|
||||
"totalBlocks": total_blocks,
|
||||
"skippedPages": skipped_pages,
|
||||
}) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
||||
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
||||
_validate_url(url)
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(300.0), follow_redirects=False
|
||||
) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
pdf = pdfium.PdfDocument(io.BytesIO(response.content))
|
||||
images = []
|
||||
|
||||
for page_idx in range(len(pdf)):
|
||||
page = pdf[page_idx]
|
||||
# Render at 200 DPI — balances OCR quality vs memory usage
|
||||
# (Surya 0.17 models use ~5GB idle; 300 DPI causes OOM on multi-page docs)
|
||||
bitmap = page.render(scale=200 / 72)
|
||||
pil_image = bitmap.to_pil()
|
||||
images.append(pil_image)
|
||||
|
||||
return images
|
||||
21
ocr-service/models.py
Normal file
21
ocr-service/models.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class OcrRequest(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
pdfUrl: str
|
||||
scriptType: str = "UNKNOWN"
|
||||
language: str = "de"
|
||||
|
||||
|
||||
class OcrBlock(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
pageNumber: int
|
||||
x: float
|
||||
y: float
|
||||
width: float
|
||||
height: float
|
||||
polygon: list[list[float]] | None = None
|
||||
text: str
|
||||
10
ocr-service/requirements.txt
Normal file
10
ocr-service/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
fastapi[standard]==0.115.6
|
||||
surya-ocr==0.17.1
|
||||
kraken==7.0
|
||||
torch==2.7.1
|
||||
torchvision==0.22.1
|
||||
transformers>=4.56.1,<5.0.0
|
||||
pillow>=10.2.0,<11.0.0
|
||||
pypdfium2==4.30.0
|
||||
pyvips>=2.2.0
|
||||
httpx==0.28.1
|
||||
160
ocr-service/test_confidence.py
Normal file
160
ocr-service/test_confidence.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Tests for confidence-based [unleserlich] marker insertion."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from confidence import apply_confidence_markers, words_from_characters, get_threshold
|
||||
|
||||
|
||||
# ─── apply_confidence_markers ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_all_words_above_threshold_passes_through():
|
||||
words = [
|
||||
{"text": "Lieber", "confidence": 0.95},
|
||||
{"text": "Freund", "confidence": 0.88},
|
||||
]
|
||||
assert apply_confidence_markers(words) == "Lieber Freund"
|
||||
|
||||
|
||||
def test_single_low_confidence_word_replaced():
|
||||
words = [
|
||||
{"text": "Lieber", "confidence": 0.95},
|
||||
{"text": "xkqz", "confidence": 0.1},
|
||||
{"text": "Freund", "confidence": 0.88},
|
||||
]
|
||||
assert apply_confidence_markers(words) == "Lieber [unleserlich] Freund"
|
||||
|
||||
|
||||
def test_adjacent_low_confidence_words_collapsed():
|
||||
words = [
|
||||
{"text": "Lieber", "confidence": 0.95},
|
||||
{"text": "xkqz", "confidence": 0.1},
|
||||
{"text": "abc", "confidence": 0.05},
|
||||
{"text": "yyy", "confidence": 0.2},
|
||||
{"text": "Freund", "confidence": 0.88},
|
||||
]
|
||||
assert apply_confidence_markers(words) == "Lieber [unleserlich] Freund"
|
||||
|
||||
|
||||
def test_mixed_high_low_each_group_gets_marker():
|
||||
words = [
|
||||
{"text": "Lieber", "confidence": 0.95},
|
||||
{"text": "xkqz", "confidence": 0.1},
|
||||
{"text": "wie", "confidence": 0.9},
|
||||
{"text": "abc", "confidence": 0.05},
|
||||
{"text": "dir", "confidence": 0.88},
|
||||
]
|
||||
assert apply_confidence_markers(words) == "Lieber [unleserlich] wie [unleserlich] dir"
|
||||
|
||||
|
||||
def test_all_below_threshold_returns_single_marker():
|
||||
words = [
|
||||
{"text": "xkq", "confidence": 0.1},
|
||||
{"text": "zzz", "confidence": 0.05},
|
||||
]
|
||||
assert apply_confidence_markers(words) == "[unleserlich]"
|
||||
|
||||
|
||||
def test_empty_list_returns_empty_string():
|
||||
assert apply_confidence_markers([]) == ""
|
||||
|
||||
|
||||
def test_single_word_above_threshold():
|
||||
words = [{"text": "Hallo", "confidence": 0.9}]
|
||||
assert apply_confidence_markers(words) == "Hallo"
|
||||
|
||||
|
||||
def test_exact_threshold_passes_through():
|
||||
"""Confidence exactly at threshold should NOT be replaced (strict <)."""
|
||||
words = [{"text": "Wort", "confidence": 0.3}]
|
||||
assert apply_confidence_markers(words) == "Wort"
|
||||
|
||||
|
||||
def test_just_below_threshold_replaced():
|
||||
words = [{"text": "Wort", "confidence": 0.29}]
|
||||
assert apply_confidence_markers(words) == "[unleserlich]"
|
||||
|
||||
|
||||
def test_custom_threshold_via_parameter():
|
||||
words = [
|
||||
{"text": "Lieber", "confidence": 0.95},
|
||||
{"text": "Freund", "confidence": 0.5},
|
||||
]
|
||||
assert apply_confidence_markers(words, threshold=0.8) == "Lieber [unleserlich]"
|
||||
assert apply_confidence_markers(words, threshold=0.3) == "Lieber Freund"
|
||||
|
||||
|
||||
def test_kurrent_threshold_is_higher_than_default():
|
||||
default = get_threshold("TYPEWRITER")
|
||||
kurrent = get_threshold("HANDWRITING_KURRENT")
|
||||
assert kurrent > default
|
||||
|
||||
|
||||
def test_get_threshold_kurrent():
|
||||
assert get_threshold("HANDWRITING_KURRENT") == 0.5
|
||||
|
||||
|
||||
def test_get_threshold_default():
|
||||
assert get_threshold("TYPEWRITER") == 0.3
|
||||
assert get_threshold("HANDWRITING_LATIN") == 0.3
|
||||
assert get_threshold("UNKNOWN") == 0.3
|
||||
|
||||
|
||||
def test_low_confidence_at_start():
|
||||
words = [
|
||||
{"text": "xkq", "confidence": 0.1},
|
||||
{"text": "Freund", "confidence": 0.88},
|
||||
]
|
||||
assert apply_confidence_markers(words) == "[unleserlich] Freund"
|
||||
|
||||
|
||||
def test_low_confidence_at_end():
|
||||
words = [
|
||||
{"text": "Lieber", "confidence": 0.95},
|
||||
{"text": "xkq", "confidence": 0.1},
|
||||
]
|
||||
assert apply_confidence_markers(words) == "Lieber [unleserlich]"
|
||||
|
||||
|
||||
# ─── words_from_characters ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_single_word_matching_confidences():
|
||||
words = words_from_characters("Hallo", [0.9, 0.8, 0.85, 0.7, 0.95])
|
||||
assert len(words) == 1
|
||||
assert words[0]["text"] == "Hallo"
|
||||
assert abs(words[0]["confidence"] - 0.84) < 0.01
|
||||
|
||||
|
||||
def test_multi_word_with_spaces():
|
||||
prediction = "Sehr geehrter"
|
||||
confidences = [0.9, 0.8, 0.7, 0.6, 0.5, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2]
|
||||
words = words_from_characters(prediction, confidences)
|
||||
assert len(words) == 2
|
||||
assert words[0]["text"] == "Sehr"
|
||||
assert words[1]["text"] == "geehrter"
|
||||
|
||||
|
||||
def test_length_mismatch_falls_back_safely():
|
||||
words = words_from_characters("Hallo Welt", [0.9, 0.8])
|
||||
assert len(words) == 1
|
||||
assert words[0]["text"] == "Hallo Welt"
|
||||
assert words[0]["confidence"] == 1.0
|
||||
|
||||
|
||||
def test_empty_prediction_returns_empty():
|
||||
assert words_from_characters("", []) == []
|
||||
|
||||
|
||||
def test_single_character_word():
|
||||
words = words_from_characters("A B", [0.9, 0.5, 0.3])
|
||||
assert len(words) == 2
|
||||
assert words[0]["text"] == "A"
|
||||
assert words[0]["confidence"] == 0.9
|
||||
assert words[1]["text"] == "B"
|
||||
assert words[1]["confidence"] == 0.3
|
||||
|
||||
|
||||
def test_whitespace_only_prediction():
|
||||
words = words_from_characters(" ", [0.5, 0.5, 0.5])
|
||||
assert words == []
|
||||
178
ocr-service/test_engines.py
Normal file
178
ocr-service/test_engines.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Tests for per-page block extraction in OCR engines."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from PIL import Image
|
||||
|
||||
from engines import surya, kraken
|
||||
|
||||
|
||||
# ─── Surya extract_page_blocks ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_surya_line(text, bbox, polygon=None, words=None):
|
||||
"""Create a mock Surya text line with the expected attributes."""
|
||||
line = MagicMock()
|
||||
line.text = text
|
||||
line.bbox = bbox
|
||||
line.polygon = polygon
|
||||
line.words = words or []
|
||||
if not words:
|
||||
line.confidence = 0.95
|
||||
# hasattr check needs words to be falsy
|
||||
del line.words
|
||||
return line
|
||||
|
||||
|
||||
def test_surya_extract_page_blocks_returns_blocks_for_single_image():
|
||||
image = Image.new("RGB", (100, 200))
|
||||
|
||||
mock_line = _make_surya_line("Hallo Welt", [10, 20, 90, 40])
|
||||
mock_pred = MagicMock()
|
||||
mock_pred.text_lines = [mock_line]
|
||||
|
||||
with patch.object(surya, "_recognition_predictor") as mock_rec, \
|
||||
patch.object(surya, "_loaded", True):
|
||||
mock_rec.return_value = [mock_pred]
|
||||
|
||||
blocks = surya.extract_page_blocks(image, page_idx=2, language="de")
|
||||
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0]["pageNumber"] == 2
|
||||
assert blocks[0]["text"] == "Hallo Welt"
|
||||
assert blocks[0]["x"] == 10 / 100
|
||||
assert blocks[0]["y"] == 20 / 200
|
||||
assert blocks[0]["width"] == 80 / 100
|
||||
assert blocks[0]["height"] == 20 / 200
|
||||
|
||||
|
||||
def test_surya_extract_page_blocks_extracts_polygon_when_present():
|
||||
image = Image.new("RGB", (100, 200))
|
||||
|
||||
mock_line = _make_surya_line("Text", [10, 20, 90, 40])
|
||||
mock_line.polygon = [(10, 20), (90, 20), (90, 40), (10, 40)]
|
||||
|
||||
mock_pred = MagicMock()
|
||||
mock_pred.text_lines = [mock_line]
|
||||
|
||||
with patch.object(surya, "_recognition_predictor") as mock_rec, \
|
||||
patch.object(surya, "_loaded", True):
|
||||
mock_rec.return_value = [mock_pred]
|
||||
|
||||
blocks = surya.extract_page_blocks(image, page_idx=1, language="de")
|
||||
|
||||
assert blocks[0]["polygon"] is not None
|
||||
assert len(blocks[0]["polygon"]) == 4
|
||||
assert blocks[0]["polygon"][0] == [10 / 100, 20 / 200]
|
||||
|
||||
|
||||
def test_surya_extract_page_blocks_extracts_word_level_confidence():
|
||||
image = Image.new("RGB", (100, 200))
|
||||
|
||||
word1 = MagicMock()
|
||||
word1.text = "Hallo"
|
||||
word1.confidence = 0.95
|
||||
word2 = MagicMock()
|
||||
word2.text = "Welt"
|
||||
word2.confidence = 0.3
|
||||
|
||||
mock_line = _make_surya_line("Hallo Welt", [10, 20, 90, 40], words=[word1, word2])
|
||||
|
||||
mock_pred = MagicMock()
|
||||
mock_pred.text_lines = [mock_line]
|
||||
|
||||
with patch.object(surya, "_recognition_predictor") as mock_rec, \
|
||||
patch.object(surya, "_loaded", True):
|
||||
mock_rec.return_value = [mock_pred]
|
||||
|
||||
blocks = surya.extract_page_blocks(image, page_idx=1, language="de")
|
||||
|
||||
assert len(blocks[0]["words"]) == 2
|
||||
assert blocks[0]["words"][0]["text"] == "Hallo"
|
||||
assert blocks[0]["words"][0]["confidence"] == 0.95
|
||||
|
||||
|
||||
def test_surya_extract_blocks_delegates_to_extract_page_blocks():
|
||||
"""After refactoring, extract_blocks should produce the same output."""
|
||||
image1 = Image.new("RGB", (100, 200))
|
||||
image2 = Image.new("RGB", (100, 200))
|
||||
|
||||
mock_line = _make_surya_line("Test", [10, 20, 90, 40])
|
||||
mock_pred = MagicMock()
|
||||
mock_pred.text_lines = [mock_line]
|
||||
|
||||
with patch.object(surya, "_recognition_predictor") as mock_rec, \
|
||||
patch.object(surya, "_loaded", True):
|
||||
mock_rec.return_value = [mock_pred]
|
||||
|
||||
blocks = surya.extract_blocks([image1, image2])
|
||||
|
||||
assert len(blocks) == 2
|
||||
assert blocks[0]["pageNumber"] == 1
|
||||
assert blocks[1]["pageNumber"] == 2
|
||||
|
||||
|
||||
# ─── Kraken extract_page_blocks ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_kraken_record(text, cuts, confidences=None):
|
||||
record = MagicMock()
|
||||
record.prediction = text
|
||||
record.cuts = cuts
|
||||
record.line = cuts
|
||||
record.confidences = confidences or [0.9] * len(text)
|
||||
return record
|
||||
|
||||
|
||||
def _run_kraken_with_mocks(fn, *args):
|
||||
"""Run a kraken function with blla/rpred mocked via sys.modules."""
|
||||
import sys
|
||||
mock_blla = MagicMock()
|
||||
mock_blla.segment.return_value = MagicMock()
|
||||
mock_rpred = MagicMock()
|
||||
|
||||
mock_record = _make_kraken_record("Kurrent", [(10, 20), (90, 20), (90, 40), (10, 40)])
|
||||
mock_rpred.rpred.return_value = [mock_record]
|
||||
|
||||
saved_kraken = sys.modules.get("kraken")
|
||||
saved_blla = sys.modules.get("kraken.blla")
|
||||
saved_rpred = sys.modules.get("kraken.rpred")
|
||||
|
||||
sys.modules["kraken"] = MagicMock(blla=mock_blla, rpred=mock_rpred)
|
||||
sys.modules["kraken.blla"] = mock_blla
|
||||
sys.modules["kraken.rpred"] = mock_rpred
|
||||
|
||||
try:
|
||||
with patch.object(kraken, "_model", MagicMock()):
|
||||
return fn(*args)
|
||||
finally:
|
||||
if saved_kraken is not None:
|
||||
sys.modules["kraken"] = saved_kraken
|
||||
else:
|
||||
sys.modules.pop("kraken", None)
|
||||
if saved_blla is not None:
|
||||
sys.modules["kraken.blla"] = saved_blla
|
||||
else:
|
||||
sys.modules.pop("kraken.blla", None)
|
||||
if saved_rpred is not None:
|
||||
sys.modules["kraken.rpred"] = saved_rpred
|
||||
else:
|
||||
sys.modules.pop("kraken.rpred", None)
|
||||
|
||||
|
||||
def test_kraken_extract_page_blocks_returns_blocks_for_single_image():
|
||||
image = Image.new("RGB", (100, 200))
|
||||
blocks = _run_kraken_with_mocks(kraken.extract_page_blocks, image, 3, "de")
|
||||
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0]["pageNumber"] == 3
|
||||
assert blocks[0]["text"] == "Kurrent"
|
||||
|
||||
|
||||
def test_kraken_extract_blocks_delegates_to_extract_page_blocks():
|
||||
image1 = Image.new("RGB", (100, 200))
|
||||
image2 = Image.new("RGB", (100, 200))
|
||||
blocks = _run_kraken_with_mocks(kraken.extract_blocks, [image1, image2])
|
||||
|
||||
assert len(blocks) == 2
|
||||
assert blocks[0]["pageNumber"] == 1
|
||||
assert blocks[1]["pageNumber"] == 2
|
||||
303
ocr-service/test_stream.py
Normal file
303
ocr-service/test_stream.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""Tests for the NDJSON streaming OCR endpoint POST /ocr/stream."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_images():
|
||||
"""Return 3 fake PIL images."""
|
||||
from PIL import Image
|
||||
return [Image.new("RGB", (100, 200)) for _ in range(3)]
|
||||
|
||||
|
||||
def _make_block(page_idx, text="Test"):
|
||||
return {
|
||||
"pageNumber": page_idx,
|
||||
"x": 0.1, "y": 0.2, "width": 0.8, "height": 0.1,
|
||||
"polygon": None, "text": text,
|
||||
"words": [{"text": text, "confidence": 0.95}],
|
||||
}
|
||||
|
||||
|
||||
# ─── P3: start event with total pages ────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_stream_emits_start_event_with_total_pages(mock_images):
|
||||
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.surya_engine") as mock_surya:
|
||||
mock_surya.extract_page_blocks.return_value = [_make_block(0)]
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr/stream", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "TYPEWRITER",
|
||||
"language": "de",
|
||||
})
|
||||
|
||||
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||
assert lines[0] == {"type": "start", "totalPages": 3}
|
||||
|
||||
|
||||
# ─── P4: page events per completed page ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_stream_emits_page_event_per_page_with_blocks(mock_images):
|
||||
def page_blocks(image, page_idx, language="de"):
|
||||
return [_make_block(page_idx, f"Page {page_idx}")]
|
||||
|
||||
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.surya_engine") as mock_surya:
|
||||
mock_surya.extract_page_blocks.side_effect = page_blocks
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr/stream", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "TYPEWRITER",
|
||||
})
|
||||
|
||||
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||
page_events = [l for l in lines if l["type"] == "page"]
|
||||
assert len(page_events) == 3
|
||||
assert page_events[0]["pageNumber"] == 1
|
||||
assert page_events[1]["pageNumber"] == 2
|
||||
assert page_events[2]["pageNumber"] == 3
|
||||
assert len(page_events[0]["blocks"]) == 1
|
||||
|
||||
|
||||
# ─── P5: done event with total blocks and skipped ────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_stream_emits_done_with_total_blocks(mock_images):
|
||||
def page_blocks(image, page_idx, language="de"):
|
||||
return [_make_block(page_idx)] * 2 # 2 blocks per page
|
||||
|
||||
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.surya_engine") as mock_surya:
|
||||
mock_surya.extract_page_blocks.side_effect = page_blocks
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr/stream", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "TYPEWRITER",
|
||||
})
|
||||
|
||||
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||
done = [l for l in lines if l["type"] == "done"][0]
|
||||
assert done["totalBlocks"] == 6
|
||||
assert done["skippedPages"] == 0
|
||||
|
||||
|
||||
# ─── P6: error event on page failure, continues ──────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_stream_emits_error_event_on_page_failure_and_continues(mock_images):
|
||||
call_count = 0
|
||||
|
||||
def page_blocks(image, page_idx, language="de"):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if page_idx == 2:
|
||||
raise RuntimeError("Engine crashed on page 2")
|
||||
return [_make_block(page_idx)]
|
||||
|
||||
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.surya_engine") as mock_surya:
|
||||
mock_surya.extract_page_blocks.side_effect = page_blocks
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr/stream", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "TYPEWRITER",
|
||||
})
|
||||
|
||||
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||
types = [l["type"] for l in lines]
|
||||
assert "error" in types
|
||||
error_event = [l for l in lines if l["type"] == "error"][0]
|
||||
assert error_event["pageNumber"] == 2
|
||||
# Error message must be generic, not the raw traceback
|
||||
assert "Engine crashed" not in error_event["message"]
|
||||
|
||||
# Processing continued: pages 0 and 2 have page events
|
||||
page_events = [l for l in lines if l["type"] == "page"]
|
||||
assert len(page_events) == 2
|
||||
|
||||
done = [l for l in lines if l["type"] == "done"][0]
|
||||
assert done["skippedPages"] == 1
|
||||
|
||||
|
||||
# ─── P7: old /ocr endpoint still works ───────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_old_ocr_endpoint_still_returns_flat_list(mock_images):
|
||||
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.surya_engine") as mock_surya:
|
||||
mock_surya.extract_blocks.return_value = [_make_block(1), _make_block(2)]
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "TYPEWRITER",
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
assert data[0]["pageNumber"] == 1
|
||||
|
||||
|
||||
# ─── Health and error handling ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_returns_ok_when_models_ready():
|
||||
with patch("main._models_ready", True), \
|
||||
patch("main.kraken_engine") as mock_kraken:
|
||||
mock_kraken.is_available.return_value = True
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["surya"] is True
|
||||
assert data["kraken"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_returns_503_when_models_not_ready():
|
||||
with patch("main._models_ready", False):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_stream_returns_503_when_models_not_ready():
|
||||
with patch("main._models_ready", False):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr/stream", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "TYPEWRITER",
|
||||
})
|
||||
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_stream_uses_kraken_for_kurrent_script(mock_images):
|
||||
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.kraken_engine") as mock_kraken:
|
||||
mock_kraken.is_available.return_value = True
|
||||
mock_kraken.extract_page_blocks.return_value = [_make_block(1, "Kurrent text")]
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr/stream", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "HANDWRITING_KURRENT",
|
||||
})
|
||||
|
||||
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||
page_events = [l for l in lines if l["type"] == "page"]
|
||||
assert len(page_events) == 3
|
||||
mock_kraken.extract_page_blocks.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_stream_returns_400_when_kraken_unavailable_for_kurrent(mock_images):
|
||||
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.kraken_engine") as mock_kraken:
|
||||
mock_kraken.is_available.return_value = False
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr/stream", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "HANDWRITING_KURRENT",
|
||||
})
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# ─── SSRF protection ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_stream_rejects_disallowed_host():
|
||||
with patch("main._models_ready", True):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr/stream", json={
|
||||
"pdfUrl": "http://evil.example.com/malicious.pdf",
|
||||
"scriptType": "TYPEWRITER",
|
||||
})
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "not allowed" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_stream_allows_minio_host(mock_images):
|
||||
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.surya_engine") as mock_surya:
|
||||
mock_surya.extract_page_blocks.return_value = [_make_block(0)]
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr/stream", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "TYPEWRITER",
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_stream_applies_confidence_markers(mock_images):
|
||||
"""Low-confidence words should be replaced with [unleserlich] in the stream output."""
|
||||
def page_blocks(image, page_idx, language="de"):
|
||||
return [{
|
||||
"pageNumber": page_idx,
|
||||
"x": 0.1, "y": 0.2, "width": 0.8, "height": 0.1,
|
||||
"polygon": None,
|
||||
"text": "original text",
|
||||
"words": [
|
||||
{"text": "Lieber", "confidence": 0.95},
|
||||
{"text": "xkqz", "confidence": 0.1},
|
||||
],
|
||||
}]
|
||||
|
||||
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images[:1]), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.surya_engine") as mock_surya:
|
||||
mock_surya.extract_page_blocks.side_effect = page_blocks
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr/stream", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "TYPEWRITER",
|
||||
})
|
||||
|
||||
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||
page = [l for l in lines if l["type"] == "page"][0]
|
||||
assert page["blocks"][0]["text"] == "Lieber [unleserlich]"
|
||||
assert "words" not in page["blocks"][0]
|
||||
203
scripts/download-kraken-models.sh
Executable file
203
scripts/download-kraken-models.sh
Executable file
@@ -0,0 +1,203 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Downloads Kraken HTR models, runs OCR evaluation on sample PDFs, and activates the winner.
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/download-kraken-models.sh # download models + run evaluation
|
||||
# ./scripts/download-kraken-models.sh --activate 1 # activate model 1
|
||||
# ./scripts/download-kraken-models.sh --activate 2 # activate model 2
|
||||
# ./scripts/download-kraken-models.sh --eval-only # re-run evaluation (models already downloaded)
|
||||
|
||||
COMPOSE_SERVICE="ocr-service"
|
||||
MODEL_DIR="/app/models"
|
||||
ACTIVE_MODEL="$MODEL_DIR/german_kurrent.mlmodel"
|
||||
|
||||
# Kraken 7 uses DOIs to identify models from HTR-United / Zenodo
|
||||
MODEL_1_DOI="10.5281/zenodo.7933463"
|
||||
MODEL_1_NAME="german_handwriting"
|
||||
MODEL_1_DESC="HTR model for German manuscripts (handwritten text recognition)"
|
||||
MODEL_1_PATH="$MODEL_DIR/$MODEL_1_NAME.mlmodel"
|
||||
|
||||
MODEL_2_DOI="10.5281/zenodo.13788177"
|
||||
MODEL_2_NAME="mccatmus"
|
||||
MODEL_2_DESC="McCATMuS — generic model for handwritten, printed & typewritten (16th c. onward)"
|
||||
MODEL_2_PATH="$MODEL_DIR/$MODEL_2_NAME.mlmodel"
|
||||
|
||||
EVAL_PDFS=("Eu-0693.pdf" "Eu-0692.pdf" "W-0150.pdf" "W-0575.pdf")
|
||||
IMPORT_DIR="./import"
|
||||
OUTPUT_BASE="./ocr-model-evaluation"
|
||||
|
||||
# ─── Functions ────────────────────────────────────────────────────────────────
|
||||
|
||||
download_model() {
|
||||
local doi="$1"
|
||||
local dest="$2"
|
||||
local name="$3"
|
||||
|
||||
echo " Downloading $name ($doi)..."
|
||||
|
||||
# kraken get downloads to /root/.local/share/htrmopo/<uuid>/<name>.mlmodel
|
||||
# Parse the "Model dir: <path>" line from kraken output to locate the file
|
||||
docker compose run --rm "$COMPOSE_SERVICE" sh -c '
|
||||
OUTPUT=$(kraken get '"$doi"' 2>&1)
|
||||
echo "$OUTPUT"
|
||||
MODEL_DIR=$(echo "$OUTPUT" | grep -oP "Model dir: \K[^ ]+")
|
||||
if [ -n "$MODEL_DIR" ] && [ -d "$MODEL_DIR" ]; then
|
||||
FOUND=$(find "$MODEL_DIR" -name "*.mlmodel" | head -1)
|
||||
if [ -n "$FOUND" ]; then
|
||||
cp "$FOUND" '"$dest"'
|
||||
echo "Saved to '"$dest"'"
|
||||
else
|
||||
echo "ERROR: No .mlmodel file in $MODEL_DIR"
|
||||
ls -la "$MODEL_DIR"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "ERROR: Could not parse model directory from kraken output"
|
||||
exit 1
|
||||
fi
|
||||
'
|
||||
}
|
||||
|
||||
download_models() {
|
||||
echo "Downloading Kraken HTR models into the ocr_models volume..."
|
||||
echo ""
|
||||
|
||||
echo "Model 1: $MODEL_1_NAME"
|
||||
echo " $MODEL_1_DESC"
|
||||
download_model "$MODEL_1_DOI" "$MODEL_1_PATH" "$MODEL_1_NAME"
|
||||
echo ""
|
||||
|
||||
echo "Model 2: $MODEL_2_NAME"
|
||||
echo " $MODEL_2_DESC"
|
||||
download_model "$MODEL_2_DOI" "$MODEL_2_PATH" "$MODEL_2_NAME"
|
||||
echo ""
|
||||
|
||||
echo "Both models downloaded."
|
||||
}
|
||||
|
||||
run_evaluation() {
|
||||
echo ""
|
||||
echo "═══════════════════════════════════════════════════════"
|
||||
echo " Running OCR evaluation on ${#EVAL_PDFS[@]} documents"
|
||||
echo "═══════════════════════════════════════════════════════"
|
||||
echo ""
|
||||
|
||||
local dir_1="$OUTPUT_BASE/$MODEL_1_NAME"
|
||||
local dir_2="$OUTPUT_BASE/$MODEL_2_NAME"
|
||||
mkdir -p "$dir_1" "$dir_2"
|
||||
|
||||
for pdf in "${EVAL_PDFS[@]}"; do
|
||||
local src="$IMPORT_DIR/$pdf"
|
||||
local basename="${pdf%.pdf}"
|
||||
|
||||
if [[ ! -f "$src" ]]; then
|
||||
echo "SKIP: $src not found"
|
||||
continue
|
||||
fi
|
||||
|
||||
echo "──── $pdf ────"
|
||||
|
||||
# Run both models inside a single container run:
|
||||
# 1. Extract PDF pages as PNGs (pypdfium2 is already installed)
|
||||
# 2. Run kraken on each page image for both models
|
||||
# 3. Concatenate per-page output into one file per model
|
||||
docker compose run --rm \
|
||||
-v "$(cd "$IMPORT_DIR" && pwd):/eval-input:ro" \
|
||||
-v "$(cd "$dir_1" && pwd):/eval-out-1" \
|
||||
-v "$(cd "$dir_2" && pwd):/eval-out-2" \
|
||||
"$COMPOSE_SERVICE" \
|
||||
python3 -c "
|
||||
import pypdfium2 as pdfium, subprocess, sys, os
|
||||
|
||||
pdf = pdfium.PdfDocument('/eval-input/$pdf')
|
||||
pages = []
|
||||
for i in range(len(pdf)):
|
||||
bmp = pdf[i].render(scale=300/72)
|
||||
path = f'/tmp/page_{i:04d}.png'
|
||||
bmp.to_pil().save(path)
|
||||
pages.append(path)
|
||||
print(f'Extracted {len(pages)} pages')
|
||||
|
||||
for label, model, outdir in [
|
||||
('Model 1', '$MODEL_1_PATH', '/eval-out-1'),
|
||||
('Model 2', '$MODEL_2_PATH', '/eval-out-2'),
|
||||
]:
|
||||
print(f' {label}...')
|
||||
combined = ''
|
||||
for p in pages:
|
||||
args = ['kraken', '-i', p, '/dev/stdout', 'segment', '-bl', 'ocr', '-m', model]
|
||||
r = subprocess.run(args, capture_output=True, text=True)
|
||||
combined += r.stdout
|
||||
if r.returncode != 0:
|
||||
print(f' ⚠ failed on {os.path.basename(p)}: {r.stderr[:200]}', file=sys.stderr)
|
||||
with open(f'{outdir}/${basename}.txt', 'w') as f:
|
||||
f.write(combined)
|
||||
lines = combined.count(chr(10))
|
||||
print(f' → {lines} lines')
|
||||
" || echo " ⚠ Failed on $pdf"
|
||||
|
||||
echo ""
|
||||
done
|
||||
|
||||
echo "═══════════════════════════════════════════════════════"
|
||||
echo " Evaluation complete. Results:"
|
||||
echo ""
|
||||
echo " Model 1 — $MODEL_1_NAME ($MODEL_1_DOI):"
|
||||
for f in "$dir_1"/*.txt; do
|
||||
[[ -f "$f" ]] && echo " $(basename "$f"): $(wc -l < "$f") lines, $(wc -c < "$f") bytes"
|
||||
done
|
||||
echo ""
|
||||
echo " Model 2 — $MODEL_2_NAME ($MODEL_2_DOI):"
|
||||
for f in "$dir_2"/*.txt; do
|
||||
[[ -f "$f" ]] && echo " $(basename "$f"): $(wc -l < "$f") lines, $(wc -c < "$f") bytes"
|
||||
done
|
||||
echo ""
|
||||
echo " Compare outputs:"
|
||||
echo " diff $dir_1/Eu-0693.txt $dir_2/Eu-0693.txt"
|
||||
echo " # or view individually:"
|
||||
echo " cat $dir_1/Eu-0693.txt"
|
||||
echo ""
|
||||
echo " Activate the better model:"
|
||||
echo " ./scripts/download-kraken-models.sh --activate 1 # $MODEL_1_NAME"
|
||||
echo " ./scripts/download-kraken-models.sh --activate 2 # $MODEL_2_NAME"
|
||||
echo "═══════════════════════════════════════════════════════"
|
||||
}
|
||||
|
||||
activate_model() {
|
||||
local choice="$1"
|
||||
case "$choice" in
|
||||
1)
|
||||
echo "Activating model 1: $MODEL_1_NAME"
|
||||
docker compose run --rm "$COMPOSE_SERVICE" cp "$MODEL_1_PATH" "$ACTIVE_MODEL"
|
||||
;;
|
||||
2)
|
||||
echo "Activating model 2: $MODEL_2_NAME"
|
||||
docker compose run --rm "$COMPOSE_SERVICE" cp "$MODEL_2_PATH" "$ACTIVE_MODEL"
|
||||
;;
|
||||
*)
|
||||
echo "Error: --activate expects 1 or 2"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "Active model is now: $ACTIVE_MODEL"
|
||||
echo "Restart the OCR service to load the new model:"
|
||||
echo " docker compose restart ocr-service"
|
||||
}
|
||||
|
||||
# ─── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
case "${1:-}" in
|
||||
--activate)
|
||||
activate_model "${2:-}"
|
||||
;;
|
||||
--eval-only)
|
||||
run_evaluation
|
||||
;;
|
||||
*)
|
||||
download_models
|
||||
run_evaluation
|
||||
;;
|
||||
esac
|
||||
Reference in New Issue
Block a user