feat(ocr): add guided OCR mode using existing annotation regions
When a document has manually drawn annotation boxes, the user can now enable "Nur annotierte Bereiche" in the OCR trigger panel. The engine skips layout detection entirely and runs recognition only within the pre-drawn bounding boxes, preserving manual transcription blocks. - Python: adds OcrRegion model, extend OcrRequest/OcrBlock; guided branch in /ocr/stream groups by page and crops each region - Engines: add extract_region_text() to both Kraken and Surya - Java: adds OcrBlockResult.annotationId, OcrClient.OcrRegion, TriggerOcrDTO.useExistingAnnotations; OcrAsyncRunner dispatches to upsertGuidedBlock when annotationId is present; OcrService threads the flag through to runSingleDocument - TranscriptionService: adds upsertGuidedBlock (creates, updates OCR, or preserves MANUAL blocks) - Frontend: guided OCR toggle in OcrTrigger shown when blocks exist; skips destructive-replace confirmation in guided mode Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -24,6 +24,7 @@ import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
||||
@@ -52,9 +53,9 @@ class OcrControllerTest {
|
||||
void triggerOcr_returns202_withJobId() throws Exception {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID jobId = UUID.randomUUID();
|
||||
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER);
|
||||
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER, false);
|
||||
|
||||
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any())).thenReturn(jobId);
|
||||
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any(), anyBoolean())).thenReturn(jobId);
|
||||
|
||||
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
@@ -67,7 +68,7 @@ class OcrControllerTest {
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void triggerOcr_returns400_whenDocumentNotUploaded() throws Exception {
|
||||
UUID docId = UUID.randomUUID();
|
||||
when(ocrService.startOcr(eq(docId), any(), any()))
|
||||
when(ocrService.startOcr(eq(docId), any(), any(), anyBoolean()))
|
||||
.thenThrow(DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED, "Not uploaded"));
|
||||
|
||||
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
|
||||
|
||||
@@ -60,8 +60,8 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2")));
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null)));
|
||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||
|
||||
@@ -84,7 +84,7 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test")));
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null)));
|
||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||
|
||||
@@ -112,12 +112,12 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(1));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -142,7 +142,7 @@ class OcrAsyncRunnerTest {
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -174,18 +174,18 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
List<String> progressMessages = new ArrayList<>();
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(3));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2"))));
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2", null))));
|
||||
progressMessages.add(job.getProgressMessage());
|
||||
handler.accept(new OcrStreamEvent.Page(1, List.of(
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3"))));
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3", null))));
|
||||
progressMessages.add(job.getProgressMessage());
|
||||
handler.accept(new OcrStreamEvent.Done(3, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -215,14 +215,14 @@ class OcrAsyncRunnerTest {
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(3));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Error(1, "failed"));
|
||||
handler.accept(new OcrStreamEvent.Page(2, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 1));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -251,13 +251,13 @@ class OcrAsyncRunnerTest {
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(2));
|
||||
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
|
||||
handler.accept(new OcrStreamEvent.Page(1, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 1));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
|
||||
@@ -15,16 +15,16 @@ class OcrClientDefaultStreamTest {
|
||||
OcrClient client = new OcrClient() {
|
||||
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
|
||||
return List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null),
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3", null));
|
||||
}
|
||||
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
|
||||
};
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
|
||||
|
||||
assertThat(events).hasSize(4);
|
||||
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
||||
@@ -55,7 +55,7 @@ class OcrClientDefaultStreamTest {
|
||||
};
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
|
||||
|
||||
assertThat(events).hasSize(2);
|
||||
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
||||
|
||||
@@ -142,7 +142,7 @@ class OcrServiceTest {
|
||||
UUID result = ocrService.startOcr(docId, null, userId);
|
||||
|
||||
assertThat(result).isEqualTo(jobId);
|
||||
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId);
|
||||
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -17,7 +17,7 @@ class OcrStreamEventTest {
|
||||
|
||||
@Test
|
||||
void pageRecordHoldsBlocksAndPageNumber() {
|
||||
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test");
|
||||
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test", null);
|
||||
var page = new OcrStreamEvent.Page(0, List.of(block));
|
||||
assertThat(page.pageNumber()).isEqualTo(0);
|
||||
assertThat(page.blocks()).hasSize(1);
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.model.BlockSource;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.AnnotationRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockVersionRepository;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class TranscriptionServiceGuidedTest {
|
||||
|
||||
TranscriptionBlockRepository blockRepository;
|
||||
TranscriptionBlockVersionRepository versionRepository;
|
||||
AnnotationRepository annotationRepository;
|
||||
AnnotationService annotationService;
|
||||
DocumentService documentService;
|
||||
TranscriptionService service;
|
||||
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID annId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||
versionRepository = mock(TranscriptionBlockVersionRepository.class);
|
||||
annotationRepository = mock(AnnotationRepository.class);
|
||||
annotationService = mock(AnnotationService.class);
|
||||
documentService = mock(DocumentService.class);
|
||||
|
||||
service = new TranscriptionService(blockRepository, versionRepository,
|
||||
annotationRepository, annotationService, documentService);
|
||||
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
void upsertGuidedBlock_createsNewBlock_whenAnnotationHasNoBlock() {
|
||||
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.empty());
|
||||
|
||||
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "Hallo", userId);
|
||||
|
||||
assertThat(result.getText()).isEqualTo("Hallo");
|
||||
assertThat(result.getSource()).isEqualTo(BlockSource.OCR);
|
||||
assertThat(result.getAnnotationId()).isEqualTo(annId);
|
||||
verify(blockRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void upsertGuidedBlock_updatesExistingOcrBlock_whenAnnotationHasOcrBlock() {
|
||||
TranscriptionBlock existing = TranscriptionBlock.builder()
|
||||
.id(UUID.randomUUID())
|
||||
.annotationId(annId)
|
||||
.documentId(docId)
|
||||
.text("old text")
|
||||
.source(BlockSource.OCR)
|
||||
.sortOrder(0)
|
||||
.build();
|
||||
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(existing));
|
||||
|
||||
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "new text", userId);
|
||||
|
||||
assertThat(result.getText()).isEqualTo("new text");
|
||||
verify(blockRepository).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void upsertGuidedBlock_doesNotOverwriteManualBlock() {
|
||||
TranscriptionBlock manual = TranscriptionBlock.builder()
|
||||
.id(UUID.randomUUID())
|
||||
.annotationId(annId)
|
||||
.documentId(docId)
|
||||
.text("manually written")
|
||||
.source(BlockSource.MANUAL)
|
||||
.sortOrder(0)
|
||||
.build();
|
||||
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(manual));
|
||||
|
||||
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "ocr result", userId);
|
||||
|
||||
assertThat(result.getText()).isEqualTo("manually written");
|
||||
verify(blockRepository, never()).save(any());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user