feat(ocr): full OCR pipeline with polygon annotations, training, and guided mode #232
@@ -51,7 +51,8 @@ public class OcrController {
|
|||||||
@RequestBody TriggerOcrDTO dto,
|
@RequestBody TriggerOcrDTO dto,
|
||||||
Authentication authentication) {
|
Authentication authentication) {
|
||||||
UUID userId = resolveUserId(authentication);
|
UUID userId = resolveUserId(authentication);
|
||||||
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId);
|
UUID jobId = ocrService.startOcr(documentId, dto.getScriptType(), userId,
|
||||||
|
Boolean.TRUE.equals(dto.getUseExistingAnnotations()));
|
||||||
return Map.of("jobId", jobId);
|
return Map.of("jobId", jobId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,4 +10,5 @@ import org.raddatz.familienarchiv.model.ScriptType;
|
|||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class TriggerOcrDTO {
|
public class TriggerOcrDTO {
|
||||||
private ScriptType scriptType;
|
private ScriptType scriptType;
|
||||||
|
private Boolean useExistingAnnotations = false;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ public interface TranscriptionBlockRepository extends JpaRepository<Transcriptio
|
|||||||
|
|
||||||
Optional<TranscriptionBlock> findByIdAndDocumentId(UUID id, UUID documentId);
|
Optional<TranscriptionBlock> findByIdAndDocumentId(UUID id, UUID documentId);
|
||||||
|
|
||||||
|
Optional<TranscriptionBlock> findByAnnotationId(UUID annotationId);
|
||||||
|
|
||||||
int countByDocumentId(UUID documentId);
|
int countByDocumentId(UUID documentId);
|
||||||
|
|
||||||
@Query("""
|
@Query("""
|
||||||
|
|||||||
@@ -32,6 +32,11 @@ public class OcrAsyncRunner {
|
|||||||
|
|
||||||
@Async
|
@Async
|
||||||
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
|
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
|
||||||
|
runSingleDocument(jobId, documentId, userId, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Async
|
||||||
|
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId, boolean useExistingAnnotations) {
|
||||||
OcrJob job = ocrJobRepository.findById(jobId).orElse(null);
|
OcrJob job = ocrJobRepository.findById(jobId).orElse(null);
|
||||||
if (job == null) return;
|
if (job == null) return;
|
||||||
|
|
||||||
@@ -49,7 +54,18 @@ public class OcrAsyncRunner {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
updateProgress(job, "LOADING");
|
updateProgress(job, "LOADING");
|
||||||
clearExistingBlocks(documentId);
|
|
||||||
|
List<OcrClient.OcrRegion> regions = null;
|
||||||
|
if (useExistingAnnotations) {
|
||||||
|
regions = annotationService.listAnnotations(documentId).stream()
|
||||||
|
.map(a -> new OcrClient.OcrRegion(
|
||||||
|
a.getId().toString(), a.getPageNumber(),
|
||||||
|
a.getX(), a.getY(), a.getWidth(), a.getHeight()))
|
||||||
|
.toList();
|
||||||
|
} else {
|
||||||
|
clearExistingBlocks(documentId);
|
||||||
|
}
|
||||||
|
|
||||||
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
|
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
|
||||||
|
|
||||||
AtomicInteger blockCounter = new AtomicInteger(0);
|
AtomicInteger blockCounter = new AtomicInteger(0);
|
||||||
@@ -57,7 +73,7 @@ public class OcrAsyncRunner {
|
|||||||
AtomicInteger skippedPages = new AtomicInteger(0);
|
AtomicInteger skippedPages = new AtomicInteger(0);
|
||||||
AtomicInteger totalPages = new AtomicInteger(0);
|
AtomicInteger totalPages = new AtomicInteger(0);
|
||||||
|
|
||||||
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), event -> {
|
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, event -> {
|
||||||
switch (event) {
|
switch (event) {
|
||||||
case OcrStreamEvent.Start start -> {
|
case OcrStreamEvent.Start start -> {
|
||||||
totalPages.set(start.totalPages());
|
totalPages.set(start.totalPages());
|
||||||
@@ -204,14 +220,21 @@ public class OcrAsyncRunner {
|
|||||||
|
|
||||||
void createSingleBlock(UUID documentId, OcrBlockResult block,
|
void createSingleBlock(UUID documentId, OcrBlockResult block,
|
||||||
UUID userId, String fileHash, int sortOrder) {
|
UUID userId, String fileHash, int sortOrder) {
|
||||||
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
|
if (block.annotationId() != null) {
|
||||||
block.pageNumber(), block.x(), block.y(),
|
// Guided mode — annotation already exists; upsert the text block only
|
||||||
block.width(), block.height(), OCR_ANNOTATION_COLOR);
|
transcriptionService.upsertGuidedBlock(
|
||||||
|
documentId, UUID.fromString(block.annotationId()), block.text(), userId);
|
||||||
|
} else {
|
||||||
|
// Normal mode — create a new annotation and a new OCR block
|
||||||
|
CreateAnnotationDTO annotationDTO = new CreateAnnotationDTO(
|
||||||
|
block.pageNumber(), block.x(), block.y(),
|
||||||
|
block.width(), block.height(), OCR_ANNOTATION_COLOR);
|
||||||
|
|
||||||
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
|
DocumentAnnotation annotation = annotationService.createOcrAnnotation(
|
||||||
documentId, annotationDTO, userId, fileHash, block.polygon());
|
documentId, annotationDTO, userId, fileHash, block.polygon());
|
||||||
|
|
||||||
transcriptionService.createOcrBlock(documentId, annotation.getId(),
|
transcriptionService.createOcrBlock(documentId, annotation.getId(),
|
||||||
block.text(), sortOrder, userId);
|
block.text(), sortOrder, userId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,5 +12,6 @@ public record OcrBlockResult(
|
|||||||
double width,
|
double width,
|
||||||
double height,
|
double height,
|
||||||
List<List<Double>> polygon,
|
List<List<Double>> polygon,
|
||||||
String text
|
String text,
|
||||||
|
String annotationId // null in normal mode; set in guided mode to link back to existing annotation
|
||||||
) {}
|
) {}
|
||||||
|
|||||||
@@ -10,6 +10,14 @@ import java.util.function.Consumer;
|
|||||||
public interface OcrClient {
|
public interface OcrClient {
|
||||||
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
|
List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A pre-drawn annotation region to use as guidance for OCR.
|
||||||
|
* When regions are provided, the OCR engine crops to each region and
|
||||||
|
* runs recognition only within that area, skipping full-page layout detection.
|
||||||
|
*/
|
||||||
|
record OcrRegion(String annotationId, int pageNumber,
|
||||||
|
double x, double y, double width, double height) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Send a training ZIP to the OCR service for fine-tuning the Kurrent model.
|
* Send a training ZIP to the OCR service for fine-tuning the Kurrent model.
|
||||||
*
|
*
|
||||||
@@ -32,8 +40,12 @@ public interface OcrClient {
|
|||||||
* Stream OCR results page-by-page via NDJSON. Implementations should override
|
* Stream OCR results page-by-page via NDJSON. Implementations should override
|
||||||
* this method. The default exists only for backward compatibility during migration
|
* this method. The default exists only for backward compatibility during migration
|
||||||
* — it calls extractBlocks() and synthesizes events from the collected result.
|
* — it calls extractBlocks() and synthesizes events from the collected result.
|
||||||
|
*
|
||||||
|
* @param regions optional list of pre-drawn annotation regions; when non-null,
|
||||||
|
* the OCR service runs in guided mode (crop + recognize per region)
|
||||||
*/
|
*/
|
||||||
default void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
|
default void streamBlocks(String pdfUrl, ScriptType scriptType,
|
||||||
|
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
|
||||||
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
|
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
|
||||||
|
|
||||||
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();
|
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();
|
||||||
|
|||||||
@@ -52,6 +52,11 @@ public class OcrService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId) {
|
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId) {
|
||||||
|
return startOcr(documentId, scriptTypeOverride, userId, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public UUID startOcr(UUID documentId, ScriptType scriptTypeOverride, UUID userId,
|
||||||
|
boolean useExistingAnnotations) {
|
||||||
Document doc = documentService.getDocumentById(documentId);
|
Document doc = documentService.getDocumentById(documentId);
|
||||||
|
|
||||||
if (doc.getStatus() == DocumentStatus.PLACEHOLDER) {
|
if (doc.getStatus() == DocumentStatus.PLACEHOLDER) {
|
||||||
@@ -85,7 +90,7 @@ public class OcrService {
|
|||||||
.build();
|
.build();
|
||||||
ocrJobDocumentRepository.save(jobDoc);
|
ocrJobDocumentRepository.save(jobDoc);
|
||||||
|
|
||||||
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId);
|
ocrAsyncRunner.runSingleDocument(job.getId(), documentId, userId, useExistingAnnotations);
|
||||||
return job.getId();
|
return job.getId();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -175,13 +175,18 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void streamBlocks(String pdfUrl, ScriptType scriptType, Consumer<OcrStreamEvent> handler) {
|
public void streamBlocks(String pdfUrl, ScriptType scriptType,
|
||||||
|
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
|
||||||
String body;
|
String body;
|
||||||
try {
|
try {
|
||||||
body = NDJSON_MAPPER.writeValueAsString(Map.of(
|
var requestMap = new java.util.LinkedHashMap<String, Object>();
|
||||||
"pdfUrl", pdfUrl,
|
requestMap.put("pdfUrl", pdfUrl);
|
||||||
"scriptType", scriptType.name(),
|
requestMap.put("scriptType", scriptType.name());
|
||||||
"language", "de"));
|
requestMap.put("language", "de");
|
||||||
|
if (regions != null && !regions.isEmpty()) {
|
||||||
|
requestMap.put("regions", regions);
|
||||||
|
}
|
||||||
|
body = NDJSON_MAPPER.writeValueAsString(requestMap);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException("Failed to serialize OCR request", e);
|
throw new RuntimeException("Failed to serialize OCR request", e);
|
||||||
}
|
}
|
||||||
@@ -199,7 +204,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
|||||||
|
|
||||||
if (response.statusCode() == 404) {
|
if (response.statusCode() == 404) {
|
||||||
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
|
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
|
||||||
OcrClient.super.streamBlocks(pdfUrl, scriptType, handler);
|
OcrClient.super.streamBlocks(pdfUrl, scriptType, regions, handler);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,10 +264,11 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
|||||||
double width,
|
double width,
|
||||||
double height,
|
double height,
|
||||||
List<List<Double>> polygon,
|
List<List<Double>> polygon,
|
||||||
String text
|
String text,
|
||||||
|
String annotationId
|
||||||
) {
|
) {
|
||||||
OcrBlockResult toResult() {
|
OcrBlockResult toResult() {
|
||||||
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text);
|
return new OcrBlockResult(pageNumber, x, y, width, height, polygon, text, annotationId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -94,6 +94,27 @@ public class TranscriptionService {
|
|||||||
return saved;
|
return saved;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upsert an OCR transcription block for a pre-existing annotation (guided OCR mode).
|
||||||
|
* If the annotation already has a MANUAL block, it is left unchanged.
|
||||||
|
* If it has an OCR block, the text is updated in-place.
|
||||||
|
* If it has no block yet, a new OCR block is created.
|
||||||
|
*/
|
||||||
|
@Transactional
|
||||||
|
public TranscriptionBlock upsertGuidedBlock(UUID documentId, UUID annotationId,
|
||||||
|
String text, UUID userId) {
|
||||||
|
return blockRepository.findByAnnotationId(annotationId).map(existing -> {
|
||||||
|
if (existing.getSource() == BlockSource.MANUAL) {
|
||||||
|
return existing; // never overwrite manual transcription
|
||||||
|
}
|
||||||
|
existing.setText(sanitizeText(text));
|
||||||
|
existing.setUpdatedBy(userId);
|
||||||
|
TranscriptionBlock saved = blockRepository.save(existing);
|
||||||
|
saveVersion(saved, userId);
|
||||||
|
return saved;
|
||||||
|
}).orElseGet(() -> createOcrBlock(documentId, annotationId, text, 0, userId));
|
||||||
|
}
|
||||||
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public TranscriptionBlock updateBlock(UUID documentId, UUID blockId,
|
public TranscriptionBlock updateBlock(UUID documentId, UUID blockId,
|
||||||
UpdateTranscriptionBlockDTO dto, UUID userId) {
|
UpdateTranscriptionBlockDTO dto, UUID userId) {
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import java.util.List;
|
|||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
||||||
@@ -52,9 +53,9 @@ class OcrControllerTest {
|
|||||||
void triggerOcr_returns202_withJobId() throws Exception {
|
void triggerOcr_returns202_withJobId() throws Exception {
|
||||||
UUID docId = UUID.randomUUID();
|
UUID docId = UUID.randomUUID();
|
||||||
UUID jobId = UUID.randomUUID();
|
UUID jobId = UUID.randomUUID();
|
||||||
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER);
|
TriggerOcrDTO dto = new TriggerOcrDTO(ScriptType.TYPEWRITER, false);
|
||||||
|
|
||||||
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any())).thenReturn(jobId);
|
when(ocrService.startOcr(eq(docId), eq(ScriptType.TYPEWRITER), any(), anyBoolean())).thenReturn(jobId);
|
||||||
|
|
||||||
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
|
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
|
||||||
.contentType(MediaType.APPLICATION_JSON)
|
.contentType(MediaType.APPLICATION_JSON)
|
||||||
@@ -67,7 +68,7 @@ class OcrControllerTest {
|
|||||||
@WithMockUser(authorities = "WRITE_ALL")
|
@WithMockUser(authorities = "WRITE_ALL")
|
||||||
void triggerOcr_returns400_whenDocumentNotUploaded() throws Exception {
|
void triggerOcr_returns400_whenDocumentNotUploaded() throws Exception {
|
||||||
UUID docId = UUID.randomUUID();
|
UUID docId = UUID.randomUUID();
|
||||||
when(ocrService.startOcr(eq(docId), any(), any()))
|
when(ocrService.startOcr(eq(docId), any(), any(), anyBoolean()))
|
||||||
.thenThrow(DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED, "Not uploaded"));
|
.thenThrow(DomainException.badRequest(ErrorCode.OCR_DOCUMENT_NOT_UPLOADED, "Not uploaded"));
|
||||||
|
|
||||||
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
|
mockMvc.perform(post("/api/documents/{id}/ocr", docId)
|
||||||
|
|||||||
@@ -60,8 +60,8 @@ class OcrAsyncRunnerTest {
|
|||||||
|
|
||||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
||||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
|
||||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2")));
|
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null)));
|
||||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
||||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||||
|
|
||||||
@@ -84,7 +84,7 @@ class OcrAsyncRunnerTest {
|
|||||||
|
|
||||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
||||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test")));
|
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null)));
|
||||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
||||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||||
|
|
||||||
@@ -112,12 +112,12 @@ class OcrAsyncRunnerTest {
|
|||||||
|
|
||||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||||
doAnswer(inv -> {
|
doAnswer(inv -> {
|
||||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||||
handler.accept(new OcrStreamEvent.Start(1));
|
handler.accept(new OcrStreamEvent.Start(1));
|
||||||
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
||||||
handler.accept(new OcrStreamEvent.Done(0, 0));
|
handler.accept(new OcrStreamEvent.Done(0, 0));
|
||||||
return null;
|
return null;
|
||||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||||
|
|
||||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||||
|
|
||||||
@@ -142,7 +142,7 @@ class OcrAsyncRunnerTest {
|
|||||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||||
|
|
||||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||||
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any());
|
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||||
|
|
||||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||||
|
|
||||||
@@ -174,18 +174,18 @@ class OcrAsyncRunnerTest {
|
|||||||
|
|
||||||
List<String> progressMessages = new ArrayList<>();
|
List<String> progressMessages = new ArrayList<>();
|
||||||
doAnswer(inv -> {
|
doAnswer(inv -> {
|
||||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||||
handler.accept(new OcrStreamEvent.Start(3));
|
handler.accept(new OcrStreamEvent.Start(3));
|
||||||
handler.accept(new OcrStreamEvent.Page(0, List.of(
|
handler.accept(new OcrStreamEvent.Page(0, List.of(
|
||||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1"),
|
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null),
|
||||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2"))));
|
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "L2", null))));
|
||||||
progressMessages.add(job.getProgressMessage());
|
progressMessages.add(job.getProgressMessage());
|
||||||
handler.accept(new OcrStreamEvent.Page(1, List.of(
|
handler.accept(new OcrStreamEvent.Page(1, List.of(
|
||||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3"))));
|
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "L3", null))));
|
||||||
progressMessages.add(job.getProgressMessage());
|
progressMessages.add(job.getProgressMessage());
|
||||||
handler.accept(new OcrStreamEvent.Done(3, 0));
|
handler.accept(new OcrStreamEvent.Done(3, 0));
|
||||||
return null;
|
return null;
|
||||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||||
|
|
||||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||||
|
|
||||||
@@ -215,14 +215,14 @@ class OcrAsyncRunnerTest {
|
|||||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||||
|
|
||||||
doAnswer(inv -> {
|
doAnswer(inv -> {
|
||||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||||
handler.accept(new OcrStreamEvent.Start(3));
|
handler.accept(new OcrStreamEvent.Start(3));
|
||||||
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
||||||
handler.accept(new OcrStreamEvent.Error(1, "failed"));
|
handler.accept(new OcrStreamEvent.Error(1, "failed"));
|
||||||
handler.accept(new OcrStreamEvent.Page(2, List.of()));
|
handler.accept(new OcrStreamEvent.Page(2, List.of()));
|
||||||
handler.accept(new OcrStreamEvent.Done(0, 1));
|
handler.accept(new OcrStreamEvent.Done(0, 1));
|
||||||
return null;
|
return null;
|
||||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||||
|
|
||||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||||
|
|
||||||
@@ -251,13 +251,13 @@ class OcrAsyncRunnerTest {
|
|||||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||||
|
|
||||||
doAnswer(inv -> {
|
doAnswer(inv -> {
|
||||||
Consumer<OcrStreamEvent> handler = inv.getArgument(2);
|
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||||
handler.accept(new OcrStreamEvent.Start(2));
|
handler.accept(new OcrStreamEvent.Start(2));
|
||||||
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
|
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
|
||||||
handler.accept(new OcrStreamEvent.Page(1, List.of()));
|
handler.accept(new OcrStreamEvent.Page(1, List.of()));
|
||||||
handler.accept(new OcrStreamEvent.Done(0, 1));
|
handler.accept(new OcrStreamEvent.Done(0, 1));
|
||||||
return null;
|
return null;
|
||||||
}).when(ocrClient).streamBlocks(any(), any(), any());
|
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||||
|
|
||||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||||
|
|
||||||
|
|||||||
@@ -15,16 +15,16 @@ class OcrClientDefaultStreamTest {
|
|||||||
OcrClient client = new OcrClient() {
|
OcrClient client = new OcrClient() {
|
||||||
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
|
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) {
|
||||||
return List.of(
|
return List.of(
|
||||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1"),
|
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
|
||||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2"),
|
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null),
|
||||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
|
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3", null));
|
||||||
}
|
}
|
||||||
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||||
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
|
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
|
||||||
};
|
};
|
||||||
|
|
||||||
List<OcrStreamEvent> events = new ArrayList<>();
|
List<OcrStreamEvent> events = new ArrayList<>();
|
||||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
|
||||||
|
|
||||||
assertThat(events).hasSize(4);
|
assertThat(events).hasSize(4);
|
||||||
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
||||||
@@ -55,7 +55,7 @@ class OcrClientDefaultStreamTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
List<OcrStreamEvent> events = new ArrayList<>();
|
List<OcrStreamEvent> events = new ArrayList<>();
|
||||||
client.streamBlocks("http://test", ScriptType.TYPEWRITER, events::add);
|
client.streamBlocks("http://test", ScriptType.TYPEWRITER, null, events::add);
|
||||||
|
|
||||||
assertThat(events).hasSize(2);
|
assertThat(events).hasSize(2);
|
||||||
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
assertThat(events.get(0)).isInstanceOf(OcrStreamEvent.Start.class);
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class OcrServiceTest {
|
|||||||
UUID result = ocrService.startOcr(docId, null, userId);
|
UUID result = ocrService.startOcr(docId, null, userId);
|
||||||
|
|
||||||
assertThat(result).isEqualTo(jobId);
|
assertThat(result).isEqualTo(jobId);
|
||||||
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId);
|
verify(ocrAsyncRunner).runSingleDocument(jobId, docId, userId, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class OcrStreamEventTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void pageRecordHoldsBlocksAndPageNumber() {
|
void pageRecordHoldsBlocksAndPageNumber() {
|
||||||
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test");
|
var block = new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.1, null, "Test", null);
|
||||||
var page = new OcrStreamEvent.Page(0, List.of(block));
|
var page = new OcrStreamEvent.Page(0, List.of(block));
|
||||||
assertThat(page.pageNumber()).isEqualTo(0);
|
assertThat(page.pageNumber()).isEqualTo(0);
|
||||||
assertThat(page.blocks()).hasSize(1);
|
assertThat(page.blocks()).hasSize(1);
|
||||||
|
|||||||
@@ -0,0 +1,93 @@
|
|||||||
|
package org.raddatz.familienarchiv.service;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.raddatz.familienarchiv.model.BlockSource;
|
||||||
|
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||||
|
import org.raddatz.familienarchiv.repository.AnnotationRepository;
|
||||||
|
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||||
|
import org.raddatz.familienarchiv.repository.TranscriptionBlockVersionRepository;
|
||||||
|
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
class TranscriptionServiceGuidedTest {
|
||||||
|
|
||||||
|
TranscriptionBlockRepository blockRepository;
|
||||||
|
TranscriptionBlockVersionRepository versionRepository;
|
||||||
|
AnnotationRepository annotationRepository;
|
||||||
|
AnnotationService annotationService;
|
||||||
|
DocumentService documentService;
|
||||||
|
TranscriptionService service;
|
||||||
|
|
||||||
|
UUID docId = UUID.randomUUID();
|
||||||
|
UUID annId = UUID.randomUUID();
|
||||||
|
UUID userId = UUID.randomUUID();
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||||
|
versionRepository = mock(TranscriptionBlockVersionRepository.class);
|
||||||
|
annotationRepository = mock(AnnotationRepository.class);
|
||||||
|
annotationService = mock(AnnotationService.class);
|
||||||
|
documentService = mock(DocumentService.class);
|
||||||
|
|
||||||
|
service = new TranscriptionService(blockRepository, versionRepository,
|
||||||
|
annotationRepository, annotationService, documentService);
|
||||||
|
|
||||||
|
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||||
|
when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void upsertGuidedBlock_createsNewBlock_whenAnnotationHasNoBlock() {
|
||||||
|
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.empty());
|
||||||
|
|
||||||
|
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "Hallo", userId);
|
||||||
|
|
||||||
|
assertThat(result.getText()).isEqualTo("Hallo");
|
||||||
|
assertThat(result.getSource()).isEqualTo(BlockSource.OCR);
|
||||||
|
assertThat(result.getAnnotationId()).isEqualTo(annId);
|
||||||
|
verify(blockRepository).save(any());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void upsertGuidedBlock_updatesExistingOcrBlock_whenAnnotationHasOcrBlock() {
|
||||||
|
TranscriptionBlock existing = TranscriptionBlock.builder()
|
||||||
|
.id(UUID.randomUUID())
|
||||||
|
.annotationId(annId)
|
||||||
|
.documentId(docId)
|
||||||
|
.text("old text")
|
||||||
|
.source(BlockSource.OCR)
|
||||||
|
.sortOrder(0)
|
||||||
|
.build();
|
||||||
|
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(existing));
|
||||||
|
|
||||||
|
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "new text", userId);
|
||||||
|
|
||||||
|
assertThat(result.getText()).isEqualTo("new text");
|
||||||
|
verify(blockRepository).save(any());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void upsertGuidedBlock_doesNotOverwriteManualBlock() {
|
||||||
|
TranscriptionBlock manual = TranscriptionBlock.builder()
|
||||||
|
.id(UUID.randomUUID())
|
||||||
|
.annotationId(annId)
|
||||||
|
.documentId(docId)
|
||||||
|
.text("manually written")
|
||||||
|
.source(BlockSource.MANUAL)
|
||||||
|
.sortOrder(0)
|
||||||
|
.build();
|
||||||
|
when(blockRepository.findByAnnotationId(annId)).thenReturn(Optional.of(manual));
|
||||||
|
|
||||||
|
TranscriptionBlock result = service.upsertGuidedBlock(docId, annId, "ocr result", userId);
|
||||||
|
|
||||||
|
assertThat(result.getText()).isEqualTo("manually written");
|
||||||
|
verify(blockRepository, never()).save(any());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -531,6 +531,8 @@
|
|||||||
"ocr_status_analyzing_page": "Seite {current} von {total} wird analysiert…",
|
"ocr_status_analyzing_page": "Seite {current} von {total} wird analysiert…",
|
||||||
"ocr_status_done_skipped": "{count} Blöcke erstellt, {skipped} Seite(n) übersprungen",
|
"ocr_status_done_skipped": "{count} Blöcke erstellt, {skipped} Seite(n) übersprungen",
|
||||||
"ocr_status_error": "OCR fehlgeschlagen",
|
"ocr_status_error": "OCR fehlgeschlagen",
|
||||||
|
"ocr_use_existing_annotations": "Nur annotierte Bereiche",
|
||||||
|
"ocr_use_existing_annotations_hint": "OCR wird nur innerhalb der bereits markierten Bereiche ausgeführt — keine neue Layout-Erkennung.",
|
||||||
"transcription_block_review": "Als geprüft markieren",
|
"transcription_block_review": "Als geprüft markieren",
|
||||||
"transcription_block_unreview": "Markierung aufheben",
|
"transcription_block_unreview": "Markierung aufheben",
|
||||||
"transcription_reviewed_count": "{reviewed} von {total} geprüft",
|
"transcription_reviewed_count": "{reviewed} von {total} geprüft",
|
||||||
|
|||||||
@@ -531,6 +531,8 @@
|
|||||||
"ocr_status_analyzing_page": "Analyzing page {current} of {total}…",
|
"ocr_status_analyzing_page": "Analyzing page {current} of {total}…",
|
||||||
"ocr_status_done_skipped": "{count} blocks created, {skipped} page(s) skipped",
|
"ocr_status_done_skipped": "{count} blocks created, {skipped} page(s) skipped",
|
||||||
"ocr_status_error": "OCR failed",
|
"ocr_status_error": "OCR failed",
|
||||||
|
"ocr_use_existing_annotations": "Annotated regions only",
|
||||||
|
"ocr_use_existing_annotations_hint": "OCR runs only within the already marked regions — no new layout detection.",
|
||||||
"transcription_block_review": "Mark as reviewed",
|
"transcription_block_review": "Mark as reviewed",
|
||||||
"transcription_block_unreview": "Unmark as reviewed",
|
"transcription_block_unreview": "Unmark as reviewed",
|
||||||
"transcription_reviewed_count": "{reviewed} of {total} reviewed",
|
"transcription_reviewed_count": "{reviewed} of {total} reviewed",
|
||||||
|
|||||||
@@ -531,6 +531,8 @@
|
|||||||
"ocr_status_analyzing_page": "Analizando página {current} de {total}…",
|
"ocr_status_analyzing_page": "Analizando página {current} de {total}…",
|
||||||
"ocr_status_done_skipped": "{count} bloques creados, {skipped} página(s) omitida(s)",
|
"ocr_status_done_skipped": "{count} bloques creados, {skipped} página(s) omitida(s)",
|
||||||
"ocr_status_error": "OCR fallido",
|
"ocr_status_error": "OCR fallido",
|
||||||
|
"ocr_use_existing_annotations": "Solo regiones anotadas",
|
||||||
|
"ocr_use_existing_annotations_hint": "El OCR se ejecuta solo dentro de las regiones ya marcadas — sin nueva detección de diseño.",
|
||||||
"transcription_block_review": "Marcar como revisado",
|
"transcription_block_review": "Marcar como revisado",
|
||||||
"transcription_block_unreview": "Desmarcar como revisado",
|
"transcription_block_unreview": "Desmarcar como revisado",
|
||||||
"transcription_reviewed_count": "{reviewed} de {total} revisados",
|
"transcription_reviewed_count": "{reviewed} de {total} revisados",
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ import ScriptTypeSelect from './ScriptTypeSelect.svelte';
|
|||||||
interface Props {
|
interface Props {
|
||||||
existingBlockCount: number;
|
existingBlockCount: number;
|
||||||
storedScriptType: string;
|
storedScriptType: string;
|
||||||
onTrigger: (scriptType: string) => void;
|
annotationCount?: number;
|
||||||
|
onTrigger: (scriptType: string, useExistingAnnotations: boolean) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
let { existingBlockCount, storedScriptType, onTrigger }: Props = $props();
|
let { existingBlockCount, storedScriptType, annotationCount = 0, onTrigger }: Props = $props();
|
||||||
|
|
||||||
const { confirm } = getConfirmService();
|
const { confirm } = getConfirmService();
|
||||||
|
|
||||||
@@ -18,10 +19,12 @@ let selectedScriptType: string = $state(
|
|||||||
untrack(() => (storedScriptType && storedScriptType !== 'UNKNOWN' ? storedScriptType : ''))
|
untrack(() => (storedScriptType && storedScriptType !== 'UNKNOWN' ? storedScriptType : ''))
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let useExistingAnnotations: boolean = $state(false);
|
||||||
|
|
||||||
async function handleClick() {
|
async function handleClick() {
|
||||||
if (!selectedScriptType) return;
|
if (!selectedScriptType) return;
|
||||||
|
|
||||||
if (existingBlockCount > 0) {
|
if (!useExistingAnnotations && existingBlockCount > 0) {
|
||||||
const confirmed = await confirm({
|
const confirmed = await confirm({
|
||||||
title: m.ocr_confirm_title(),
|
title: m.ocr_confirm_title(),
|
||||||
body: m.ocr_confirm_body({ count: String(existingBlockCount) }),
|
body: m.ocr_confirm_body({ count: String(existingBlockCount) }),
|
||||||
@@ -31,12 +34,27 @@ async function handleClick() {
|
|||||||
if (!confirmed) return;
|
if (!confirmed) return;
|
||||||
}
|
}
|
||||||
|
|
||||||
onTrigger(selectedScriptType);
|
onTrigger(selectedScriptType, useExistingAnnotations);
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div class="flex flex-col gap-3">
|
<div class="flex flex-col gap-3">
|
||||||
<ScriptTypeSelect bind:value={selectedScriptType} />
|
<ScriptTypeSelect bind:value={selectedScriptType} />
|
||||||
|
{#if annotationCount > 0}
|
||||||
|
<div class="flex flex-col gap-1">
|
||||||
|
<label class="flex cursor-pointer items-center gap-2">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
bind:checked={useExistingAnnotations}
|
||||||
|
class="h-4 w-4 cursor-pointer rounded-sm border-brand-navy/30 accent-brand-navy"
|
||||||
|
/>
|
||||||
|
<span class="font-sans text-sm font-medium text-brand-navy">
|
||||||
|
{m.ocr_use_existing_annotations()}
|
||||||
|
</span>
|
||||||
|
</label>
|
||||||
|
<p class="pl-6 text-xs text-ink-3">{m.ocr_use_existing_annotations_hint()}</p>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
disabled={!selectedScriptType}
|
disabled={!selectedScriptType}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ type Props = {
|
|||||||
onSaveBlock: (blockId: string, text: string) => Promise<void>;
|
onSaveBlock: (blockId: string, text: string) => Promise<void>;
|
||||||
onDeleteBlock: (blockId: string) => Promise<void>;
|
onDeleteBlock: (blockId: string) => Promise<void>;
|
||||||
onReviewToggle: (blockId: string) => Promise<void>;
|
onReviewToggle: (blockId: string) => Promise<void>;
|
||||||
onTriggerOcr?: (scriptType: string) => void;
|
onTriggerOcr?: (scriptType: string, useExistingAnnotations: boolean) => void;
|
||||||
canWrite?: boolean;
|
canWrite?: boolean;
|
||||||
trainingLabels?: string[];
|
trainingLabels?: string[];
|
||||||
onToggleTrainingLabel?: (label: string, enrolled: boolean) => Promise<void>;
|
onToggleTrainingLabel?: (label: string, enrolled: boolean) => Promise<void>;
|
||||||
@@ -370,6 +370,7 @@ $effect(() => {
|
|||||||
<div class="mt-3 max-w-xs">
|
<div class="mt-3 max-w-xs">
|
||||||
<OcrTrigger
|
<OcrTrigger
|
||||||
existingBlockCount={blocks.length}
|
existingBlockCount={blocks.length}
|
||||||
|
annotationCount={blocks.length}
|
||||||
storedScriptType={storedScriptType}
|
storedScriptType={storedScriptType}
|
||||||
onTrigger={onTriggerOcr}
|
onTrigger={onTriggerOcr}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -144,14 +144,14 @@ let ocrErrorMessage = $state('');
|
|||||||
let ocrPollTimer = $state<ReturnType<typeof setInterval> | null>(null);
|
let ocrPollTimer = $state<ReturnType<typeof setInterval> | null>(null);
|
||||||
let ocrSkippedPages = $state(0);
|
let ocrSkippedPages = $state(0);
|
||||||
|
|
||||||
async function triggerOcr(scriptType: string) {
|
async function triggerOcr(scriptType: string, useExistingAnnotations: boolean = false) {
|
||||||
ocrRunning = true;
|
ocrRunning = true;
|
||||||
ocrErrorMessage = '';
|
ocrErrorMessage = '';
|
||||||
try {
|
try {
|
||||||
const res = await fetch(`/api/documents/${doc.id}/ocr`, {
|
const res = await fetch(`/api/documents/${doc.id}/ocr`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({ scriptType })
|
body: JSON.stringify({ scriptType, useExistingAnnotations })
|
||||||
});
|
});
|
||||||
if (res.ok) {
|
if (res.ok) {
|
||||||
const data = await res.json();
|
const data = await res.json();
|
||||||
|
|||||||
@@ -79,6 +79,29 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
|||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
||||||
|
"""Crop image to a normalized region and run Kraken recognition on the crop.
|
||||||
|
|
||||||
|
Used for guided OCR — skips full-page layout detection and only processes
|
||||||
|
the given bounding box. Coordinates are normalized to [0, 1].
|
||||||
|
"""
|
||||||
|
from kraken import blla, rpred
|
||||||
|
|
||||||
|
if _model is None:
|
||||||
|
raise RuntimeError("Kraken model is not loaded")
|
||||||
|
|
||||||
|
pw, ph = image.size
|
||||||
|
x1 = max(0, int(x * pw))
|
||||||
|
y1 = max(0, int(y * ph))
|
||||||
|
x2 = min(pw, int((x + w) * pw))
|
||||||
|
y2 = min(ph, int((y + h) * ph))
|
||||||
|
crop = image.crop((x1, y1, x2, y2))
|
||||||
|
|
||||||
|
baseline_seg = blla.segment(crop)
|
||||||
|
pred_it = rpred.rpred(_model, crop, baseline_seg)
|
||||||
|
return " ".join(r.prediction for r in pred_it)
|
||||||
|
|
||||||
|
|
||||||
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
||||||
"""Run Kraken segmentation + recognition on a list of PIL images.
|
"""Run Kraken segmentation + recognition on a list of PIL images.
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,25 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
|||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
||||||
|
"""Crop image to a normalized region and run Surya recognition on the crop.
|
||||||
|
|
||||||
|
Used for guided OCR — skips full-page layout detection and only processes
|
||||||
|
the given bounding box. Coordinates are normalized to [0, 1].
|
||||||
|
"""
|
||||||
|
load_models()
|
||||||
|
|
||||||
|
pw, ph = image.size
|
||||||
|
x1 = max(0, int(x * pw))
|
||||||
|
y1 = max(0, int(y * ph))
|
||||||
|
x2 = min(pw, int((x + w) * pw))
|
||||||
|
y2 = min(ph, int((y + h) * ph))
|
||||||
|
crop = image.crop((x1, y1, x2, y2))
|
||||||
|
|
||||||
|
predictions = _recognition_predictor([crop], det_predictor=_detection_predictor)
|
||||||
|
return " ".join(line.text for line in predictions[0].text_lines)
|
||||||
|
|
||||||
|
|
||||||
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
||||||
"""Run Surya OCR on a list of PIL images (one per page).
|
"""Run Surya OCR on a list of PIL images (one per page).
|
||||||
|
|
||||||
|
|||||||
@@ -108,7 +108,12 @@ async def run_ocr(request: OcrRequest):
|
|||||||
|
|
||||||
@app.post("/ocr/stream")
|
@app.post("/ocr/stream")
|
||||||
async def run_ocr_stream(request: OcrRequest):
|
async def run_ocr_stream(request: OcrRequest):
|
||||||
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page."""
|
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page.
|
||||||
|
|
||||||
|
When request.regions is provided, runs in guided mode: each region is cropped
|
||||||
|
and recognized individually, skipping full-page layout detection. The response
|
||||||
|
blocks include the annotationId from the region.
|
||||||
|
"""
|
||||||
if not _models_ready:
|
if not _models_ready:
|
||||||
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||||
|
|
||||||
@@ -123,6 +128,81 @@ async def run_ocr_stream(request: OcrRequest):
|
|||||||
detail="Kraken model not available — cannot process Kurrent script",
|
detail="Kraken model not available — cannot process Kurrent script",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
engine = kraken_engine if use_kraken else surya_engine
|
||||||
|
|
||||||
|
if request.regions:
|
||||||
|
# Guided mode: recognize only the user-drawn annotation regions
|
||||||
|
regions_by_page: dict[int, list] = {}
|
||||||
|
for region in request.regions:
|
||||||
|
regions_by_page.setdefault(region.pageNumber, []).append(region)
|
||||||
|
|
||||||
|
async def generate_guided():
|
||||||
|
total_pages = len(images)
|
||||||
|
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
||||||
|
|
||||||
|
total_blocks = 0
|
||||||
|
skipped_pages = 0
|
||||||
|
|
||||||
|
for page_idx, image in enumerate(images, start=1):
|
||||||
|
page_regions = regions_by_page.get(page_idx, [])
|
||||||
|
if not page_regions:
|
||||||
|
yield json.dumps({
|
||||||
|
"type": "page",
|
||||||
|
"pageNumber": page_idx,
|
||||||
|
"blocks": [],
|
||||||
|
}) + "\n"
|
||||||
|
del image
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
blocks = []
|
||||||
|
for region in page_regions:
|
||||||
|
text = await asyncio.to_thread(
|
||||||
|
engine.extract_region_text, image,
|
||||||
|
region.x, region.y, region.width, region.height,
|
||||||
|
)
|
||||||
|
blocks.append({
|
||||||
|
"pageNumber": page_idx,
|
||||||
|
"x": region.x,
|
||||||
|
"y": region.y,
|
||||||
|
"width": region.width,
|
||||||
|
"height": region.height,
|
||||||
|
"polygon": None,
|
||||||
|
"text": text,
|
||||||
|
"annotationId": region.annotationId,
|
||||||
|
})
|
||||||
|
|
||||||
|
total_blocks += len(blocks)
|
||||||
|
yield json.dumps({
|
||||||
|
"type": "page",
|
||||||
|
"pageNumber": page_idx,
|
||||||
|
"blocks": blocks,
|
||||||
|
}) + "\n"
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Guided OCR failed on page %d", page_idx)
|
||||||
|
skipped_pages += 1
|
||||||
|
yield json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"pageNumber": page_idx,
|
||||||
|
"message": f"Guided OCR processing failed on page {page_idx}",
|
||||||
|
}) + "\n"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
del image
|
||||||
|
|
||||||
|
yield json.dumps({
|
||||||
|
"type": "done",
|
||||||
|
"totalBlocks": total_blocks,
|
||||||
|
"skippedPages": skipped_pages,
|
||||||
|
}) + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_guided(),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"},
|
||||||
|
)
|
||||||
|
|
||||||
async def generate():
|
async def generate():
|
||||||
total_pages = len(images)
|
total_pages = len(images)
|
||||||
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
||||||
@@ -132,7 +212,6 @@ async def run_ocr_stream(request: OcrRequest):
|
|||||||
|
|
||||||
for page_idx, image in enumerate(images, start=1):
|
for page_idx, image in enumerate(images, start=1):
|
||||||
try:
|
try:
|
||||||
engine = kraken_engine if use_kraken else surya_engine
|
|
||||||
blocks = await asyncio.to_thread(
|
blocks = await asyncio.to_thread(
|
||||||
engine.extract_page_blocks, image, page_idx, request.language
|
engine.extract_page_blocks, image, page_idx, request.language
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,12 +1,24 @@
|
|||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class OcrRegion(BaseModel):
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
annotationId: str
|
||||||
|
pageNumber: int
|
||||||
|
x: float
|
||||||
|
y: float
|
||||||
|
width: float
|
||||||
|
height: float
|
||||||
|
|
||||||
|
|
||||||
class OcrRequest(BaseModel):
|
class OcrRequest(BaseModel):
|
||||||
model_config = ConfigDict(populate_by_name=True)
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
pdfUrl: str
|
pdfUrl: str
|
||||||
scriptType: str = "UNKNOWN"
|
scriptType: str = "UNKNOWN"
|
||||||
language: str = "de"
|
language: str = "de"
|
||||||
|
regions: list[OcrRegion] | None = None
|
||||||
|
|
||||||
|
|
||||||
class OcrBlock(BaseModel):
|
class OcrBlock(BaseModel):
|
||||||
@@ -19,3 +31,4 @@ class OcrBlock(BaseModel):
|
|||||||
height: float
|
height: float
|
||||||
polygon: list[list[float]] | None = None
|
polygon: list[list[float]] | None = None
|
||||||
text: str
|
text: str
|
||||||
|
annotationId: str | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user