feat(ocr): add guided OCR mode using existing annotation regions

When a document has manually drawn annotation boxes, the user can now
enable "Nur annotierte Bereiche" in the OCR trigger panel. The engine
skips layout detection entirely and runs recognition only within the
pre-drawn bounding boxes, preserving manual transcription blocks.

- Python: adds OcrRegion model, extend OcrRequest/OcrBlock; guided
  branch in /ocr/stream groups by page and crops each region
- Engines: add extract_region_text() to both Kraken and Surya
- Java: adds OcrBlockResult.annotationId, OcrClient.OcrRegion,
  TriggerOcrDTO.useExistingAnnotations; OcrAsyncRunner dispatches to
  upsertGuidedBlock when annotationId is present; OcrService threads
  the flag through to runSingleDocument
- TranscriptionService: adds upsertGuidedBlock (creates, updates OCR,
  or preserves MANUAL blocks)
- Frontend: guided OCR toggle in OcrTrigger shown when blocks exist;
  skips destructive-replace confirmation in guided mode

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-13 15:57:54 +02:00
parent 9b2f91ee59
commit ee58b63517
25 changed files with 380 additions and 55 deletions

View File

@@ -51,7 +51,8 @@ public class OcrController {
@RequestBody TriggerOcrDTO dto,
Authentication authentication) {
UUID userId = resolveUserId(authentication);
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId);
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId,
Boolean.TRUE.equals(dto.getUseExistingAnnotations()));
return Map.of("jobId", jobId);
}

View File

@@ -10,4 +10,5 @@ import org.raddatz.familienarchiv.model.ScriptType;
@AllArgsConstructor
public class TriggerOcrDTO {
private ScriptType scriptType;
private Boolean useExistingAnnotations = false;
}

View File

