feat(ocr): full OCR pipeline with polygon annotations, training, and guided mode #232
@@ -51,7 +51,8 @@ public class OcrController {
|
||||
@RequestBody TriggerOcrDTO dto,
|
||||
Authentication authentication) {
|
||||
UUID userId = resolveUserId(authentication);
|
||||
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId);
|
||||
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId,
|
||||
Boolean.TRUE.equals(dto.getUseExistingAnnotations()));
|
||||
return Map.of("jobId", jobId);
|
||||
}
|
||||
|
||||
|
||||
@@ -10,4 +10,5 @@ import org.raddatz.familienarchiv.model.ScriptType;
|
||||
@AllArgsConstructor
|
||||
public class TriggerOcrDTO {
|
||||
private ScriptType scriptType;
|
||||
private Boolean useExistingAnnotations = false;
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ public interface TranscriptionBlockRepository extends JpaRepository<Transcriptio
|
||||
|
||||
Optional<TranscriptionBlock> findByIdAndDocumentId(UUID id, UUID documentId);
|
||||
|
||||
Optional<TranscriptionBlock> findByAnnotationId(UUID annotationId);
|
||||
|
||||
int countByDocumentId(UUID documentId);
|
||||
|
||||
@Query("""
|
||||
|
||||
@@ -32,6 +32,11 @@ public class OcrAsyncRunner {
|
||||
|
||||
@Async
|
||||
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
|
||||
runSingleDocument(jobId, documentId, userId, false);
|
||||
}
|
||||
|
||||
@Async
|
||||
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId, boolean useExistingAnnotations) {
|
||||
OcrJob job = ocrJobRepository.findById(jobId).orElse(null);
|
||||
if (job == null) return;
|
||||
|
||||
@@ -49,7 +54,18 @@ public class OcrAsyncRunner {
|
||||
|
||||
try {
|
||||
updateProgress(job, "LOADING");
|
||||
clearExistingBlocks(documentId);
|
||||
|
||||
List<OcrClient.OcrRegion> regions = null;
|
||||
if (useExistingAnnotations) {
|
||||
regions = annotationService.listAnnotations(documentId).stream()
|
||||
.map(a -> new OcrClient.OcrRegion(
|
||||
a.getId().toString(), a.getPageNumber(),
|
||||
a.getX(), a.getY(), a.getWidth(), a.getHeight()))
|
||||
.toList();
|
||||
} else {
|
||||
clearExistingBlocks(documentId);
|
||||
}
|
||||
|
||||
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
|
||||
|
||||
AtomicInteger blockCounter = new AtomicInteger(0);
|
||||
@@ -57,7 +73,7 @@ public class OcrAsyncRunner {
|
||||
AtomicInteger skippedPages = new AtomicInteger(0);
|
||||
AtomicInteger totalPages = new AtomicInteger(0);
|
||||
|
||||
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), event -> {
|
||||
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, event -> {
|
||||
switch (event) {
|
||||
case OcrStreamEvent.Start start -> {
|
||||
totalPages.set(start.totalPages());
|
||||
@@ -204,14 +220,21 @@ public class OcrAsyncRunner {
|
||||
|
||||
void createSingleBlock(UUID documentId, OcrBlockResult block,
|
||||
UUID userId, String fileHash, int sortOrder) {
|
||||
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
|
||||
block.pageNumber(), block.x(), block.y(),
|
||||
block.width(), block.height(), OCR_ANNOTATION_COLOR);
|
||||
if (block.annotationId() != null) {
|
||||
// Guided mode — annotation already exists; upsert the text block only
|
||||
transcriptionService.upsertGuidedBlock(
|
||||
documentId, UUID.fromString(block.annotationId()), block.text(), userId);
|
||||
} else {
|
||||
// Normal mode — create a new annotation and a new OCR block
|
||||
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
|
||||
block.pageNumber(), block.x(), block.y(),
|
||||
block.width(), block.height(), OCR_ANNOTATION_COLOR);
|
||||
|
||||
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
|
||||
documentId, annotationDTO, userId, fileHash, block.polygon());
|
||||
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
|
||||
documentId, annotationDTO, userId, fileHash, block.polygon());
|
||||
|
||||
transcriptionService.createOcrBlock(documentId, annotation.getId(),
|
||||
block.text(), sortOrder, userId);
|
||||
transcriptionService.createOcrBlock(documentId, annotation.getId(),
|
||||
block.text(), sortOrder, userId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,5 +12,6 @@ public record OcrBlockResult(
|
||||
double width,
|
||||
double height,
|
||||
List<List<Double>> polygon,
|
||||
String text
|
||||
String text,
|
||||
String annotationId // null in normal mode; set in guided mode to link back to existing annotation
|
||||
) {}
|
||||
|
||||
@@ -10,6 +10,14 @@ import java.util.function.Consumer;
|
||||
public interface OcrClient {
|
||||
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
|
||||
|
||||
/**
|
||||
* A pre-drawn annotation region to use as guidance for OCR.
|
||||
* When regions are provided, the OCR engine crops to each region and
|
||||
* runs recognition only within that area, skipping full-page layout detection.
|
||||
*/
|
||||
record OcrRegion(String annotationId, int pageNumber,
|
||||
double x, double y, double width, double height) {}
|
||||
|
||||
/**
|
||||
* Send a training ZIP to the OCR service for fine-tuning the Kurrent model.
|
||||
*
|
||||
@@ -32,8 +40,12 @@ public interface OcrClient {
|
||||
* Stream OCR results page-by-page via NDJSON. Implementations should override
|
||||
* this method. The default exists only for backward compatibility during migration
|
||||
* — it calls extractBlocks() and synthesizes events from the collected result.
|
||||
*
|
||||
* @param regions optional list of pre-drawn annotation regions; when non-null,
|
||||
* the OCR service runs in guided mode (crop + recognize per region)
|
||||
*/
|
||||
default void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
|
||||
default void streamBlocks(String pdfUrl, ScriptType scriptType,
|
||||
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
|
||||
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
|
||||
|
||||
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();
|
||||
|
||||
@@ -52,6 +52,11 @@ public class OcrService {
|
||||
}
|
||||
|
||||
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId) {
|
||||
return startOcr(documentId, scriptTypeOverride, userId, false);
|
||||
}
|
||||
|
||||
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId,
|
||||
boolean useExistingAnnotations) {
|
||||
Document doc = documentService.getDocumentById(documentId);
|
||||
|
||||
if (doc.getStatus() == DocumentStatus.PLACEHOLDER) {
|
||||
@@ -85,7 +90,7 @@ public class OcrService {
|
||||
.build();
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId);
|
||||
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId, useExistingAnnotations);
|
||||
return job.getId();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,13 +175,18 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
|
||||
public void streamBlocks(String pdfUrl, ScriptType scriptType,
|
||||
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
|
||||
String body;
|
||||
try {
|
||||
body = NDJSON_MAPPER.writeValueAsString(Map.of(
|
||||
"pdfUrl", pdfUrl,
|
||||
"scriptType", scriptType.name(),
|
||||
"language", "de"));
|
||||
var requestMap = new java.util.LinkedHashMap<String, Object>();
|
||||
requestMap.put("pdfUrl", pdfUrl);
|
||||
requestMap.put("scriptType", scriptType.name());
|
||||
requestMap.put("language", "de");
|
||||
if (regions != null && !regions.isEmpty()) {
|
||||
requestMap.put("regions", regions);
|
||||
}
|
||||
body = NDJSON_MAPPER.writeValueAsString(requestMap);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to serialize OCR request", e);
|
||||
}
|
||||
@@ -199,7 +204,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
|
||||
if (response.statusCode() == 404) {
|
||||
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
|
||||
OcrClient.super.streamBlocks(pdfUrl, scriptType, handler);
|
||||
OcrClient.super.streamBlocks(pdfUrl, scriptType, regions, handler);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -259,10 +264,11 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
double width,
|
||||
double height,
|
||||
List<List<Double>> polygon,
|
||||
String text
|
||||
String text,
|
||||
String annotationId
|
||||
) {
|
||||
OcrBlockResult toResult() {
|
||||
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text);
|
||||
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text, annotationId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,6 +94,27 @@ public class TranscriptionService {
|
||||
return saved;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upsert an OCR transcription block for a pre-existing annotation (guided OCR mode).
|
||||
* If the annotation already has a MANUAL block, it is left unchanged.
|
||||
* If it has an OCR block, the text is updated in-place.
|
||||
* If it has no block yet, a new OCR block is created.
|
||||
*/
|
||||
@Transactional
|
||||
public TranscriptionBlock upsertGuidedBlock(UUID documentId, UUID annotationId,
|
||||
String text, UUID userId) {
|
||||
return blockRepository.findByAnnotationId(annotationId).map(existing -> {
|
||||
if (existing.getSource() == BlockSource.MANUAL) {
|
||||
return existing; // never overwrite manual transcription
|
||||
}
|
||||
existing.setText(sanitizeText(text));
|
||||
existing.setUpdatedBy(userId);
|
||||
TranscriptionBlock saved = blockRepository.save(existing);
|
||||
saveVersion(saved, userId);
|
||||
return saved;
|
||||
}).orElseGet(() -> createOcrBlock(documentId, annotationId, text, 0, userId));
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public TranscriptionBlock updateBlock(UUID documentId, UUID blockId,
|
||||
UpdateTranscriptionBlockDTO dto, UUID userId) {
|
||||
|
||||
@@ -24,6 +24,7 @@ import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
||||
@@ -52,9 +53,9 @@ class OcrControllerTest {
|
||||
void triggerOcr_returns202_withJobId() throws Exception {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID jobId = UUID.randomUUID();
|
||||
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER);
|
||||
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER, false);
|
||||
|
||||
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any())).thenReturn(jobId);
|
||||
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any(), anyBoolean())).thenReturn(jobId);
|
||||
|
||||
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
@@ -67,7 +68,7 @@ class OcrControllerTest {
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void triggerOcr_returns400_whenDocumentNotUploaded() throws Exception {
|
||||
UUID docId = UUID.randomUUID();
|
||||
when(ocrService.startOcr(eq(docId), any(), any()))
|
||||
when(ocrService.startOcr(eq(docId), any(), any(), anyBoolean()))
|
||||
.thenThrow(DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED, "Not uploaded"));
|
||||
|
||||
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
|
||||
|
||||
@@ -60,8 +60,8 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2")));
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null)));
|
||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||
|
||||
@@ -84,7 +84,7 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test")));
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null)));
|
||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||
|
||||
@@ -112,12 +112,12 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(1));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -142,7 +142,7 @@ class OcrAsyncRunnerTest {
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -174,18 +174,18 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
List<String> progressMessages = new ArrayList<>();
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(3));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2"))));
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2", null))));
|
||||
progressMessages.add(job.getProgressMessage());
|
||||
handler.accept(new OcrStreamEvent.Page(1, List.of(
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3"))));
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3", null))));
|
||||
progressMessages.add(job.getProgressMessage());
|
||||
handler.accept(new OcrStreamEvent.Done(3, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -215,14 +215,14 @@ class OcrAsyncRunnerTest {
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(3));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Error(1, "failed"));
|
||||
handler.accept(new OcrStreamEvent.Page(2, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 1));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -251,13 +251,13 @@ class OcrAsyncRunnerTest {
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(2));
|
||||
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
|
||||
handler.accept(new OcrStreamEvent.Page(1, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 1));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
|
||||
@@ -15,16 +15,16 @@ class OcrClientDefaultStreamTest {
|
||||
OcrClient client = new OcrClient() {
|
||||
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
|
||||
return List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null),
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3", null));
|
||||
}
|
||||
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
|
||||
};
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
|
||||
|
||||
assertThat(events).hasSize(4);
|
||||
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
||||
@@ -55,7 +55,7 @@ class OcrClientDefaultStreamTest {
|
||||
};
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
|
||||
|
||||
assertThat(events).hasSize(2);
|
||||
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
||||
|
||||
@@ -142,7 +142,7 @@ class OcrServiceTest {
|
||||
UUID result = ocrService.startOcr(docId, null, userId);
|
||||
|
||||
assertThat(result).isEqualTo(jobId);
|
||||
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId);
|
||||
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -17,7 +17,7 @@ class OcrStreamEventTest {
|
||||
|
||||
@Test
|
||||
void pageRecordHoldsBlocksAndPageNumber() {
|
||||
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test");
|
||||
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test", null);
|
||||
var page = new OcrStreamEvent.Page(0, List.of(block));
|
||||
assertThat(page.pageNumber()).isEqualTo(0);
|
||||
assertThat(page.blocks()).hasSize(1);
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.model.BlockSource;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.AnnotationRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockVersionRepository;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class TranscriptionServiceGuidedTest {
|
||||
|
||||
TranscriptionBlockRepository blockRepository;
|
||||
TranscriptionBlockVersionRepository versionRepository;
|
||||
AnnotationRepository annotationRepository;
|
||||
AnnotationService annotationService;
|
||||
DocumentService documentService;
|
||||
TranscriptionService service;
|
||||
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID annId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||
versionRepository = mock(TranscriptionBlockVersionRepository.class);
|
||||
annotationRepository = mock(AnnotationRepository.class);
|
||||
annotationService = mock(AnnotationService.class);
|
||||
documentService = mock(DocumentService.class);
|
||||
|
||||
service = new TranscriptionService(blockRepository, versionRepository,
|
||||
annotationRepository, annotationService, documentService);
|
||||
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
void upsertGuidedBlock_createsNewBlock_whenAnnotationHasNoBlock() {
|
||||
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.empty());
|
||||
|
||||
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "Hallo", userId);
|
||||
|
||||
assertThat(result.getText()).isEqualTo("Hallo");
|
||||
assertThat(result.getSource()).isEqualTo(BlockSource.OCR);
|
||||
assertThat(result.getAnnotationId()).isEqualTo(annId);
|
||||
verify(blockRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void upsertGuidedBlock_updatesExistingOcrBlock_whenAnnotationHasOcrBlock() {
|
||||
TranscriptionBlock existing = TranscriptionBlock.builder()
|
||||
.id(UUID.randomUUID())
|
||||
.annotationId(annId)
|
||||
.documentId(docId)
|
||||
.text("old text")
|
||||
.source(BlockSource.OCR)
|
||||
.sortOrder(0)
|
||||
.build();
|
||||
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(existing));
|
||||
|
||||
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "new text", userId);
|
||||
|
||||
assertThat(result.getText()).isEqualTo("new text");
|
||||
verify(blockRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void upsertGuidedBlock_doesNotOverwriteManualBlock() {
|
||||
TranscriptionBlock manual = TranscriptionBlock.builder()
|
||||
.id(UUID.randomUUID())
|
||||
.annotationId(annId)
|
||||
.documentId(docId)
|
||||
.text("manually written")
|
||||
.source(BlockSource.MANUAL)
|
||||
.sortOrder(0)
|
||||
.build();
|
||||
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(manual));
|
||||
|
||||
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "ocr result", userId);
|
||||
|
||||
assertThat(result.getText()).isEqualTo("manually written");
|
||||
verify(blockRepository, never()).save(any());
|
||||
}
|
||||
}
|
||||
@@ -531,6 +531,8 @@
|
||||
"ocr_status_analyzing_page": "Seite {current} von {total} wird analysiert…",
|
||||
"ocr_status_done_skipped": "{count} Blöcke erstellt, {skipped} Seite(n) übersprungen",
|
||||
"ocr_status_error": "OCR fehlgeschlagen",
|
||||
"ocr_use_existing_annotations": "Nur annotierte Bereiche",
|
||||
"ocr_use_existing_annotations_hint": "OCR wird nur innerhalb der bereits markierten Bereiche ausgeführt — keine neue Layout-Erkennung.",
|
||||
"transcription_block_review": "Als geprüft markieren",
|
||||
"transcription_block_unreview": "Markierung aufheben",
|
||||
"transcription_reviewed_count": "{reviewed} von {total} geprüft",
|
||||
|
||||
@@ -531,6 +531,8 @@
|
||||
"ocr_status_analyzing_page": "Analyzing page {current} of {total}…",
|
||||
"ocr_status_done_skipped": "{count} blocks created, {skipped} page(s) skipped",
|
||||
"ocr_status_error": "OCR failed",
|
||||
"ocr_use_existing_annotations": "Annotated regions only",
|
||||
"ocr_use_existing_annotations_hint": "OCR runs only within the already marked regions — no new layout detection.",
|
||||
"transcription_block_review": "Mark as reviewed",
|
||||
"transcription_block_unreview": "Unmark as reviewed",
|
||||
"transcription_reviewed_count": "{reviewed} of {total} reviewed",
|
||||
|
||||
@@ -531,6 +531,8 @@
|
||||
"ocr_status_analyzing_page": "Analizando página {current} de {total}…",
|
||||
"ocr_status_done_skipped": "{count} bloques creados, {skipped} página(s) omitida(s)",
|
||||
"ocr_status_error": "OCR fallido",
|
||||
"ocr_use_existing_annotations": "Solo regiones anotadas",
|
||||
"ocr_use_existing_annotations_hint": "El OCR se ejecuta solo dentro de las regiones ya marcadas — sin nueva detección de diseño.",
|
||||
"transcription_block_review": "Marcar como revisado",
|
||||
"transcription_block_unreview": "Desmarcar como revisado",
|
||||
"transcription_reviewed_count": "{reviewed} de {total} revisados",
|
||||
|
||||
@@ -7,10 +7,11 @@ import ScriptTypeSelect from './ScriptTypeSelect.svelte';
|
||||
interface Props {
|
||||
existingBlockCount: number;
|
||||
storedScriptType: string;
|
||||
onTrigger: (scriptType: string) => void;
|
||||
annotationCount?: number;
|
||||
onTrigger: (scriptType: string, useExistingAnnotations: boolean) => void;
|
||||
}
|
||||
|
||||
let { existingBlockCount, storedScriptType, onTrigger }: Props = $props();
|
||||
let { existingBlockCount, storedScriptType, annotationCount = 0, onTrigger }: Props = $props();
|
||||
|
||||
const { confirm } = getConfirmService();
|
||||
|
||||
@@ -18,10 +19,12 @@ let selectedScriptType: string = $state(
|
||||
untrack(() => (storedScriptType && storedScriptType !== 'UNKNOWN' ? storedScriptType : ''))
|
||||
);
|
||||
|
||||
let useExistingAnnotations: boolean = $state(false);
|
||||
|
||||
async function handleClick() {
|
||||
if (!selectedScriptType) return;
|
||||
|
||||
if (existingBlockCount > 0) {
|
||||
if (!useExistingAnnotations && existingBlockCount > 0) {
|
||||
const confirmed = await confirm({
|
||||
title: m.ocr_confirm_title(),
|
||||
body: m.ocr_confirm_body({ count: String(existingBlockCount) }),
|
||||
@@ -31,12 +34,27 @@ async function handleClick() {
|
||||
if (!confirmed) return;
|
||||
}
|
||||
|
||||
onTrigger(selectedScriptType);
|
||||
onTrigger(selectedScriptType, useExistingAnnotations);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col gap-3">
|
||||
<ScriptTypeSelect bind:value={selectedScriptType} />
|
||||
{#if annotationCount > 0}
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="flex cursor-pointer items-center gap-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
bind:checked={useExistingAnnotations}
|
||||
class="h-4 w-4 cursor-pointer rounded-sm border-brand-navy/30 accent-brand-navy"
|
||||
/>
|
||||
<span class="font-sans text-sm font-medium text-brand-navy">
|
||||
{m.ocr_use_existing_annotations()}
|
||||
</span>
|
||||
</label>
|
||||
<p class="pl-6 text-xs text-ink-3">{m.ocr_use_existing_annotations_hint()}</p>
|
||||
</div>
|
||||
{/if}
|
||||
<button
|
||||
type="button"
|
||||
disabled={!selectedScriptType}
|
||||
|
||||
@@ -19,7 +19,7 @@ type Props = {
|
||||
onSaveBlock: (blockId: string, text: string) => Promise<void>;
|
||||
onDeleteBlock: (blockId: string) => Promise<void>;
|
||||
onReviewToggle: (blockId: string) => Promise<void>;
|
||||
onTriggerOcr?: (scriptType: string) => void;
|
||||
onTriggerOcr?: (scriptType: string, useExistingAnnotations: boolean) => void;
|
||||
canWrite?: boolean;
|
||||
trainingLabels?: string[];
|
||||
onToggleTrainingLabel?: (label: string, enrolled: boolean) => Promise<void>;
|
||||
@@ -370,6 +370,7 @@ $effect(() => {
|
||||
<div class="mt-3 max-w-xs">
|
||||
<OcrTrigger
|
||||
existingBlockCount={blocks.length}
|
||||
annotationCount={blocks.length}
|
||||
storedScriptType={storedScriptType}
|
||||
onTrigger={onTriggerOcr}
|
||||
/>
|
||||
|
||||
@@ -144,14 +144,14 @@ let ocrErrorMessage = $state('');
|
||||
let ocrPollTimer = $state<ReturnType<typeof setInterval> | null>(null);
|
||||
let ocrSkippedPages = $state(0);
|
||||
|
||||
async function triggerOcr(scriptType: string) {
|
||||
async function triggerOcr(scriptType: string, useExistingAnnotations: boolean = false) {
|
||||
ocrRunning = true;
|
||||
ocrErrorMessage = '';
|
||||
try {
|
||||
const res = await fetch(`/api/documents/${doc.id}/ocr`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ scriptType })
|
||||
body: JSON.stringify({ scriptType, useExistingAnnotations })
|
||||
});
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
|
||||
@@ -79,6 +79,29 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
||||
return blocks
|
||||
|
||||
|
||||
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
||||
"""Crop image to a normalized region and run Kraken recognition on the crop.
|
||||
|
||||
Used for guided OCR — skips full-page layout detection and only processes
|
||||
the given bounding box. Coordinates are normalized to [0, 1].
|
||||
"""
|
||||
from kraken import blla, rpred
|
||||
|
||||
if _model is None:
|
||||
raise RuntimeError("Kraken model is not loaded")
|
||||
|
||||
pw, ph = image.size
|
||||
x1 = max(0, int(x * pw))
|
||||
y1 = max(0, int(y * ph))
|
||||
x2 = min(pw, int((x + w) * pw))
|
||||
y2 = min(ph, int((y + h) * ph))
|
||||
crop = image.crop((x1, y1, x2, y2))
|
||||
|
||||
baseline_seg = blla.segment(crop)
|
||||
pred_it = rpred.rpred(_model, crop, baseline_seg)
|
||||
return " ".join(r.prediction for r in pred_it)
|
||||
|
||||
|
||||
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
||||
"""Run Kraken segmentation + recognition on a list of PIL images.
|
||||
|
||||
|
||||
@@ -81,6 +81,25 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
||||
return blocks
|
||||
|
||||
|
||||
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
||||
"""Crop image to a normalized region and run Surya recognition on the crop.
|
||||
|
||||
Used for guided OCR — skips full-page layout detection and only processes
|
||||
the given bounding box. Coordinates are normalized to [0, 1].
|
||||
"""
|
||||
load_models()
|
||||
|
||||
pw, ph = image.size
|
||||
x1 = max(0, int(x * pw))
|
||||
y1 = max(0, int(y * ph))
|
||||
x2 = min(pw, int((x + w) * pw))
|
||||
y2 = min(ph, int((y + h) * ph))
|
||||
crop = image.crop((x1, y1, x2, y2))
|
||||
|
||||
predictions = _recognition_predictor([crop], det_predictor=_detection_predictor)
|
||||
return " ".join(line.text for line in predictions[0].text_lines)
|
||||
|
||||
|
||||
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
||||
"""Run Surya OCR on a list of PIL images (one per page).
|
||||
|
||||
|
||||
@@ -108,7 +108,12 @@ async def run_ocr(request: OcrRequest):
|
||||
|
||||
@app.post("/ocr/stream")
|
||||
async def run_ocr_stream(request: OcrRequest):
|
||||
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page."""
|
||||
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page.
|
||||
|
||||
When request.regions is provided, runs in guided mode: each region is cropped
|
||||
and recognized individually, skipping full-page layout detection. The response
|
||||
blocks include the annotationId from the region.
|
||||
"""
|
||||
if not _models_ready:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||
|
||||
@@ -123,6 +128,81 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
detail="Kraken model not available — cannot process Kurrent script",
|
||||
)
|
||||
|
||||
engine = kraken_engine if use_kraken else surya_engine
|
||||
|
||||
if request.regions:
|
||||
# Guided mode: recognize only the user-drawn annotation regions
|
||||
regions_by_page: dict[int, list] = {}
|
||||
for region in request.regions:
|
||||
regions_by_page.setdefault(region.pageNumber, []).append(region)
|
||||
|
||||
async def generate_guided():
|
||||
total_pages = len(images)
|
||||
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
||||
|
||||
total_blocks = 0
|
||||
skipped_pages = 0
|
||||
|
||||
for page_idx, image in enumerate(images, start=1):
|
||||
page_regions = regions_by_page.get(page_idx, [])
|
||||
if not page_regions:
|
||||
yield json.dumps({
|
||||
"type": "page",
|
||||
"pageNumber": page_idx,
|
||||
"blocks": [],
|
||||
}) + "\n"
|
||||
del image
|
||||
continue
|
||||
|
||||
try:
|
||||
blocks = []
|
||||
for region in page_regions:
|
||||
text = await asyncio.to_thread(
|
||||
engine.extract_region_text, image,
|
||||
region.x, region.y, region.width, region.height,
|
||||
)
|
||||
blocks.append({
|
||||
"pageNumber": page_idx,
|
||||
"x": region.x,
|
||||
"y": region.y,
|
||||
"width": region.width,
|
||||
"height": region.height,
|
||||
"polygon": None,
|
||||
"text": text,
|
||||
"annotationId": region.annotationId,
|
||||
})
|
||||
|
||||
total_blocks += len(blocks)
|
||||
yield json.dumps({
|
||||
"type": "page",
|
||||
"pageNumber": page_idx,
|
||||
"blocks": blocks,
|
||||
}) + "\n"
|
||||
|
||||
except Exception:
|
||||
logger.exception("Guided OCR failed on page %d", page_idx)
|
||||
skipped_pages += 1
|
||||
yield json.dumps({
|
||||
"type": "error",
|
||||
"pageNumber": page_idx,
|
||||
"message": f"Guided OCR processing failed on page {page_idx}",
|
||||
}) + "\n"
|
||||
|
||||
finally:
|
||||
del image
|
||||
|
||||
yield json.dumps({
|
||||
"type": "done",
|
||||
"totalBlocks": total_blocks,
|
||||
"skippedPages": skipped_pages,
|
||||
}) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_guided(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"},
|
||||
)
|
||||
|
||||
async def generate():
|
||||
total_pages = len(images)
|
||||
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
||||
@@ -132,7 +212,6 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
|
||||
for page_idx, image in enumerate(images, start=1):
|
||||
try:
|
||||
engine = kraken_engine if use_kraken else surya_engine
|
||||
blocks = await asyncio.to_thread(
|
||||
engine.extract_page_blocks, image, page_idx, request.language
|
||||
)
|
||||
|
||||
@@ -1,12 +1,24 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class OcrRegion(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
annotationId: str
|
||||
pageNumber: int
|
||||
x: float
|
||||
y: float
|
||||
width: float
|
||||
height: float
|
||||
|
||||
|
||||
class OcrRequest(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
pdfUrl: str
|
||||
scriptType: str = "UNKNOWN"
|
||||
language: str = "de"
|
||||
regions: list[OcrRegion] | None = None
|
||||
|
||||
|
||||
class OcrBlock(BaseModel):
|
||||
@@ -19,3 +31,4 @@ class OcrBlock(BaseModel):
|
||||
height: float
|
||||
polygon: list[list[float]] | None = None
|
||||
text: str
|
||||
annotationId: str | None = None
|
||||
|
||||
Reference in New Issue
Block a user