diff --git a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java index 4b8f9cd3..4c0d1d4a 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java @@ -12,11 +12,15 @@ import org.raddatz.familienarchiv.security.RequirePermission; import org.raddatz.familienarchiv.service.OcrBatchService; import org.raddatz.familienarchiv.service.OcrProgressService; import org.raddatz.familienarchiv.service.OcrService; +import org.raddatz.familienarchiv.service.TrainingDataExportService; import org.raddatz.familienarchiv.service.UserService; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; import org.springframework.security.core.Authentication; import org.springframework.web.bind.annotation.*; +import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import jakarta.validation.Valid; @@ -32,6 +36,7 @@ public class OcrController { private final OcrBatchService ocrBatchService; private final OcrProgressService ocrProgressService; private final UserService userService; + private final TrainingDataExportService trainingDataExportService; @PostMapping("/api/documents/{documentId}/ocr") @ResponseStatus(HttpStatus.ACCEPTED) @@ -75,6 +80,19 @@ public class OcrController { return ocrService.getDocumentOcrStatus(documentId); } + @GetMapping("/api/ocr/training-data/export") + @RequirePermission(Permission.ADMIN) + public ResponseEntity exportTrainingData() { + if (trainingDataExportService.queryEligibleBlocks().isEmpty()) { + return ResponseEntity.noContent().build(); + } + StreamingResponseBody body = trainingDataExportService.exportToZip(); + return ResponseEntity.ok() + .contentType(MediaType.parseMediaType("application/zip")) + .header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"training-data.zip\"") + .body(body); + } + private UUID resolveUserId(Authentication authentication) { if (authentication == null || !authentication.isAuthenticated()) return null; try { diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/TrainingDataExportService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/TrainingDataExportService.java new file mode 100644 index 00000000..06a23946 --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/TrainingDataExportService.java @@ -0,0 +1,141 @@ +package org.raddatz.familienarchiv.service; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.pdfbox.Loader; +import org.apache.pdfbox.pdmodel.PDDocument; +import org.apache.pdfbox.rendering.PDFRenderer; +import org.raddatz.familienarchiv.model.Document; +import org.raddatz.familienarchiv.model.DocumentAnnotation; +import org.raddatz.familienarchiv.model.TranscriptionBlock; +import org.raddatz.familienarchiv.repository.AnnotationRepository; +import org.raddatz.familienarchiv.repository.DocumentRepository; +import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; + +import javax.imageio.ImageIO; +import java.awt.image.BufferedImage; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; + +@Service +@RequiredArgsConstructor +@Slf4j +public class TrainingDataExportService { + + private final TranscriptionBlockRepository blockRepository; + private final AnnotationRepository annotationRepository; + private final DocumentRepository documentRepository; + private final FileService fileService; + + public List queryEligibleBlocks() { + return blockRepository.findEligibleKurrentBlocks(); + } + + public StreamingResponseBody exportToZip() { + // Collect all data before entering the lambda — no open DB txn during streaming + List blocks = queryEligibleBlocks(); + if (blocks.isEmpty()) { + return out -> {}; // caller checks isEmpty() for 204 response + } + + // Group blocks by documentId so we only download each PDF once + Map> byDoc = new LinkedHashMap<>(); + for (TranscriptionBlock b : blocks) { + byDoc.computeIfAbsent(b.getDocumentId(), k -> new ArrayList<>()).add(b); + } + + // Pre-fetch annotations keyed by id + Map annotations = new HashMap<>(); + for (TranscriptionBlock b : blocks) { + annotationRepository.findById(b.getAnnotationId()) + .ifPresent(a -> annotations.put(a.getId(), a)); + } + + // Pre-fetch documents keyed by id + Map documents = new HashMap<>(); + for (UUID docId : byDoc.keySet()) { + documentRepository.findById(docId).ifPresent(d -> documents.put(d.getId(), d)); + } + + return out -> { + try (ZipOutputStream zip = new ZipOutputStream(out)) { + for (Map.Entry> entry : byDoc.entrySet()) { + UUID docId = entry.getKey(); + Document doc = documents.get(docId); + if (doc == null || doc.getFilePath() == null) { + log.warn("Skipping document {} — no file path", docId); + continue; + } + + byte[] pdfBytes; + try { + pdfBytes = fileService.downloadFileBytes(doc.getFilePath()); + } catch (FileService.StorageFileNotFoundException | IOException e) { + log.warn("Skipping document {} — S3 download failed: {}", docId, e.getMessage()); + continue; + } + + try (PDDocument pdf = Loader.loadPDF(pdfBytes)) { + PDFRenderer renderer = new PDFRenderer(pdf); + for (TranscriptionBlock block : entry.getValue()) { + DocumentAnnotation ann = annotations.get(block.getAnnotationId()); + if (ann == null) continue; + + int pageIdx = ann.getPageNumber() - 1; // pageNumber is 1-based + if (pageIdx < 0 || pageIdx >= pdf.getNumberOfPages()) continue; + + BufferedImage pageImage = renderPageImage(renderer, pageIdx); + BufferedImage cropped = cropBlockImage(pageImage, ann); + + writeTrainingPair(zip, block.getId(), cropped, block.getText()); + } + } catch (Exception e) { + log.warn("Skipping document {} — rendering failed: {}", docId, e.getMessage()); + } + } + } + }; + } + + BufferedImage renderPageImage(PDFRenderer renderer, int pageIdx) throws IOException { + return renderer.renderImageWithDPI(pageIdx, 300); + } + + BufferedImage cropBlockImage(BufferedImage page, DocumentAnnotation ann) { + int imgW = page.getWidth(); + int imgH = page.getHeight(); + + int x = (int) (ann.getX() * imgW); + int y = (int) (ann.getY() * imgH); + int w = (int) (ann.getWidth() * imgW); + int h = (int) (ann.getHeight() * imgH); + + // Clamp to image bounds + x = Math.max(0, Math.min(x, imgW - 1)); + y = Math.max(0, Math.min(y, imgH - 1)); + w = Math.max(1, Math.min(w, imgW - x)); + h = Math.max(1, Math.min(h, imgH - y)); + + return page.getSubimage(x, y, w, h); + } + + void writeTrainingPair(ZipOutputStream zip, UUID blockId, BufferedImage image, String text) throws IOException { + String base = blockId.toString(); + + // Write PNG + zip.putNextEntry(new ZipEntry(base + ".png")); + ImageIO.write(image, "PNG", zip); + zip.closeEntry(); + + // Write ground-truth text + zip.putNextEntry(new ZipEntry(base + ".gt.txt")); + zip.write((text != null ? text : "").getBytes(StandardCharsets.UTF_8)); + zip.closeEntry(); + } +} diff --git a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java index a7d6d5cf..bdb3a346 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java @@ -43,6 +43,7 @@ class OcrControllerTest { @MockitoBean OcrProgressService ocrProgressService; @MockitoBean UserService userService; @MockitoBean CustomUserDetailsService customUserDetailsService; + @MockitoBean TrainingDataExportService trainingDataExportService; @Test @WithMockUser(authorities = "WRITE_ALL") @@ -121,6 +122,47 @@ class OcrControllerTest { .andExpect(jsonPath("$.jobId").value(jobId.toString())); } + // ─── GET /api/ocr/training-data/export ─────────────────────────────────── + + @Test + void exportTrainingData_returns401_whenUnauthenticated() throws Exception { + mockMvc.perform(get("/api/ocr/training-data/export")) + .andExpect(status().isUnauthorized()); + } + + @Test + @WithMockUser(authorities = "READ_ALL") + void exportTrainingData_returns403_whenNotAdmin() throws Exception { + mockMvc.perform(get("/api/ocr/training-data/export")) + .andExpect(status().isForbidden()); + } + + @Test + @WithMockUser(authorities = "ADMIN") + void exportTrainingData_returns204_whenNoEligibleBlocks() throws Exception { + when(trainingDataExportService.queryEligibleBlocks()).thenReturn(List.of()); + + mockMvc.perform(get("/api/ocr/training-data/export")) + .andExpect(status().isNoContent()); + } + + @Test + @WithMockUser(authorities = "ADMIN") + void exportTrainingData_returns200_withZipContentType_whenBlocksExist() throws Exception { + org.raddatz.familienarchiv.model.TranscriptionBlock block = + org.raddatz.familienarchiv.model.TranscriptionBlock.builder() + .id(UUID.randomUUID()).documentId(UUID.randomUUID()) + .annotationId(UUID.randomUUID()).text("x").sortOrder(0).build(); + when(trainingDataExportService.queryEligibleBlocks()).thenReturn(List.of(block)); + when(trainingDataExportService.exportToZip()).thenReturn(out -> {}); + + mockMvc.perform(get("/api/ocr/training-data/export")) + .andExpect(status().isOk()) + .andExpect(result -> + org.assertj.core.api.Assertions.assertThat( + result.getResponse().getContentType()).contains("application/zip")); + } + @Test @WithMockUser(authorities = "READ_ALL") void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception { diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/TrainingDataExportServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/TrainingDataExportServiceTest.java new file mode 100644 index 00000000..f214f989 --- /dev/null +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/TrainingDataExportServiceTest.java @@ -0,0 +1,258 @@ +package org.raddatz.familienarchiv.service; + +import org.apache.pdfbox.Loader; +import org.apache.pdfbox.pdmodel.PDDocument; +import org.apache.pdfbox.pdmodel.PDPage; +import org.apache.pdfbox.pdmodel.common.PDRectangle; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.raddatz.familienarchiv.PostgresContainerConfig; +import org.raddatz.familienarchiv.config.FlywayConfig; +import org.raddatz.familienarchiv.model.*; +import org.raddatz.familienarchiv.repository.AnnotationRepository; +import org.raddatz.familienarchiv.repository.DocumentRepository; +import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.jdbc.test.autoconfigure.AutoConfigureTestDatabase; +import org.springframework.boot.data.jpa.test.autoconfigure.DataJpaTest; +import org.springframework.context.annotation.Import; +import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.HashSet; +import java.util.Set; +import java.util.UUID; +import java.util.zip.ZipEntry; +import java.util.zip.ZipInputStream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +@DataJpaTest +@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE) +@Import({PostgresContainerConfig.class, FlywayConfig.class}) +class TrainingDataExportServiceTest { + + @Autowired TranscriptionBlockRepository blockRepository; + @Autowired DocumentRepository documentRepository; + @Autowired AnnotationRepository annotationRepository; + + static byte[] minimalPdfBytes; + + @BeforeAll + static void createMinimalPdf() throws Exception { + try (PDDocument doc = new PDDocument()) { + doc.addPage(new PDPage(PDRectangle.A4)); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + doc.save(out); + minimalPdfBytes = out.toByteArray(); + } + } + + // ─── Query: enrollment filter ───────────────────────────────────────────── + + @Test + void export_includesManualBlockFromEnrolledDocument() throws Exception { + UUID docId = enrolledDoc("enrolled.pdf"); + UUID annotId = annotation(docId); + blockRepository.save(manualBlock(docId, annotId, "Liebe Mutter")); + + FileService fileService = mockFileService(); + TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService); + + StreamingResponseBody body = service.exportToZip(); + byte[] zipBytes = stream(body); + assertThat(zipEntryNames(zipBytes)).isNotEmpty(); + } + + @Test + void export_excludesManualBlockFromNonEnrolledDocument() throws Exception { + UUID docId = nonEnrolledDoc("notenrolled.pdf"); + UUID annotId = annotation(docId); + blockRepository.save(manualBlock(docId, annotId, "Liebe Tante")); + + FileService fileService = mockFileService(); + TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService); + + StreamingResponseBody body = service.exportToZip(); + byte[] zipBytes = stream(body); + assertThat(zipEntryNames(zipBytes)).isEmpty(); + } + + @Test + void export_includesReviewedOcrBlockFromEnrolledDocument() throws Exception { + UUID docId = enrolledDoc("ocr-reviewed.pdf"); + UUID annotId = annotation(docId); + TranscriptionBlock block = TranscriptionBlock.builder() + .annotationId(annotId).documentId(docId) + .text("OCR text").sortOrder(0) + .source(BlockSource.OCR).reviewed(true).build(); + blockRepository.save(block); + + FileService fileService = mockFileService(); + TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService); + + StreamingResponseBody body = service.exportToZip(); + assertThat(zipEntryNames(stream(body))).isNotEmpty(); + } + + @Test + void export_excludesUnreviewedOcrBlockFromEnrolledDocument() throws Exception { + UUID docId = enrolledDoc("ocr-unreviewed.pdf"); + UUID annotId = annotation(docId); + TranscriptionBlock block = TranscriptionBlock.builder() + .annotationId(annotId).documentId(docId) + .text("Raw OCR").sortOrder(0) + .source(BlockSource.OCR).reviewed(false).build(); + blockRepository.save(block); + + FileService fileService = mockFileService(); + TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService); + + StreamingResponseBody body = service.exportToZip(); + assertThat(zipEntryNames(stream(body))).isEmpty(); + } + + // ─── ZIP structure ──────────────────────────────────────────────────────── + + @Test + void export_producesExactly2EntriesPerBlock_pngAndTxt() throws Exception { + UUID docId = enrolledDoc("zip-struct.pdf"); + UUID annotId = annotation(docId); + blockRepository.save(manualBlock(docId, annotId, "Erste Zeile")); + blockRepository.save(manualBlock(docId, annotId, "Zweite Zeile")); + + FileService fileService = mockFileService(); + TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService); + + byte[] zipBytes = stream(service.exportToZip()); + var names = zipEntryNames(zipBytes); + assertThat(names).hasSize(4); // 2 blocks × 2 entries each + assertThat(names.stream().filter(n -> n.endsWith(".png")).count()).isEqualTo(2); + assertThat(names.stream().filter(n -> n.endsWith(".gt.txt")).count()).isEqualTo(2); + } + + @Test + void export_gtTxtContainsBlockText() throws Exception { + UUID docId = enrolledDoc("txt-content.pdf"); + UUID annotId = annotation(docId); + String expectedText = "Sehr geehrte Frau"; + blockRepository.save(manualBlock(docId, annotId, expectedText)); + + FileService fileService = mockFileService(); + TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService); + + byte[] zipBytes = stream(service.exportToZip()); + String txtContent = readZipEntry(zipBytes, ".gt.txt"); + assertThat(txtContent).isEqualTo(expectedText); + } + + // ─── S3 failure resilience ──────────────────────────────────────────────── + + @Test + void export_skipsDocumentWhenS3DownloadFails_andStillIncludesOtherDocuments() throws Exception { + UUID failDocId = enrolledDoc("fail.pdf"); + UUID okDocId = enrolledDoc("ok.pdf"); + UUID failAnnotId = annotation(failDocId); + UUID okAnnotId = annotation(okDocId); + blockRepository.save(manualBlock(failDocId, failAnnotId, "Will fail")); + blockRepository.save(manualBlock(okDocId, okAnnotId, "Will succeed")); + + FileService fileService = mock(FileService.class); + when(fileService.downloadFileBytes("fail.pdf")).thenThrow(new FileService.StorageFileNotFoundException("missing")); + when(fileService.downloadFileBytes("ok.pdf")).thenReturn(minimalPdfBytes); + + TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService); + + byte[] zipBytes = stream(service.exportToZip()); + var names = zipEntryNames(zipBytes); + // ok.pdf block produces 2 entries; fail.pdf block is skipped + assertThat(names).hasSize(2); + } + + // ─── Empty export ───────────────────────────────────────────────────────── + + @Test + void queryEligibleBlocks_returnsEmpty_whenNoEnrolledDocuments() { + FileService fileService = mockFileService(); + TrainingDataExportService service = new TrainingDataExportService(blockRepository, annotationRepository, documentRepository, fileService); + + assertThat(service.queryEligibleBlocks()).isEmpty(); + } + + // ─── helpers ───────────────────────────────────────────────────────────── + + private UUID enrolledDoc(String filename) { + Document doc = documentRepository.save(Document.builder() + .title(filename).originalFilename(filename).filePath(filename) + .status(DocumentStatus.UPLOADED) + .trainingLabels(new HashSet<>(Set.of(TrainingLabel.KURRENT_RECOGNITION))) + .build()); + return doc.getId(); + } + + private UUID nonEnrolledDoc(String filename) { + Document doc = documentRepository.save(Document.builder() + .title(filename).originalFilename(filename).filePath(filename) + .status(DocumentStatus.UPLOADED) + .build()); + return doc.getId(); + } + + private UUID annotation(UUID docId) { + return annotationRepository.save(DocumentAnnotation.builder() + .documentId(docId).pageNumber(1) + .x(0.1).y(0.1).width(0.8).height(0.1).color("#00C7B1") + .build()).getId(); + } + + private TranscriptionBlock manualBlock(UUID docId, UUID annotId, String text) { + return TranscriptionBlock.builder() + .annotationId(annotId).documentId(docId) + .text(text).sortOrder(0) + .source(BlockSource.MANUAL).reviewed(false).build(); + } + + private FileService mockFileService() { + FileService fs = mock(FileService.class); + try { + when(fs.downloadFileBytes(anyString())).thenReturn(minimalPdfBytes); + } catch (Exception e) { + throw new RuntimeException(e); + } + return fs; + } + + private static byte[] stream(StreamingResponseBody body) throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + body.writeTo(out); + return out.toByteArray(); + } + + private static java.util.List zipEntryNames(byte[] zipBytes) throws Exception { + var names = new java.util.ArrayList(); + try (ZipInputStream zis = new ZipInputStream(new ByteArrayInputStream(zipBytes))) { + ZipEntry entry; + while ((entry = zis.getNextEntry()) != null) { + names.add(entry.getName()); + zis.closeEntry(); + } + } + return names; + } + + private static String readZipEntry(byte[] zipBytes, String suffix) throws Exception { + try (ZipInputStream zis = new ZipInputStream(new ByteArrayInputStream(zipBytes))) { + ZipEntry entry; + while ((entry = zis.getNextEntry()) != null) { + if (entry.getName().endsWith(suffix)) { + return new String(zis.readAllBytes(), java.nio.charset.StandardCharsets.UTF_8); + } + zis.closeEntry(); + } + } + return null; + } +}