@@ -14,6 +14,8 @@ public interface TranscriptionBlockRepository extends JpaRepository<Transcriptio
Optional<TranscriptionBlock> findByIdAndDocumentId(UUID id, UUID documentId);
Optional<TranscriptionBlock> findByAnnotationId(UUID annotationId);
int countByDocumentId(UUID documentId);
@Query("""

View File

@@ -32,6 +32,11 @@ public class OcrAsyncRunner {
@Async
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
runSingleDocument(jobId, documentId, userId, false);
}
@Async
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId, boolean useExistingAnnotations) {
OcrJob job = ocrJobRepository.findById(jobId).orElse(null);
if (job == null) return;
@@ -49,7 +54,18 @@ public class OcrAsyncRunner {
try {
updateProgress(job, "LOADING");
clearExistingBlocks(documentId);
List<OcrClient.OcrRegion> regions = null;
if (useExistingAnnotations) {
regions = annotationService.listAnnotations(documentId).stream()
.map(a -> new OcrClient.OcrRegion(
a.getId().toString(), a.getPageNumber(),
a.getX(), a.getY(), a.getWidth(), a.getHeight()))
.toList();
} else {
clearExistingBlocks(documentId);
}
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
AtomicInteger blockCounter = new AtomicInteger(0);
@@ -57,7 +73,7 @@ public class OcrAsyncRunner {
AtomicInteger skippedPages = new AtomicInteger(0);
AtomicInteger totalPages = new AtomicInteger(0);
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), event -> {
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, event -> {
switch (event) {
case OcrStreamEvent.Start start -> {
totalPages.set(start.totalPages());
@@ -204,14 +220,21 @@ public class OcrAsyncRunner {
void createSingleBlock(UUID documentId, OcrBlockResult block,
UUID userId, String fileHash, int sortOrder) {
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
block.pageNumber(), block.x(), block.y(),
block.width(), block.height(), OCR_ANNOTATION_COLOR);
if (block.annotationId() != null) {
// Guided mode — annotation already exists; upsert the text block only
transcriptionService.upsertGuidedBlock(
documentId, UUID.fromString(block.annotationId()), block.text(), userId);
} else {
// Normal mode — create a new annotation and a new OCR block
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
block.pageNumber(), block.x(), block.y(),
block.width(), block.height(), OCR_ANNOTATION_COLOR);
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
documentId, annotationDTO, userId, fileHash, block.polygon());
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
documentId, annotationDTO, userId, fileHash, block.polygon());
transcriptionService.createOcrBlock(documentId, annotation.getId(),
block.text(), sortOrder, userId);
transcriptionService.createOcrBlock(documentId, annotation.getId(),
block.text(), sortOrder, userId);
}
}
}

View File

@@ -12,5 +12,6 @@ public record OcrBlockResult(
double width,
double height,
List<List<Double>> polygon,
String text
String text,
String annotationId // null in normal mode; set in guided mode to link back to existing annotation
) {}

View File

@@ -10,6 +10,14 @@ import java.util.function.Consumer;
public interface OcrClient {
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
/**
* A pre-drawn annotation region to use as guidance for OCR.
* When regions are provided, the OCR engine crops to each region and
* runs recognition only within that area, skipping full-page layout detection.
*/
record OcrRegion(String annotationId, int pageNumber,
double x, double y, double width, double height) {}
/**
* Send a training ZIP to the OCR service for fine-tuning the Kurrent model.
*
@@ -32,8 +40,12 @@ public interface OcrClient {
* Stream OCR results page-by-page via NDJSON. Implementations should override
* this method. The default exists only for backward compatibility during migration
* — it calls extractBlocks() and synthesizes events from the collected result.
*
* @param regions optional list of pre-drawn annotation regions; when non-null,
* the OCR service runs in guided mode (crop + recognize per region)
*/
default void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
default void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();

View File

@@ -52,6 +52,11 @@ public class OcrService {
}
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId) {
return startOcr(documentId, scriptTypeOverride, userId, false);
}
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId,
boolean useExistingAnnotations) {
Document doc = documentService.getDocumentById(documentId);
if (doc.getStatus() == DocumentStatus.PLACEHOLDER) {
@@ -85,7 +90,7 @@ public class OcrService {
.build();
ocrJobDocumentRepository.save(jobDoc);
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId);
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId, useExistingAnnotations);
return job.getId();
}
}

View File

@@ -175,13 +175,18 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
}
@Override
public void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
public void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
String body;
try {
body = NDJSON_MAPPER.writeValueAsString(Map.of(
"pdfUrl", pdfUrl,
"scriptType", scriptType.name(),
"language", "de"));
var requestMap = new java.util.LinkedHashMap<String, Object>();
requestMap.put("pdfUrl", pdfUrl);
requestMap.put("scriptType", scriptType.name());
requestMap.put("language", "de");
if (regions != null && !regions.isEmpty()) {
requestMap.put("regions", regions);
}
body = NDJSON_MAPPER.writeValueAsString(requestMap);
} catch (IOException e) {
throw new RuntimeException("Failed to serialize OCR request", e);
}
@@ -199,7 +204,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
if (response.statusCode() == 404) {
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
OcrClient.super.streamBlocks(pdfUrl, scriptType, handler);
OcrClient.super.streamBlocks(pdfUrl, scriptType, regions, handler);
return;
}
@@ -259,10 +264,11 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
double width,
double height,
List<List<Double>> polygon,
String text
String text,
String annotationId
) {
OcrBlockResult toResult() {
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text);
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text, annotationId);
}
}
}

View File

@@ -94,6 +94,27 @@ public class TranscriptionService {
return saved;
}
/**
* Upsert an OCR transcription block for a pre-existing annotation (guided OCR mode).
* If the annotation already has a MANUAL block, it is left unchanged.
* If it has an OCR block, the text is updated in-place.
* If it has no block yet, a new OCR block is created.
*/
@Transactional
public TranscriptionBlock upsertGuidedBlock(UUID documentId, UUID annotationId,
String text, UUID userId) {
return blockRepository.findByAnnotationId(annotationId).map(existing -> {
if (existing.getSource() == BlockSource.MANUAL) {
return existing; // never overwrite manual transcription
}
existing.setText(sanitizeText(text));
existing.setUpdatedBy(userId);
TranscriptionBlock saved = blockRepository.save(existing);
saveVersion(saved, userId);
return saved;
}).orElseGet(() -> createOcrBlock(documentId, annotationId, text, 0, userId));
}
@Transactional
public TranscriptionBlock updateBlock(UUID documentId, UUID blockId,
UpdateTranscriptionBlockDTO dto, UUID userId) {

View File

@@ -24,6 +24,7 @@ import java.util.List;
import java.util.UUID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
@@ -52,9 +53,9 @@ class OcrControllerTest {
void triggerOcr_returns202_withJobId() throws Exception {
UUID docId = UUID.randomUUID();
UUID jobId = UUID.randomUUID();
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER);
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER, false);
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any())).thenReturn(jobId);
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any(), anyBoolean())).thenReturn(jobId);
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
.contentType(MediaType.APPLICATION_JSON)
@@ -67,7 +68,7 @@ class OcrControllerTest {
@WithMockUser(authorities = "WRITE_ALL")
void triggerOcr_returns400_whenDocumentNotUploaded() throws Exception {
UUID docId = UUID.randomUUID();
when(ocrService.startOcr(eq(docId), any(), any()))
when(ocrService.startOcr(eq(docId), any(), any(), anyBoolean()))
.thenThrow(DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED, "Not uploaded"));
mockMvc.perform(post("/api/documents/{id}/ocr", docId)

View File

@@ -60,8 +60,8 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2")));
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null)));
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
@@ -84,7 +84,7 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test")));
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null)));
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
@@ -112,12 +112,12 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(1));
handler.accept(new OcrStreamEvent.Page(0, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -142,7 +142,7 @@ class OcrAsyncRunnerTest {
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any());
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -174,18 +174,18 @@ class OcrAsyncRunnerTest {
List<String> progressMessages = new ArrayList<>();
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(3));
handler.accept(new OcrStreamEvent.Page(0, List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2"))));
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2", null))));
progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Page(1, List.of(
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3"))));
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3", null))));
progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Done(3, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -215,14 +215,14 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(3));
handler.accept(new OcrStreamEvent.Page(0, List.of()));
handler.accept(new OcrStreamEvent.Error(1, "failed"));
handler.accept(new OcrStreamEvent.Page(2, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 1));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -251,13 +251,13 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(2));
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
handler.accept(new OcrStreamEvent.Page(1, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 1));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);

View File

@@ -15,16 +15,16 @@ class OcrClientDefaultStreamTest {
OcrClient client = new OcrClient() {
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
return List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null),
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3", null));
}
@Override public TrainingResult trainModel(byte[] zip) { return null; }
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
};
List<OcrStreamEvent> events = new ArrayList<>();
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
assertThat(events).hasSize(4);
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
@@ -55,7 +55,7 @@ class OcrClientDefaultStreamTest {
};
List<OcrStreamEvent> events = new ArrayList<>();
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
assertThat(events).hasSize(2);
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);

View File

@@ -142,7 +142,7 @@ class OcrServiceTest {
UUID result = ocrService.startOcr(docId, null, userId);
assertThat(result).isEqualTo(jobId);
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId);
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId, false);
}
@Test

View File

@@ -17,7 +17,7 @@ class OcrStreamEventTest {
@Test
void pageRecordHoldsBlocksAndPageNumber() {
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test");
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test", null);
var page = new OcrStreamEvent.Page(0, List.of(block));
assertThat(page.pageNumber()).isEqualTo(0);
assertThat(page.blocks()).hasSize(1);

View File

@@ -0,0 +1,93 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.model.BlockSource;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.AnnotationRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockVersionRepository;
import java.util.Optional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
class TranscriptionServiceGuidedTest {
TranscriptionBlockRepository blockRepository;
TranscriptionBlockVersionRepository versionRepository;
AnnotationRepository annotationRepository;
AnnotationService annotationService;
DocumentService documentService;
TranscriptionService service;
UUID docId = UUID.randomUUID();
UUID annId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
@BeforeEach
void setUp() {
blockRepository = mock(TranscriptionBlockRepository.class);
versionRepository = mock(TranscriptionBlockVersionRepository.class);
annotationRepository = mock(AnnotationRepository.class);
annotationService = mock(AnnotationService.class);
documentService = mock(DocumentService.class);
service = new TranscriptionService(blockRepository, versionRepository,
annotationRepository, annotationService, documentService);
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
}
@Test
void upsertGuidedBlock_createsNewBlock_whenAnnotationHasNoBlock() {
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.empty());
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "Hallo", userId);
assertThat(result.getText()).isEqualTo("Hallo");
assertThat(result.getSource()).isEqualTo(BlockSource.OCR);
assertThat(result.getAnnotationId()).isEqualTo(annId);
verify(blockRepository).save(any());
}
@Test
void upsertGuidedBlock_updatesExistingOcrBlock_whenAnnotationHasOcrBlock() {
TranscriptionBlock existing = TranscriptionBlock.builder()
.id(UUID.randomUUID())
.annotationId(annId)
.documentId(docId)
.text("old text")
.source(BlockSource.OCR)
.sortOrder(0)
.build();
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(existing));
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "new text", userId);
assertThat(result.getText()).isEqualTo("new text");
verify(blockRepository).save(any());
}
@Test
void upsertGuidedBlock_doesNotOverwriteManualBlock() {
TranscriptionBlock manual = TranscriptionBlock.builder()
.id(UUID.randomUUID())
.annotationId(annId)
.documentId(docId)
.text("manually written")
.source(BlockSource.MANUAL)
.sortOrder(0)
.build();
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(manual));
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "ocr result", userId);
assertThat(result.getText()).isEqualTo("manually written");
verify(blockRepository, never()).save(any());
}
}