feat(ocr): full OCR pipeline with polygon annotations, training, and guided mode #232

Merged
marcel merged 40 commits from feat/issue-226-227-ocr-pipeline-polygon into main 2026-04-14 10:31:35 +02:00
25 changed files with 380 additions and 55 deletions
Showing only changes of commit ee58b63517 - Show all commits

View File

@@ -51,7 +51,8 @@ public class OcrController {
@RequestBody TriggerOcrDTO dto,
Authentication authentication) {
UUID userId = resolveUserId(authentication);
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId);
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId,
Boolean.TRUE.equals(dto.getUseExistingAnnotations()));
return Map.of("jobId", jobId);
}

View File

@@ -10,4 +10,5 @@ import org.raddatz.familienarchiv.model.ScriptType;
@AllArgsConstructor
public class TriggerOcrDTO {
private ScriptType scriptType;
private Boolean useExistingAnnotations = false;
}

View File

@@ -14,6 +14,8 @@ public interface TranscriptionBlockRepository extends JpaRepository<Transcriptio
Optional<TranscriptionBlock> findByIdAndDocumentId(UUID id, UUID documentId);
Optional<TranscriptionBlock> findByAnnotationId(UUID annotationId);
int countByDocumentId(UUID documentId);
@Query("""

View File

@@ -32,6 +32,11 @@ public class OcrAsyncRunner {
@Async
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
runSingleDocument(jobId, documentId, userId, false);
}
@Async
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId, boolean useExistingAnnotations) {
OcrJob job = ocrJobRepository.findById(jobId).orElse(null);
if (job == null) return;
@@ -49,7 +54,18 @@ public class OcrAsyncRunner {
try {
updateProgress(job, "LOADING");
clearExistingBlocks(documentId);
List<OcrClient.OcrRegion> regions = null;
if (useExistingAnnotations) {
regions = annotationService.listAnnotations(documentId).stream()
.map(a -> new OcrClient.OcrRegion(
a.getId().toString(), a.getPageNumber(),
a.getX(), a.getY(), a.getWidth(), a.getHeight()))
.toList();
} else {
clearExistingBlocks(documentId);
}
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
AtomicInteger blockCounter = new AtomicInteger(0);
@@ -57,7 +73,7 @@ public class OcrAsyncRunner {
AtomicInteger skippedPages = new AtomicInteger(0);
AtomicInteger totalPages = new AtomicInteger(0);
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), event -> {
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, event -> {
switch (event) {
case OcrStreamEvent.Start start -> {
totalPages.set(start.totalPages());
@@ -204,14 +220,21 @@ public class OcrAsyncRunner {
void createSingleBlock(UUID documentId, OcrBlockResult block,
UUID userId, String fileHash, int sortOrder) {
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
block.pageNumber(), block.x(), block.y(),
block.width(), block.height(), OCR_ANNOTATION_COLOR);
if (block.annotationId() != null) {
// Guided mode — annotation already exists; upsert the text block only
transcriptionService.upsertGuidedBlock(
documentId, UUID.fromString(block.annotationId()), block.text(), userId);
} else {
// Normal mode — create a new annotation and a new OCR block
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
block.pageNumber(), block.x(), block.y(),
block.width(), block.height(), OCR_ANNOTATION_COLOR);
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
documentId, annotationDTO, userId, fileHash, block.polygon());
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
documentId, annotationDTO, userId, fileHash, block.polygon());
transcriptionService.createOcrBlock(documentId, annotation.getId(),
block.text(), sortOrder, userId);
transcriptionService.createOcrBlock(documentId, annotation.getId(),
block.text(), sortOrder, userId);
}
}
}

View File

@@ -12,5 +12,6 @@ public record OcrBlockResult(
double width,
double height,
List<List<Double>> polygon,
String text
String text,
String annotationId // null in normal mode; set in guided mode to link back to existing annotation
) {}

View File

@@ -10,6 +10,14 @@ import java.util.function.Consumer;
public interface OcrClient {
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
/**
* A pre-drawn annotation region to use as guidance for OCR.
* When regions are provided, the OCR engine crops to each region and
* runs recognition only within that area, skipping full-page layout detection.
*/
record OcrRegion(String annotationId, int pageNumber,
double x, double y, double width, double height) {}
/**
* Send a training ZIP to the OCR service for fine-tuning the Kurrent model.
*
@@ -32,8 +40,12 @@ public interface OcrClient {
* Stream OCR results page-by-page via NDJSON. Implementations should override
* this method. The default exists only for backward compatibility during migration
* — it calls extractBlocks() and synthesizes events from the collected result.
*
* @param regions optional list of pre-drawn annotation regions; when non-null,
* the OCR service runs in guided mode (crop + recognize per region)
*/
default void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
default void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();

View File

@@ -52,6 +52,11 @@ public class OcrService {
}
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId) {
return startOcr(documentId, scriptTypeOverride, userId, false);
}
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId,
boolean useExistingAnnotations) {
Document doc = documentService.getDocumentById(documentId);
if (doc.getStatus() == DocumentStatus.PLACEHOLDER) {
@@ -85,7 +90,7 @@ public class OcrService {
.build();
ocrJobDocumentRepository.save(jobDoc);
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId);
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId, useExistingAnnotations);
return job.getId();
}
}

View File

@@ -175,13 +175,18 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
}
@Override
public void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
public void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
String body;
try {
body = NDJSON_MAPPER.writeValueAsString(Map.of(
"pdfUrl", pdfUrl,
"scriptType", scriptType.name(),
"language", "de"));
var requestMap = new java.util.LinkedHashMap<String, Object>();
requestMap.put("pdfUrl", pdfUrl);
requestMap.put("scriptType", scriptType.name());
requestMap.put("language", "de");
if (regions != null && !regions.isEmpty()) {
requestMap.put("regions", regions);
}
body = NDJSON_MAPPER.writeValueAsString(requestMap);
} catch (IOException e) {
throw new RuntimeException("Failed to serialize OCR request", e);
}
@@ -199,7 +204,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
if (response.statusCode() == 404) {
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
OcrClient.super.streamBlocks(pdfUrl, scriptType, handler);
OcrClient.super.streamBlocks(pdfUrl, scriptType, regions, handler);
return;
}
@@ -259,10 +264,11 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
double width,
double height,
List<List<Double>> polygon,
String text
String text,
String annotationId
) {
OcrBlockResult toResult() {
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text);
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text, annotationId);
}
}
}

View File

@@ -94,6 +94,27 @@ public class TranscriptionService {
return saved;
}
/**
* Upsert an OCR transcription block for a pre-existing annotation (guided OCR mode).
* If the annotation already has a MANUAL block, it is left unchanged.
* If it has an OCR block, the text is updated in-place.
* If it has no block yet, a new OCR block is created.
*/
@Transactional
public TranscriptionBlock upsertGuidedBlock(UUID documentId, UUID annotationId,
String text, UUID userId) {
return blockRepository.findByAnnotationId(annotationId).map(existing -> {
if (existing.getSource() == BlockSource.MANUAL) {
return existing; // never overwrite manual transcription
}
existing.setText(sanitizeText(text));
existing.setUpdatedBy(userId);
TranscriptionBlock saved = blockRepository.save(existing);
saveVersion(saved, userId);
return saved;
}).orElseGet(() -> createOcrBlock(documentId, annotationId, text, 0, userId));
}
@Transactional
public TranscriptionBlock updateBlock(UUID documentId, UUID blockId,
UpdateTranscriptionBlockDTO dto, UUID userId) {

View File

@@ -24,6 +24,7 @@ import java.util.List;
import java.util.UUID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.when;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
@@ -52,9 +53,9 @@ class OcrControllerTest {
void triggerOcr_returns202_withJobId() throws Exception {
UUID docId = UUID.randomUUID();
UUID jobId = UUID.randomUUID();
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER);
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER, false);
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any())).thenReturn(jobId);
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any(), anyBoolean())).thenReturn(jobId);
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
.contentType(MediaType.APPLICATION_JSON)
@@ -67,7 +68,7 @@ class OcrControllerTest {
@WithMockUser(authorities = "WRITE_ALL")
void triggerOcr_returns400_whenDocumentNotUploaded() throws Exception {
UUID docId = UUID.randomUUID();
when(ocrService.startOcr(eq(docId), any(), any()))
when(ocrService.startOcr(eq(docId), any(), any(), anyBoolean()))
.thenThrow(DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED, "Not uploaded"));
mockMvc.perform(post("/api/documents/{id}/ocr", docId)

View File

@@ -60,8 +60,8 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2")));
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null)));
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
@@ -84,7 +84,7 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test")));
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null)));
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
@@ -112,12 +112,12 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(1));
handler.accept(new OcrStreamEvent.Page(0, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -142,7 +142,7 @@ class OcrAsyncRunnerTest {
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any());
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -174,18 +174,18 @@ class OcrAsyncRunnerTest {
List<String> progressMessages = new ArrayList<>();
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(3));
handler.accept(new OcrStreamEvent.Page(0, List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2"))));
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2", null))));
progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Page(1, List.of(
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3"))));
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3", null))));
progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Done(3, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -215,14 +215,14 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(3));
handler.accept(new OcrStreamEvent.Page(0, List.of()));
handler.accept(new OcrStreamEvent.Error(1, "failed"));
handler.accept(new OcrStreamEvent.Page(2, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 1));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -251,13 +251,13 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(2));
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
handler.accept(new OcrStreamEvent.Page(1, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 1));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);

View File

@@ -15,16 +15,16 @@ class OcrClientDefaultStreamTest {
OcrClient client = new OcrClient() {
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
return List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null),
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3", null));
}
@Override public TrainingResult trainModel(byte[] zip) { return null; }
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
};
List<OcrStreamEvent> events = new ArrayList<>();
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
assertThat(events).hasSize(4);
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
@@ -55,7 +55,7 @@ class OcrClientDefaultStreamTest {
};
List<OcrStreamEvent> events = new ArrayList<>();
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
assertThat(events).hasSize(2);
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);

View File

@@ -142,7 +142,7 @@ class OcrServiceTest {
UUID result = ocrService.startOcr(docId, null, userId);
assertThat(result).isEqualTo(jobId);
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId);
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId, false);
}
@Test

View File

@@ -17,7 +17,7 @@ class OcrStreamEventTest {
@Test
void pageRecordHoldsBlocksAndPageNumber() {
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test");
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test", null);
var page = new OcrStreamEvent.Page(0, List.of(block));
assertThat(page.pageNumber()).isEqualTo(0);
assertThat(page.blocks()).hasSize(1);

View File

@@ -0,0 +1,93 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.model.BlockSource;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.AnnotationRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockVersionRepository;
import java.util.Optional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
class TranscriptionServiceGuidedTest {
TranscriptionBlockRepository blockRepository;
TranscriptionBlockVersionRepository versionRepository;
AnnotationRepository annotationRepository;
AnnotationService annotationService;
DocumentService documentService;
TranscriptionService service;
UUID docId = UUID.randomUUID();
UUID annId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
@BeforeEach
void setUp() {
blockRepository = mock(TranscriptionBlockRepository.class);
versionRepository = mock(TranscriptionBlockVersionRepository.class);
annotationRepository = mock(AnnotationRepository.class);
annotationService = mock(AnnotationService.class);
documentService = mock(DocumentService.class);
service = new TranscriptionService(blockRepository, versionRepository,
annotationRepository, annotationService, documentService);
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
}
@Test
void upsertGuidedBlock_createsNewBlock_whenAnnotationHasNoBlock() {
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.empty());
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "Hallo", userId);
assertThat(result.getText()).isEqualTo("Hallo");
assertThat(result.getSource()).isEqualTo(BlockSource.OCR);
assertThat(result.getAnnotationId()).isEqualTo(annId);
verify(blockRepository).save(any());
}
@Test
void upsertGuidedBlock_updatesExistingOcrBlock_whenAnnotationHasOcrBlock() {
TranscriptionBlock existing = TranscriptionBlock.builder()
.id(UUID.randomUUID())
.annotationId(annId)
.documentId(docId)
.text("old text")
.source(BlockSource.OCR)
.sortOrder(0)
.build();
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(existing));
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "new text", userId);
assertThat(result.getText()).isEqualTo("new text");
verify(blockRepository).save(any());
}
@Test
void upsertGuidedBlock_doesNotOverwriteManualBlock() {
TranscriptionBlock manual = TranscriptionBlock.builder()
.id(UUID.randomUUID())
.annotationId(annId)
.documentId(docId)
.text("manually written")
.source(BlockSource.MANUAL)
.sortOrder(0)
.build();
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(manual));
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "ocr result", userId);
assertThat(result.getText()).isEqualTo("manually written");
verify(blockRepository, never()).save(any());
}
}

View File

@@ -531,6 +531,8 @@
"ocr_status_analyzing_page": "Seite {current} von {total} wird analysiert…",
"ocr_status_done_skipped": "{count} Blöcke erstellt, {skipped} Seite(n) übersprungen",
"ocr_status_error": "OCR fehlgeschlagen",
"ocr_use_existing_annotations": "Nur annotierte Bereiche",
"ocr_use_existing_annotations_hint": "OCR wird nur innerhalb der bereits markierten Bereiche ausgeführt — keine neue Layout-Erkennung.",
"transcription_block_review": "Als geprüft markieren",
"transcription_block_unreview": "Markierung aufheben",
"transcription_reviewed_count": "{reviewed} von {total} geprüft",

View File

@@ -531,6 +531,8 @@
"ocr_status_analyzing_page": "Analyzing page {current} of {total}…",
"ocr_status_done_skipped": "{count} blocks created, {skipped} page(s) skipped",
"ocr_status_error": "OCR failed",
"ocr_use_existing_annotations": "Annotated regions only",
"ocr_use_existing_annotations_hint": "OCR runs only within the already marked regions — no new layout detection.",
"transcription_block_review": "Mark as reviewed",
"transcription_block_unreview": "Unmark as reviewed",
"transcription_reviewed_count": "{reviewed} of {total} reviewed",

View File

@@ -531,6 +531,8 @@
"ocr_status_analyzing_page": "Analizando página {current} de {total}…",
"ocr_status_done_skipped": "{count} bloques creados, {skipped} página(s) omitida(s)",
"ocr_status_error": "OCR fallido",
"ocr_use_existing_annotations": "Solo regiones anotadas",
"ocr_use_existing_annotations_hint": "El OCR se ejecuta solo dentro de las regiones ya marcadas — sin nueva detección de diseño.",
"transcription_block_review": "Marcar como revisado",
"transcription_block_unreview": "Desmarcar como revisado",
"transcription_reviewed_count": "{reviewed} de {total} revisados",

View File

@@ -7,10 +7,11 @@ import ScriptTypeSelect from './ScriptTypeSelect.svelte';
interface Props {
existingBlockCount: number;
storedScriptType: string;
onTrigger: (scriptType: string) => void;
annotationCount?: number;
onTrigger: (scriptType: string, useExistingAnnotations: boolean) => void;
}
let { existingBlockCount, storedScriptType, onTrigger }: Props = $props();
let { existingBlockCount, storedScriptType, annotationCount = 0, onTrigger }: Props = $props();
const { confirm } = getConfirmService();
@@ -18,10 +19,12 @@ let selectedScriptType: string = $state(
untrack(() => (storedScriptType && storedScriptType !== 'UNKNOWN' ? storedScriptType : ''))
);
let useExistingAnnotations: boolean = $state(false);
async function handleClick() {
if (!selectedScriptType) return;
if (existingBlockCount > 0) {
if (!useExistingAnnotations && existingBlockCount > 0) {
const confirmed = await confirm({
title: m.ocr_confirm_title(),
body: m.ocr_confirm_body({ count: String(existingBlockCount) }),
@@ -31,12 +34,27 @@ async function handleClick() {
if (!confirmed) return;
}
onTrigger(selectedScriptType);
onTrigger(selectedScriptType, useExistingAnnotations);
}
</script>
<div class="flex flex-col gap-3">
<ScriptTypeSelect bind:value={selectedScriptType} />
{#if annotationCount > 0}
<div class="flex flex-col gap-1">
<label class="flex cursor-pointer items-center gap-2">
<input
type="checkbox"
bind:checked={useExistingAnnotations}
class="h-4 w-4 cursor-pointer rounded-sm border-brand-navy/30 accent-brand-navy"
/>
<span class="font-sans text-sm font-medium text-brand-navy">
{m.ocr_use_existing_annotations()}
</span>
</label>
<p class="pl-6 text-xs text-ink-3">{m.ocr_use_existing_annotations_hint()}</p>
</div>
{/if}
<button
type="button"
disabled={!selectedScriptType}

View File

@@ -19,7 +19,7 @@ type Props = {
onSaveBlock: (blockId: string, text: string) => Promise<void>;
onDeleteBlock: (blockId: string) => Promise<void>;
onReviewToggle: (blockId: string) => Promise<void>;
onTriggerOcr?: (scriptType: string) => void;
onTriggerOcr?: (scriptType: string, useExistingAnnotations: boolean) => void;
canWrite?: boolean;
trainingLabels?: string[];
onToggleTrainingLabel?: (label: string, enrolled: boolean) => Promise<void>;
@@ -370,6 +370,7 @@ $effect(() => {
<div class="mt-3 max-w-xs">
<OcrTrigger
existingBlockCount={blocks.length}
annotationCount={blocks.length}
storedScriptType={storedScriptType}
onTrigger={onTriggerOcr}
/>

View File

@@ -144,14 +144,14 @@ let ocrErrorMessage = $state('');
let ocrPollTimer = $state<ReturnType<typeof setInterval> | null>(null);
let ocrSkippedPages = $state(0);
async function triggerOcr(scriptType: string) {
async function triggerOcr(scriptType: string, useExistingAnnotations: boolean = false) {
ocrRunning = true;
ocrErrorMessage = '';
try {
const res = await fetch(`/api/documents/${doc.id}/ocr`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ scriptType })
body: JSON.stringify({ scriptType, useExistingAnnotations })
});
if (res.ok) {
const data = await res.json();

View File

@@ -79,6 +79,29 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
return blocks
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
"""Crop image to a normalized region and run Kraken recognition on the crop.
Used for guided OCR — skips full-page layout detection and only processes
the given bounding box. Coordinates are normalized to [0, 1].
"""
from kraken import blla, rpred
if _model is None:
raise RuntimeError("Kraken model is not loaded")
pw, ph = image.size
x1 = max(0, int(x * pw))
y1 = max(0, int(y * ph))
x2 = min(pw, int((x + w) * pw))
y2 = min(ph, int((y + h) * ph))
crop = image.crop((x1, y1, x2, y2))
baseline_seg = blla.segment(crop)
pred_it = rpred.rpred(_model, crop, baseline_seg)
return " ".join(r.prediction for r in pred_it)
def extract_blocks(images: list, language: str = "de") -> list[dict]:
"""Run Kraken segmentation + recognition on a list of PIL images.

View File

@@ -81,6 +81,25 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
return blocks
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
"""Crop image to a normalized region and run Surya recognition on the crop.
Used for guided OCR — skips full-page layout detection and only processes
the given bounding box. Coordinates are normalized to [0, 1].
"""
load_models()
pw, ph = image.size
x1 = max(0, int(x * pw))
y1 = max(0, int(y * ph))
x2 = min(pw, int((x + w) * pw))
y2 = min(ph, int((y + h) * ph))
crop = image.crop((x1, y1, x2, y2))
predictions = _recognition_predictor([crop], det_predictor=_detection_predictor)
return " ".join(line.text for line in predictions[0].text_lines)
def extract_blocks(images: list, language: str = "de") -> list[dict]:
"""Run Surya OCR on a list of PIL images (one per page).

View File

@@ -108,7 +108,12 @@ async def run_ocr(request: OcrRequest):
@app.post("/ocr/stream")
async def run_ocr_stream(request: OcrRequest):
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page."""
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page.
When request.regions is provided, runs in guided mode: each region is cropped
and recognized individually, skipping full-page layout detection. The response
blocks include the annotationId from the region.
"""
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
@@ -123,6 +128,81 @@ async def run_ocr_stream(request: OcrRequest):
detail="Kraken model not available — cannot process Kurrent script",
)
engine = kraken_engine if use_kraken else surya_engine
if request.regions:
# Guided mode: recognize only the user-drawn annotation regions
regions_by_page: dict[int, list] = {}
for region in request.regions:
regions_by_page.setdefault(region.pageNumber, []).append(region)
async def generate_guided():
total_pages = len(images)
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
total_blocks = 0
skipped_pages = 0
for page_idx, image in enumerate(images, start=1):
page_regions = regions_by_page.get(page_idx, [])
if not page_regions:
yield json.dumps({
"type": "page",
"pageNumber": page_idx,
"blocks": [],
}) + "\n"
del image
continue
try:
blocks = []
for region in page_regions:
text = await asyncio.to_thread(
engine.extract_region_text, image,
region.x, region.y, region.width, region.height,
)
blocks.append({
"pageNumber": page_idx,
"x": region.x,
"y": region.y,
"width": region.width,
"height": region.height,
"polygon": None,
"text": text,
"annotationId": region.annotationId,
})
total_blocks += len(blocks)
yield json.dumps({
"type": "page",
"pageNumber": page_idx,
"blocks": blocks,
}) + "\n"
except Exception:
logger.exception("Guided OCR failed on page %d", page_idx)
skipped_pages += 1
yield json.dumps({
"type": "error",
"pageNumber": page_idx,
"message": f"Guided OCR processing failed on page {page_idx}",
}) + "\n"
finally:
del image
yield json.dumps({
"type": "done",
"totalBlocks": total_blocks,
"skippedPages": skipped_pages,
}) + "\n"
return StreamingResponse(
generate_guided(),
media_type="application/x-ndjson",
headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"},
)
async def generate():
total_pages = len(images)
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
@@ -132,7 +212,6 @@ async def run_ocr_stream(request: OcrRequest):
for page_idx, image in enumerate(images, start=1):
try:
engine = kraken_engine if use_kraken else surya_engine
blocks = await asyncio.to_thread(
engine.extract_page_blocks, image, page_idx, request.language
)

View File

@@ -1,12 +1,24 @@
from pydantic import BaseModel, ConfigDict
class OcrRegion(BaseModel):
model_config = ConfigDict(populate_by_name=True)
annotationId: str
pageNumber: int
x: float
y: float
width: float
height: float
class OcrRequest(BaseModel):
model_config = ConfigDict(populate_by_name=True)
pdfUrl: str
scriptType: str = "UNKNOWN"
language: str = "de"
regions: list[OcrRegion] | None = None
class OcrBlock(BaseModel):
@@ -19,3 +31,4 @@ class OcrBlock(BaseModel):
height: float
polygon: list[list[float]] | None = None
text: str
annotationId: str | None = None