diff --git a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java new file mode 100644 index 00000000..bd1e41f9 --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java @@ -0,0 +1,114 @@ +package org.raddatz.familienarchiv.controller; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.raddatz.familienarchiv.dto.BatchOcrDTO; +import org.raddatz.familienarchiv.dto.OcrStatusDTO; +import org.raddatz.familienarchiv.dto.TriggerOcrDTO; +import org.raddatz.familienarchiv.exception.DomainException; +import org.raddatz.familienarchiv.exception.ErrorCode; +import org.raddatz.familienarchiv.model.*; +import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository; +import org.raddatz.familienarchiv.repository.OcrJobRepository; +import org.raddatz.familienarchiv.security.Permission; +import org.raddatz.familienarchiv.security.RequirePermission; +import org.raddatz.familienarchiv.service.OcrBatchService; +import org.raddatz.familienarchiv.service.OcrProgressService; +import org.raddatz.familienarchiv.service.OcrService; +import org.raddatz.familienarchiv.service.UserService; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.security.core.Authentication; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import jakarta.validation.Valid; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + +@RestController +@RequiredArgsConstructor +@Slf4j +public class OcrController { + + private final OcrService ocrService; + private final OcrBatchService ocrBatchService; + private final OcrProgressService ocrProgressService; + private final OcrJobRepository ocrJobRepository; + private final OcrJobDocumentRepository ocrJobDocumentRepository; + private final UserService userService; + + @PostMapping("/api/documents/{documentId}/ocr") + @ResponseStatus(HttpStatus.ACCEPTED) + @RequirePermission(Permission.WRITE_ALL) + public Map triggerOcr( + @PathVariable UUID documentId, + @RequestBody TriggerOcrDTO dto, + Authentication authentication) { + UUID userId = resolveUserId(authentication); + UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId); + return Map.of("jobId", jobId); + } + + @PostMapping("/api/ocr/batch") + @ResponseStatus(HttpStatus.ACCEPTED) + @RequirePermission(Permission.ADMIN) + public Map triggerBatch( + @RequestBody @Valid BatchOcrDTO dto, + Authentication authentication) { + UUID userId = resolveUserId(authentication); + UUID jobId = ocrBatchService.startBatch(dto.getDocumentIds(), userId); + return Map.of("jobId", jobId); + } + + @GetMapping("/api/ocr/jobs/{jobId}") + @RequirePermission(Permission.READ_ALL) + public OcrJob getJobStatus(@PathVariable UUID jobId) { + return ocrJobRepository.findById(jobId) + .orElseThrow(() -> DomainException.notFound( + ErrorCode.OCR_JOB_NOT_FOUND, "OCR job not found: " + jobId)); + } + + @GetMapping(value = "/api/ocr/jobs/{jobId}/progress", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + @RequirePermission(Permission.READ_ALL) + public SseEmitter streamProgress(@PathVariable UUID jobId) { + ocrJobRepository.findById(jobId) + .orElseThrow(() -> DomainException.notFound( + ErrorCode.OCR_JOB_NOT_FOUND, "OCR job not found: " + jobId)); + return ocrProgressService.register(jobId); + } + + @GetMapping("/api/documents/{documentId}/ocr-status") + @RequirePermission(Permission.READ_ALL) + public OcrStatusDTO getDocumentOcrStatus(@PathVariable UUID documentId) { + List activeStatuses = List.of( + OcrDocumentStatus.PENDING, OcrDocumentStatus.RUNNING); + + Optional activeJobDoc = ocrJobDocumentRepository + .findFirstByDocumentIdAndStatusIn(documentId, activeStatuses); + + if (activeJobDoc.isEmpty()) { + return OcrStatusDTO.builder().status("NONE").build(); + } + + OcrJobDocument jobDoc = activeJobDoc.get(); + return OcrStatusDTO.builder() + .status(jobDoc.getStatus().name()) + .jobId(jobDoc.getJobId()) + .currentPage(jobDoc.getCurrentPage()) + .totalPages(jobDoc.getTotalPages()) + .build(); + } + + private UUID resolveUserId(Authentication authentication) { + if (authentication == null || !authentication.isAuthenticated()) return null; + try { + AppUser user = userService.findByUsername(authentication.getName()); + return user != null ? user.getId() : null; + } catch (Exception e) { + return null; + } + } +} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrBatchService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrBatchService.java new file mode 100644 index 00000000..52639c36 --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrBatchService.java @@ -0,0 +1,114 @@ +package org.raddatz.familienarchiv.service; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.raddatz.familienarchiv.exception.DomainException; +import org.raddatz.familienarchiv.exception.ErrorCode; +import org.raddatz.familienarchiv.model.*; +import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository; +import org.raddatz.familienarchiv.repository.OcrJobRepository; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; + +import java.util.List; +import java.util.Map; +import java.util.UUID; + +@Service +@RequiredArgsConstructor +@Slf4j +public class OcrBatchService { + + private final OcrService ocrService; + private final OcrHealthClient ocrHealthClient; + private final DocumentService documentService; + private final OcrJobRepository ocrJobRepository; + private final OcrJobDocumentRepository ocrJobDocumentRepository; + private final OcrProgressService ocrProgressService; + + public UUID startBatch(List documentIds, UUID userId) { + if (!ocrHealthClient.isHealthy()) { + throw DomainException.internal(ErrorCode.OCR_SERVICE_UNAVAILABLE, + "OCR service is not available"); + } + + OcrJob job = OcrJob.builder() + .totalDocuments(documentIds.size()) + .createdBy(userId) + .status(OcrJobStatus.PENDING) + .build(); + job = ocrJobRepository.save(job); + + for (UUID docId : documentIds) { + OcrJobDocument jobDoc = OcrJobDocument.builder() + .jobId(job.getId()) + .documentId(docId) + .status(OcrDocumentStatus.PENDING) + .build(); + ocrJobDocumentRepository.save(jobDoc); + } + + processBatchAsync(job.getId(), userId); + return job.getId(); + } + + @Async + void processBatchAsync(UUID jobId, UUID userId) { + OcrJob job = ocrJobRepository.findById(jobId).orElse(null); + if (job == null) return; + + job.setStatus(OcrJobStatus.RUNNING); + ocrJobRepository.save(job); + + List jobDocs = ocrJobDocumentRepository.findByJobIdOrderByCreatedAtAsc(jobId); + + for (OcrJobDocument jobDoc : jobDocs) { + Document doc = documentService.getDocumentById(jobDoc.getDocumentId()); + + if (doc.getStatus() == DocumentStatus.PLACEHOLDER) { + jobDoc.setStatus(OcrDocumentStatus.SKIPPED); + ocrJobDocumentRepository.save(jobDoc); + job.setSkippedCount(job.getSkippedCount() + 1); + ocrJobRepository.save(job); + ocrProgressService.emit(jobId, "document", Map.of( + "documentId", jobDoc.getDocumentId(), + "status", "SKIPPED", + "processed", job.getProcessedDocuments(), + "total", job.getTotalDocuments())); + continue; + } + + jobDoc.setStatus(OcrDocumentStatus.RUNNING); + ocrJobDocumentRepository.save(jobDoc); + + try { + ocrService.processDocument(jobDoc.getDocumentId(), doc, userId); + jobDoc.setStatus(OcrDocumentStatus.DONE); + job.setProcessedDocuments(job.getProcessedDocuments() + 1); + } catch (Exception e) { + log.error("OCR batch: failed document {}", jobDoc.getDocumentId(), e); + jobDoc.setStatus(OcrDocumentStatus.FAILED); + jobDoc.setErrorMessage(e.getMessage()); + job.setErrorCount(job.getErrorCount() + 1); + } + + ocrJobDocumentRepository.save(jobDoc); + ocrJobRepository.save(job); + + ocrProgressService.emit(jobId, "document", Map.of( + "documentId", jobDoc.getDocumentId(), + "status", jobDoc.getStatus().name(), + "processed", job.getProcessedDocuments(), + "total", job.getTotalDocuments())); + } + + job.setStatus(OcrJobStatus.DONE); + ocrJobRepository.save(job); + + ocrProgressService.emit(jobId, "done", Map.of( + "processed", job.getProcessedDocuments(), + "errors", job.getErrorCount(), + "skipped", job.getSkippedCount())); + ocrProgressService.complete(jobId); + } +} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrProgressService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrProgressService.java new file mode 100644 index 00000000..8b3bc798 --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrProgressService.java @@ -0,0 +1,69 @@ +package org.raddatz.familienarchiv.service; + +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; + +@Service +@Slf4j +public class OcrProgressService { + + private static final long SSE_TIMEOUT = 5 * 60 * 1000L; + + private final ConcurrentHashMap> emitters = new ConcurrentHashMap<>(); + + public SseEmitter register(UUID jobId) { + SseEmitter emitter = new SseEmitter(SSE_TIMEOUT); + emitters.computeIfAbsent(jobId, k -> new CopyOnWriteArrayList<>()).add(emitter); + + emitter.onCompletion(() -> removeEmitter(jobId, emitter)); + emitter.onTimeout(() -> removeEmitter(jobId, emitter)); + emitter.onError(e -> removeEmitter(jobId, emitter)); + + return emitter; + } + + public void emit(UUID jobId, String eventType, Object data) { + List jobEmitters = emitters.get(jobId); + if (jobEmitters == null) return; + + for (SseEmitter emitter : jobEmitters) { + try { + emitter.send(SseEmitter.event().name(eventType).data(data)); + } catch (IOException e) { + log.debug("SSE send failed for job {} — removing emitter", jobId); + removeEmitter(jobId, emitter); + } + } + } + + public void complete(UUID jobId) { + List jobEmitters = emitters.remove(jobId); + if (jobEmitters == null) return; + + for (SseEmitter emitter : jobEmitters) { + try { + emitter.complete(); + } catch (Exception e) { + log.debug("SSE complete failed for job {}", jobId); + } + } + } + + private void removeEmitter(UUID jobId, SseEmitter emitter) { + List jobEmitters = emitters.get(jobId); + if (jobEmitters != null) { + jobEmitters.remove(emitter); + if (jobEmitters.isEmpty()) { + emitters.remove(jobId); + } + } + } +} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrService.java new file mode 100644 index 00000000..5ec7a2f1 --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrService.java @@ -0,0 +1,120 @@ +package org.raddatz.familienarchiv.service; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.raddatz.familienarchiv.dto.CreateAnnotationDTO; +import org.raddatz.familienarchiv.exception.DomainException; +import org.raddatz.familienarchiv.exception.ErrorCode; +import org.raddatz.familienarchiv.model.*; +import org.raddatz.familienarchiv.repository.OcrJobRepository; +import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.util.List; +import java.util.UUID; + +@Service +@RequiredArgsConstructor +@Slf4j +public class OcrService { + + private static final String OCR_ANNOTATION_COLOR = "#00C7B1"; + + private final OcrClient ocrClient; + private final OcrHealthClient ocrHealthClient; + private final DocumentService documentService; + private final TranscriptionService transcriptionService; + private final AnnotationService annotationService; + private final TranscriptionBlockRepository blockRepository; + private final OcrJobRepository ocrJobRepository; + + @Value("${app.s3.internal-url:http://minio:9000}") + private String s3InternalUrl; + + @Value("${app.s3.bucket}") + private String bucketName; + + @Transactional + public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId) { + Document doc = documentService.getDocumentById(documentId); + + if (doc.getStatus() == DocumentStatus.PLACEHOLDER) { + throw DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED, + "Document has no file attached: " + documentId); + } + + if (!ocrHealthClient.isHealthy()) { + throw DomainException.internal(ErrorCode.OCR_SERVICE_UNAVAILABLE, + "OCR service is not available"); + } + + if (scriptTypeOverride != null) { + doc.setScriptType(scriptTypeOverride); + } + + OcrJob job = OcrJob.builder() + .totalDocuments(1) + .createdBy(userId) + .status(OcrJobStatus.RUNNING) + .build(); + job = ocrJobRepository.save(job); + + try { + processDocument(documentId, doc, userId); + job.setStatus(OcrJobStatus.DONE); + job.setProcessedDocuments(1); + } catch (Exception e) { + log.error("OCR processing failed for document {}", documentId, e); + job.setStatus(OcrJobStatus.FAILED); + job.setErrorCount(1); + } + + ocrJobRepository.save(job); + return job.getId(); + } + + void processDocument(UUID documentId, Document doc, UUID userId) { + clearExistingBlocks(documentId); + + String pdfUrl = buildInternalUrl(doc.getFilePath()); + List blocks = ocrClient.extractBlocks(pdfUrl, doc.getScriptType()); + createTranscriptionBlocks(documentId, blocks, userId, doc.getFileHash()); + } + + private void clearExistingBlocks(UUID documentId) { + List existing = transcriptionService.listBlocks(documentId); + for (TranscriptionBlock block : existing) { + transcriptionService.deleteBlock(documentId, block.getId()); + } + } + + private void createTranscriptionBlocks(UUID documentId, List blocks, + UUID userId, String fileHash) { + for (int i = 0; i < blocks.size(); i++) { + OcrBlockResult block = blocks.get(i); + + CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO( + block.pageNumber(), block.x(), block.y(), + block.width(), block.height(), OCR_ANNOTATION_COLOR); + + DocumentAnnotation annotation = annotationService.createOcrAnnotation( + documentId, annotationDTO, userId, fileHash, block.polygon()); + + TranscriptionBlock transcriptionBlock = TranscriptionBlock.builder() + .annotationId(annotation.getId()) + .documentId(documentId) + .text(block.text() != null ? block.text() : "") + .sortOrder(i) + .createdBy(userId) + .updatedBy(userId) + .build(); + blockRepository.save(transcriptionBlock); + } + } + + String buildInternalUrl(String filePath) { + return s3InternalUrl + "/" + bucketName + "/" + filePath; + } +} diff --git a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java new file mode 100644 index 00000000..aef427ef --- /dev/null +++ b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java @@ -0,0 +1,138 @@ +package org.raddatz.familienarchiv.controller; + +import tools.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.raddatz.familienarchiv.config.SecurityConfig; +import org.raddatz.familienarchiv.dto.BatchOcrDTO; +import org.raddatz.familienarchiv.dto.TriggerOcrDTO; +import org.raddatz.familienarchiv.exception.DomainException; +import org.raddatz.familienarchiv.exception.ErrorCode; +import org.raddatz.familienarchiv.model.*; +import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository; +import org.raddatz.familienarchiv.repository.OcrJobRepository; +import org.raddatz.familienarchiv.security.PermissionAspect; +import org.raddatz.familienarchiv.service.*; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.aop.AopAutoConfiguration; +import org.springframework.boot.webmvc.test.autoconfigure.WebMvcTest; +import org.springframework.context.annotation.Import; +import org.springframework.http.MediaType; +import org.springframework.security.test.context.support.WithMockUser; +import org.springframework.test.context.bean.override.mockito.MockitoBean; +import org.springframework.test.web.servlet.MockMvc; + +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +@WebMvcTest(OcrController.class) +@Import({SecurityConfig.class, PermissionAspect.class, AopAutoConfiguration.class}) +class OcrControllerTest { + + @Autowired MockMvc mockMvc; + private final ObjectMapper objectMapper = new ObjectMapper(); + + @MockitoBean OcrService ocrService; + @MockitoBean OcrBatchService ocrBatchService; + @MockitoBean OcrProgressService ocrProgressService; + @MockitoBean OcrJobRepository ocrJobRepository; + @MockitoBean OcrJobDocumentRepository ocrJobDocumentRepository; + @MockitoBean UserService userService; + @MockitoBean CustomUserDetailsService customUserDetailsService; + + @Test + @WithMockUser(authorities = "WRITE_ALL") + void triggerOcr_returns202_withJobId() throws Exception { + UUID docId = UUID.randomUUID(); + UUID jobId = UUID.randomUUID(); + TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER); + + when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any())).thenReturn(jobId); + + mockMvc.perform(post("/api/documents/{id}/ocr", docId) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(dto))) + .andExpect(status().isAccepted()) + .andExpect(jsonPath("$.jobId").value(jobId.toString())); + } + + @Test + @WithMockUser(authorities = "WRITE_ALL") + void triggerOcr_returns400_whenDocumentNotUploaded() throws Exception { + UUID docId = UUID.randomUUID(); + when(ocrService.startOcr(eq(docId), any(), any())) + .thenThrow(DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED, "Not uploaded")); + + mockMvc.perform(post("/api/documents/{id}/ocr", docId) + .contentType(MediaType.APPLICATION_JSON) + .content("{}")) + .andExpect(status().isBadRequest()); + } + + @Test + @WithMockUser(authorities = "READ_ALL") + void getJobStatus_returns404_whenJobNotFound() throws Exception { + UUID jobId = UUID.randomUUID(); + when(ocrJobRepository.findById(jobId)).thenReturn(Optional.empty()); + + mockMvc.perform(get("/api/ocr/jobs/{jobId}", jobId)) + .andExpect(status().isNotFound()); + } + + @Test + @WithMockUser(authorities = "READ_ALL") + void getJobStatus_returnsJobInfo_whenFound() throws Exception { + UUID jobId = UUID.randomUUID(); + OcrJob job = OcrJob.builder() + .id(jobId) + .status(OcrJobStatus.RUNNING) + .totalDocuments(10) + .processedDocuments(3) + .errorCount(1) + .skippedCount(0) + .build(); + when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job)); + + mockMvc.perform(get("/api/ocr/jobs/{jobId}", jobId)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.status").value("RUNNING")) + .andExpect(jsonPath("$.totalDocuments").value(10)) + .andExpect(jsonPath("$.processedDocuments").value(3)); + } + + @Test + @WithMockUser(authorities = "ADMIN") + void triggerBatch_returns202_withJobId() throws Exception { + UUID jobId = UUID.randomUUID(); + List docIds = List.of(UUID.randomUUID(), UUID.randomUUID()); + BatchOcrDTO dto = new BatchOcrDTO(docIds); + + when(ocrBatchService.startBatch(eq(docIds), any())).thenReturn(jobId); + + mockMvc.perform(post("/api/ocr/batch") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(dto))) + .andExpect(status().isAccepted()) + .andExpect(jsonPath("$.jobId").value(jobId.toString())); + } + + @Test + @WithMockUser(authorities = "READ_ALL") + void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception { + UUID docId = UUID.randomUUID(); + when(ocrJobDocumentRepository.findFirstByDocumentIdAndStatusIn(eq(docId), any())) + .thenReturn(Optional.empty()); + + mockMvc.perform(get("/api/documents/{id}/ocr-status", docId)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.status").value("NONE")); + } +} diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrBatchServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrBatchServiceTest.java new file mode 100644 index 00000000..9640c3b0 --- /dev/null +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrBatchServiceTest.java @@ -0,0 +1,142 @@ +package org.raddatz.familienarchiv.service; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.raddatz.familienarchiv.exception.DomainException; +import org.raddatz.familienarchiv.exception.ErrorCode; +import org.raddatz.familienarchiv.model.*; +import org.raddatz.familienarchiv.repository.OcrJobDocumentRepository; +import org.raddatz.familienarchiv.repository.OcrJobRepository; + +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +class OcrBatchServiceTest { + + @Mock OcrService ocrService; + @Mock OcrHealthClient ocrHealthClient; + @Mock DocumentService documentService; + @Mock OcrJobRepository ocrJobRepository; + @Mock OcrJobDocumentRepository ocrJobDocumentRepository; + @Mock OcrProgressService ocrProgressService; + + @InjectMocks OcrBatchService ocrBatchService; + + @Test + void startBatch_throwsServiceUnavailable_whenOcrServiceIsDown() { + when(ocrHealthClient.isHealthy()).thenReturn(false); + + assertThatThrownBy(() -> ocrBatchService.startBatch(List.of(UUID.randomUUID()), UUID.randomUUID())) + .isInstanceOf(DomainException.class) + .satisfies(e -> assertThat(((DomainException) e).getCode()) + .isEqualTo(ErrorCode.OCR_SERVICE_UNAVAILABLE)); + } + + @Test + void startBatch_createsJobAndReturnsJobId() { + UUID docId = UUID.randomUUID(); + UUID userId = UUID.randomUUID(); + UUID jobId = UUID.randomUUID(); + + when(ocrHealthClient.isHealthy()).thenReturn(true); + when(ocrJobRepository.save(any())).thenAnswer(inv -> { + OcrJob job = inv.getArgument(0); + job.setId(jobId); + return job; + }); + when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of( + OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build())); + when(ocrJobDocumentRepository.findByJobIdOrderByCreatedAtAsc(jobId)).thenReturn(List.of( + OcrJobDocument.builder().jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build())); + + Document doc = Document.builder().id(docId).status(DocumentStatus.UPLOADED) + .filePath("test.pdf").fileHash("hash").scriptType(ScriptType.TYPEWRITER).build(); + when(documentService.getDocumentById(docId)).thenReturn(doc); + + UUID resultJobId = ocrBatchService.startBatch(List.of(docId), userId); + + assertThat(resultJobId).isEqualTo(jobId); + verify(ocrService).processDocument(eq(docId), eq(doc), eq(userId)); + } + + @Test + void processBatchAsync_skipsPlaceholderDocuments() { + UUID jobId = UUID.randomUUID(); + UUID uploadedId = UUID.randomUUID(); + UUID placeholderId = UUID.randomUUID(); + UUID userId = UUID.randomUUID(); + + OcrJob job = OcrJob.builder().id(jobId).totalDocuments(2).status(OcrJobStatus.PENDING).build(); + when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job)); + when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + + OcrJobDocument uploadedJobDoc = OcrJobDocument.builder() + .jobId(jobId).documentId(uploadedId).status(OcrDocumentStatus.PENDING).build(); + OcrJobDocument placeholderJobDoc = OcrJobDocument.builder() + .jobId(jobId).documentId(placeholderId).status(OcrDocumentStatus.PENDING).build(); + when(ocrJobDocumentRepository.findByJobIdOrderByCreatedAtAsc(jobId)) + .thenReturn(List.of(uploadedJobDoc, placeholderJobDoc)); + + Document uploaded = Document.builder().id(uploadedId).status(DocumentStatus.UPLOADED) + .filePath("test.pdf").fileHash("hash").scriptType(ScriptType.TYPEWRITER).build(); + Document placeholder = Document.builder().id(placeholderId).status(DocumentStatus.PLACEHOLDER).build(); + when(documentService.getDocumentById(uploadedId)).thenReturn(uploaded); + when(documentService.getDocumentById(placeholderId)).thenReturn(placeholder); + + ocrBatchService.processBatchAsync(jobId, userId); + + verify(ocrService).processDocument(eq(uploadedId), eq(uploaded), eq(userId)); + verify(ocrService, never()).processDocument(eq(placeholderId), any(), any()); + assertThat(placeholderJobDoc.getStatus()).isEqualTo(OcrDocumentStatus.SKIPPED); + } + + @Test + void processBatchAsync_continuesAfterSingleDocumentFailure() { + UUID jobId = UUID.randomUUID(); + UUID failDocId = UUID.randomUUID(); + UUID successDocId = UUID.randomUUID(); + UUID userId = UUID.randomUUID(); + + OcrJob job = OcrJob.builder().id(jobId).totalDocuments(2).status(OcrJobStatus.PENDING).build(); + when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job)); + when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + + OcrJobDocument failJobDoc = OcrJobDocument.builder() + .jobId(jobId).documentId(failDocId).status(OcrDocumentStatus.PENDING).build(); + OcrJobDocument successJobDoc = OcrJobDocument.builder() + .jobId(jobId).documentId(successDocId).status(OcrDocumentStatus.PENDING).build(); + when(ocrJobDocumentRepository.findByJobIdOrderByCreatedAtAsc(jobId)) + .thenReturn(List.of(failJobDoc, successJobDoc)); + + Document failDoc = Document.builder().id(failDocId).status(DocumentStatus.UPLOADED) + .filePath("fail.pdf").fileHash("hash1").scriptType(ScriptType.TYPEWRITER).build(); + Document successDoc = Document.builder().id(successDocId).status(DocumentStatus.UPLOADED) + .filePath("success.pdf").fileHash("hash2").scriptType(ScriptType.TYPEWRITER).build(); + when(documentService.getDocumentById(failDocId)).thenReturn(failDoc); + when(documentService.getDocumentById(successDocId)).thenReturn(successDoc); + + doThrow(new RuntimeException("OCR failed")).when(ocrService) + .processDocument(eq(failDocId), any(), any()); + + ocrBatchService.processBatchAsync(jobId, userId); + + verify(ocrService).processDocument(eq(successDocId), eq(successDoc), eq(userId)); + assertThat(failJobDoc.getStatus()).isEqualTo(OcrDocumentStatus.FAILED); + assertThat(successJobDoc.getStatus()).isEqualTo(OcrDocumentStatus.DONE); + assertThat(job.getStatus()).isEqualTo(OcrJobStatus.DONE); + } +} diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrProgressServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrProgressServiceTest.java new file mode 100644 index 00000000..44ed276f --- /dev/null +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrProgressServiceTest.java @@ -0,0 +1,33 @@ +package org.raddatz.familienarchiv.service; + +import org.junit.jupiter.api.Test; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + +class OcrProgressServiceTest { + + private final OcrProgressService progressService = new OcrProgressService(); + + @Test + void register_returnsNonNullEmitter() { + UUID jobId = UUID.randomUUID(); + SseEmitter emitter = progressService.register(jobId); + assertThat(emitter).isNotNull(); + } + + @Test + void emit_doesNotThrow_whenNoEmittersRegistered() { + assertThatCode(() -> progressService.emit(UUID.randomUUID(), "test", "data")) + .doesNotThrowAnyException(); + } + + @Test + void complete_doesNotThrow_whenNoEmittersRegistered() { + assertThatCode(() -> progressService.complete(UUID.randomUUID())) + .doesNotThrowAnyException(); + } +} diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrServiceTest.java new file mode 100644 index 00000000..44c598e0 --- /dev/null +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrServiceTest.java @@ -0,0 +1,176 @@ +package org.raddatz.familienarchiv.service; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.raddatz.familienarchiv.dto.CreateAnnotationDTO; +import org.raddatz.familienarchiv.exception.DomainException; +import org.raddatz.familienarchiv.exception.ErrorCode; +import org.raddatz.familienarchiv.model.*; +import org.raddatz.familienarchiv.repository.OcrJobRepository; +import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; + +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.springframework.http.HttpStatus.*; + +@ExtendWith(MockitoExtension.class) +class OcrServiceTest { + + @Mock OcrClient ocrClient; + @Mock OcrHealthClient ocrHealthClient; + @Mock DocumentService documentService; + @Mock TranscriptionService transcriptionService; + @Mock AnnotationService annotationService; + @Mock TranscriptionBlockRepository blockRepository; + @Mock OcrJobRepository ocrJobRepository; + + @InjectMocks OcrService ocrService; + + @Test + void startOcr_throwsBadRequest_whenDocumentIsPlaceholder() { + UUID docId = UUID.randomUUID(); + UUID userId = UUID.randomUUID(); + Document doc = Document.builder().id(docId).status(DocumentStatus.PLACEHOLDER).build(); + when(documentService.getDocumentById(docId)).thenReturn(doc); + + assertThatThrownBy(() -> ocrService.startOcr(docId, null, userId)) + .isInstanceOf(DomainException.class) + .satisfies(e -> { + DomainException de = (DomainException) e; + assertThat(de.getStatus()).isEqualTo(BAD_REQUEST); + assertThat(de.getCode()).isEqualTo(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED); + }); + } + + @Test + void startOcr_throwsServiceUnavailable_whenOcrServiceIsDown() { + UUID docId = UUID.randomUUID(); + UUID userId = UUID.randomUUID(); + Document doc = Document.builder().id(docId).status(DocumentStatus.UPLOADED) + .filePath("documents/test.pdf").fileHash("hash123").build(); + when(documentService.getDocumentById(docId)).thenReturn(doc); + when(ocrHealthClient.isHealthy()).thenReturn(false); + + assertThatThrownBy(() -> ocrService.startOcr(docId, null, userId)) + .isInstanceOf(DomainException.class) + .satisfies(e -> { + DomainException de = (DomainException) e; + assertThat(de.getCode()).isEqualTo(ErrorCode.OCR_SERVICE_UNAVAILABLE); + }); + } + + @Test + void startOcr_createsJobAndReturnsJobId() { + UUID docId = UUID.randomUUID(); + UUID userId = UUID.randomUUID(); + UUID jobId = UUID.randomUUID(); + Document doc = Document.builder().id(docId).status(DocumentStatus.UPLOADED) + .filePath("documents/test.pdf").fileHash("hash123") + .scriptType(ScriptType.TYPEWRITER).build(); + when(documentService.getDocumentById(docId)).thenReturn(doc); + when(ocrHealthClient.isHealthy()).thenReturn(true); + // ocrService constructs the internal MinIO URL from S3 key + when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of()); + when(ocrJobRepository.save(any())).thenAnswer(inv -> { + OcrJob job = inv.getArgument(0); + job.setId(jobId); + return job; + }); + + UUID resultJobId = ocrService.startOcr(docId, ScriptType.TYPEWRITER, userId); + + assertThat(resultJobId).isEqualTo(jobId); + verify(ocrJobRepository, atLeastOnce()).save(any()); + } + + @Test + void startOcr_setsScriptTypeOnDocument_whenProvided() { + UUID docId = UUID.randomUUID(); + UUID userId = UUID.randomUUID(); + Document doc = Document.builder().id(docId).status(DocumentStatus.UPLOADED) + .filePath("documents/test.pdf").fileHash("hash123") + .scriptType(ScriptType.UNKNOWN).build(); + when(documentService.getDocumentById(docId)).thenReturn(doc); + when(ocrHealthClient.isHealthy()).thenReturn(true); + // ocrService constructs the internal MinIO URL from S3 key + when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of()); + when(ocrJobRepository.save(any())).thenAnswer(inv -> { + OcrJob job = inv.getArgument(0); + job.setId(UUID.randomUUID()); + return job; + }); + + ocrService.startOcr(docId, ScriptType.HANDWRITING_LATIN, userId); + + assertThat(doc.getScriptType()).isEqualTo(ScriptType.HANDWRITING_LATIN); + } + + @Test + void startOcr_clearsExistingBlocks_beforeCreatingNew() { + UUID docId = UUID.randomUUID(); + UUID userId = UUID.randomUUID(); + Document doc = Document.builder().id(docId).status(DocumentStatus.UPLOADED) + .filePath("documents/test.pdf").fileHash("hash123") + .scriptType(ScriptType.TYPEWRITER).build(); + TranscriptionBlock existingBlock = TranscriptionBlock.builder() + .id(UUID.randomUUID()).documentId(docId).build(); + + when(documentService.getDocumentById(docId)).thenReturn(doc); + when(ocrHealthClient.isHealthy()).thenReturn(true); + // ocrService constructs the internal MinIO URL from S3 key + when(transcriptionService.listBlocks(docId)).thenReturn(List.of(existingBlock)); + when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of( + new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Hello"))); + when(ocrJobRepository.save(any())).thenAnswer(inv -> { + OcrJob job = inv.getArgument(0); + job.setId(UUID.randomUUID()); + return job; + }); + DocumentAnnotation ann = DocumentAnnotation.builder().id(UUID.randomUUID()).build(); + when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann); + + ocrService.startOcr(docId, null, userId); + + verify(transcriptionService).deleteBlock(docId, existingBlock.getId()); + } + + @Test + void startOcr_createsAnnotationAndBlock_forEachOcrResult() { + UUID docId = UUID.randomUUID(); + UUID userId = UUID.randomUUID(); + Document doc = Document.builder().id(docId).status(DocumentStatus.UPLOADED) + .filePath("documents/test.pdf").fileHash("hash123") + .scriptType(ScriptType.TYPEWRITER).build(); + + OcrBlockResult block1 = new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"); + OcrBlockResult block2 = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"); + + when(documentService.getDocumentById(docId)).thenReturn(doc); + when(ocrHealthClient.isHealthy()).thenReturn(true); + // ocrService constructs the internal MinIO URL from S3 key + when(transcriptionService.listBlocks(docId)).thenReturn(List.of()); + when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(block1, block2)); + when(ocrJobRepository.save(any())).thenAnswer(inv -> { + OcrJob job = inv.getArgument(0); + job.setId(UUID.randomUUID()); + return job; + }); + DocumentAnnotation ann = DocumentAnnotation.builder().id(UUID.randomUUID()).build(); + when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann); + + ocrService.startOcr(docId, null, userId); + + verify(annotationService, times(2)).createOcrAnnotation( + eq(docId), any(CreateAnnotationDTO.class), eq(userId), eq("hash123"), any()); + } +}