feat(ocr): add image preprocessing pipeline to improve transcription quality on aged documents #255
@@ -82,6 +82,10 @@ public class OcrAsyncRunner {
|
|||||||
ocrJobDocumentRepository.save(jobDoc);
|
ocrJobDocumentRepository.save(jobDoc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case OcrStreamEvent.Preprocessing preprocessing -> {
|
||||||
|
updateProgress(job, "PREPROCESSING_PAGE:" + preprocessing.pageNumber()
|
||||||
|
+ ":" + totalPages.get());
|
||||||
|
}
|
||||||
case OcrStreamEvent.Page page -> {
|
case OcrStreamEvent.Page page -> {
|
||||||
for (OcrBlockResult block : page.blocks()) {
|
for (OcrBlockResult block : page.blocks()) {
|
||||||
createSingleBlock(documentId, block, userId,
|
createSingleBlock(documentId, block, userId,
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ public sealed interface OcrStreamEvent {
|
|||||||
|
|
||||||
record Start(int totalPages) implements OcrStreamEvent {}
|
record Start(int totalPages) implements OcrStreamEvent {}
|
||||||
|
|
||||||
|
record Preprocessing(int pageNumber) implements OcrStreamEvent {}
|
||||||
|
|
||||||
record Page(int pageNumber, List<OcrBlockResult> blocks) implements OcrStreamEvent {}
|
record Page(int pageNumber, List<OcrBlockResult> blocks) implements OcrStreamEvent {}
|
||||||
|
|
||||||
record Error(int pageNumber, String message) implements OcrStreamEvent {}
|
record Error(int pageNumber, String message) implements OcrStreamEvent {}
|
||||||
|
|||||||
@@ -232,6 +232,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
|||||||
switch (type) {
|
switch (type) {
|
||||||
case "start" -> handler.accept(
|
case "start" -> handler.accept(
|
||||||
new OcrStreamEvent.Start(node.path("totalPages").asInt()));
|
new OcrStreamEvent.Start(node.path("totalPages").asInt()));
|
||||||
|
case "preprocessing" -> handler.accept(
|
||||||
|
new OcrStreamEvent.Preprocessing(node.path("pageNumber").asInt()));
|
||||||
case "page" -> {
|
case "page" -> {
|
||||||
int pageNumber = node.path("pageNumber").asInt();
|
int pageNumber = node.path("pageNumber").asInt();
|
||||||
List<OcrBlockResult> blocks = NDJSON_MAPPER.convertValue(
|
List<OcrBlockResult> blocks = NDJSON_MAPPER.convertValue(
|
||||||
|
|||||||
@@ -230,6 +230,42 @@ class OcrAsyncRunnerTest {
|
|||||||
assertThat(job.getProgressMessage()).isEqualTo("DONE:0:1");
|
assertThat(job.getProgressMessage()).isEqualTo("DONE:0:1");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void runSingleDocument_updatesProgressOnPreprocessingEvent() {
|
||||||
|
UUID jobId = UUID.randomUUID();
|
||||||
|
UUID docId = UUID.randomUUID();
|
||||||
|
UUID userId = UUID.randomUUID();
|
||||||
|
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
|
||||||
|
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
|
||||||
|
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
|
||||||
|
Document doc = Document.builder().id(docId).filePath("test.pdf")
|
||||||
|
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
|
||||||
|
|
||||||
|
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
|
||||||
|
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||||
|
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
|
||||||
|
.thenReturn(Optional.of(jobDoc));
|
||||||
|
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||||
|
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||||
|
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||||
|
|
||||||
|
List<String> progressMessages = new ArrayList<>();
|
||||||
|
doAnswer(inv -> {
|
||||||
|
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||||
|
handler.accept(new OcrStreamEvent.Start(5));
|
||||||
|
handler.accept(new OcrStreamEvent.Preprocessing(1));
|
||||||
|
progressMessages.add(job.getProgressMessage());
|
||||||
|
handler.accept(new OcrStreamEvent.Page(1, List.of()));
|
||||||
|
handler.accept(new OcrStreamEvent.Done(0, 0));
|
||||||
|
return null;
|
||||||
|
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||||
|
|
||||||
|
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||||
|
|
||||||
|
// Must have totalPages=5 from the preceding Start event, not 0
|
||||||
|
assertThat(progressMessages.get(0)).isEqualTo("PREPROCESSING_PAGE:1:5");
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void runSingleDocument_logsStreamErrorAtWarnWithoutSettingJobFailed() {
|
void runSingleDocument_logsStreamErrorAtWarnWithoutSettingJobFailed() {
|
||||||
UUID jobId = UUID.randomUUID();
|
UUID jobId = UUID.randomUUID();
|
||||||
|
|||||||
@@ -37,11 +37,19 @@ class OcrStreamEventTest {
|
|||||||
assertThat(done.skippedPages()).isEqualTo(2);
|
assertThat(done.skippedPages()).isEqualTo(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void preprocessingRecordHoldsPageNumber() {
|
||||||
|
var preprocessing = new OcrStreamEvent.Preprocessing(4);
|
||||||
|
assertThat(preprocessing.pageNumber()).isEqualTo(4);
|
||||||
|
assertThat(preprocessing).isInstanceOf(OcrStreamEvent.class);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void patternMatchingWorksOnSealedInterface() {
|
void patternMatchingWorksOnSealedInterface() {
|
||||||
OcrStreamEvent event = new OcrStreamEvent.Start(3);
|
OcrStreamEvent event = new OcrStreamEvent.Start(3);
|
||||||
String result = switch (event) {
|
String result = switch (event) {
|
||||||
case OcrStreamEvent.Start s -> "start:" + s.totalPages();
|
case OcrStreamEvent.Start s -> "start:" + s.totalPages();
|
||||||
|
case OcrStreamEvent.Preprocessing pre -> "preprocessing:" + pre.pageNumber();
|
||||||
case OcrStreamEvent.Page p -> "page:" + p.pageNumber();
|
case OcrStreamEvent.Page p -> "page:" + p.pageNumber();
|
||||||
case OcrStreamEvent.Error e -> "error:" + e.pageNumber();
|
case OcrStreamEvent.Error e -> "error:" + e.pageNumber();
|
||||||
case OcrStreamEvent.Done d -> "done:" + d.totalBlocks();
|
case OcrStreamEvent.Done d -> "done:" + d.totalBlocks();
|
||||||
|
|||||||
@@ -83,6 +83,25 @@ class RestClientOcrClientStreamTest {
|
|||||||
assertThat(events).hasSize(3);
|
assertThat(events).hasSize(3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void parseNdjsonStream_dispatchesPreprocessingEvent() {
|
||||||
|
String ndjson = """
|
||||||
|
{"type":"start","totalPages":3}
|
||||||
|
{"type":"preprocessing","pageNumber":1}
|
||||||
|
{"type":"page","pageNumber":1,"blocks":[]}
|
||||||
|
{"type":"done","totalBlocks":0,"skippedPages":0}
|
||||||
|
""";
|
||||||
|
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
|
||||||
|
|
||||||
|
List<OcrStreamEvent> events = new ArrayList<>();
|
||||||
|
RestClientOcrClient.parseNdjsonStream(stream, events::add);
|
||||||
|
|
||||||
|
assertThat(events).hasSize(4);
|
||||||
|
assertThat(events.get(1)).isInstanceOf(OcrStreamEvent.Preprocessing.class);
|
||||||
|
var preprocessing = (OcrStreamEvent.Preprocessing) events.get(1);
|
||||||
|
assertThat(preprocessing.pageNumber()).isEqualTo(1);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void parseNdjsonStream_ignoresUnknownEventTypes() {
|
void parseNdjsonStream_ignoresUnknownEventTypes() {
|
||||||
String ndjson = """
|
String ndjson = """
|
||||||
|
|||||||
@@ -95,6 +95,8 @@ services:
|
|||||||
OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5"
|
OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5"
|
||||||
RECOGNITION_BATCH_SIZE: "16"
|
RECOGNITION_BATCH_SIZE: "16"
|
||||||
DETECTOR_BATCH_SIZE: "8"
|
DETECTOR_BATCH_SIZE: "8"
|
||||||
|
OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency)
|
||||||
|
OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page)
|
||||||
networks:
|
networks:
|
||||||
- archive-net
|
- archive-net
|
||||||
healthcheck:
|
healthcheck:
|
||||||
|
|||||||
@@ -526,6 +526,7 @@
|
|||||||
"ocr_status_creating_blocks": "{count} Textblöcke erkannt — erstelle Transkription…",
|
"ocr_status_creating_blocks": "{count} Textblöcke erkannt — erstelle Transkription…",
|
||||||
"ocr_status_done_blocks": "{count} Blöcke erstellt",
|
"ocr_status_done_blocks": "{count} Blöcke erstellt",
|
||||||
"ocr_status_analyzing_page": "Seite {current} von {total} wird analysiert…",
|
"ocr_status_analyzing_page": "Seite {current} von {total} wird analysiert…",
|
||||||
|
"ocr_status_preprocessing_page": "Seite {current} von {total} wird aufbereitet…",
|
||||||
"ocr_status_done_skipped": "{count} Blöcke erstellt, {skipped} Seite(n) übersprungen",
|
"ocr_status_done_skipped": "{count} Blöcke erstellt, {skipped} Seite(n) übersprungen",
|
||||||
"ocr_status_error": "OCR fehlgeschlagen",
|
"ocr_status_error": "OCR fehlgeschlagen",
|
||||||
"ocr_trigger_no_annotations": "Zeichnen Sie zuerst Bereiche auf dem Dokument ein.",
|
"ocr_trigger_no_annotations": "Zeichnen Sie zuerst Bereiche auf dem Dokument ein.",
|
||||||
|
|||||||
@@ -526,6 +526,7 @@
|
|||||||
"ocr_status_creating_blocks": "{count} text blocks detected — creating transcription…",
|
"ocr_status_creating_blocks": "{count} text blocks detected — creating transcription…",
|
||||||
"ocr_status_done_blocks": "{count} blocks created",
|
"ocr_status_done_blocks": "{count} blocks created",
|
||||||
"ocr_status_analyzing_page": "Analyzing page {current} of {total}…",
|
"ocr_status_analyzing_page": "Analyzing page {current} of {total}…",
|
||||||
|
"ocr_status_preprocessing_page": "Preparing page {current} of {total}…",
|
||||||
"ocr_status_done_skipped": "{count} blocks created, {skipped} page(s) skipped",
|
"ocr_status_done_skipped": "{count} blocks created, {skipped} page(s) skipped",
|
||||||
"ocr_status_error": "OCR failed",
|
"ocr_status_error": "OCR failed",
|
||||||
"ocr_trigger_no_annotations": "Draw regions on the document first.",
|
"ocr_trigger_no_annotations": "Draw regions on the document first.",
|
||||||
|
|||||||
@@ -526,6 +526,7 @@
|
|||||||
"ocr_status_creating_blocks": "{count} bloques de texto detectados — creando transcripción…",
|
"ocr_status_creating_blocks": "{count} bloques de texto detectados — creando transcripción…",
|
||||||
"ocr_status_done_blocks": "{count} bloques creados",
|
"ocr_status_done_blocks": "{count} bloques creados",
|
||||||
"ocr_status_analyzing_page": "Analizando página {current} de {total}…",
|
"ocr_status_analyzing_page": "Analizando página {current} de {total}…",
|
||||||
|
"ocr_status_preprocessing_page": "Preparando página {current} de {total}…",
|
||||||
"ocr_status_done_skipped": "{count} bloques creados, {skipped} página(s) omitida(s)",
|
"ocr_status_done_skipped": "{count} bloques creados, {skipped} página(s) omitida(s)",
|
||||||
"ocr_status_error": "OCR fallido",
|
"ocr_status_error": "OCR fallido",
|
||||||
"ocr_trigger_no_annotations": "Dibuje regiones en el documento primero.",
|
"ocr_trigger_no_annotations": "Dibuje regiones en el documento primero.",
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ vi.mock('$lib/paraglide/messages.js', () => ({
|
|||||||
`${count} Blöcke erstellt, ${skipped} Seite(n) übersprungen`,
|
`${count} Blöcke erstellt, ${skipped} Seite(n) übersprungen`,
|
||||||
ocr_status_analyzing_page: ({ current, total }: { current: string; total: string }) =>
|
ocr_status_analyzing_page: ({ current, total }: { current: string; total: string }) =>
|
||||||
`Seite ${current} von ${total} wird analysiert…`,
|
`Seite ${current} von ${total} wird analysiert…`,
|
||||||
|
ocr_status_preprocessing_page: ({ current, total }: { current: string; total: string }) =>
|
||||||
|
`Seite ${current} von ${total} wird aufbereitet…`,
|
||||||
ocr_status_error: () => 'OCR fehlgeschlagen'
|
ocr_status_error: () => 'OCR fehlgeschlagen'
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
@@ -68,6 +70,19 @@ describe('translateOcrProgress', () => {
|
|||||||
expect(result.totalPages).toBe(5);
|
expect(result.totalPages).toBe(5);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('translates PREPROCESSING_PAGE with current and total', () => {
|
||||||
|
const result = translateOcrProgress('PREPROCESSING_PAGE:3:10');
|
||||||
|
expect(result.message).toBe('Seite 3 von 10 wird aufbereitet…');
|
||||||
|
expect(result.currentPage).toBe(3);
|
||||||
|
expect(result.totalPages).toBe(10);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('PREPROCESSING_PAGE with no colon parts gracefully falls back to zero', () => {
|
||||||
|
const result = translateOcrProgress('PREPROCESSING_PAGE');
|
||||||
|
expect(result.currentPage).toBe(0);
|
||||||
|
expect(result.totalPages).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
it('translates ERROR', () => {
|
it('translates ERROR', () => {
|
||||||
expect(translateOcrProgress('ERROR').message).toBe('OCR fehlgeschlagen');
|
expect(translateOcrProgress('ERROR').message).toBe('OCR fehlgeschlagen');
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -48,6 +48,18 @@ export function translateOcrProgress(code: string): OcrProgressResult {
|
|||||||
totalPages: total
|
totalPages: total
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
case 'PREPROCESSING_PAGE': {
|
||||||
|
const current = parseInt(parts[1] ?? '0', 10);
|
||||||
|
const total = parseInt(parts[2] ?? '0', 10);
|
||||||
|
return {
|
||||||
|
message: m.ocr_status_preprocessing_page({
|
||||||
|
current: String(current),
|
||||||
|
total: String(total)
|
||||||
|
}),
|
||||||
|
currentPage: current,
|
||||||
|
totalPages: total
|
||||||
|
};
|
||||||
|
}
|
||||||
case 'ERROR':
|
case 'ERROR':
|
||||||
return { message: m.ocr_status_error() };
|
return { message: m.ocr_status_error() };
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ FROM python:3.11.9-slim
|
|||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# curl for healthcheck; libgomp1 for PyTorch CPU threading; libvips for kraken PDF support
|
# curl for healthcheck; libgomp1 for PyTorch CPU threading; libvips for kraken PDF support
|
||||||
|
# libglib2.0-0 is required by opencv-python-headless on Debian slim
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
curl \
|
curl \
|
||||||
libgomp1 \
|
libgomp1 \
|
||||||
libvips42 \
|
libvips42 \
|
||||||
|
libglib2.0-0 \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# PyTorch CPU-only — separate layer; the whl/cpu index strips all CUDA variants (~2 GB saved)
|
# PyTorch CPU-only — separate layer; the whl/cpu index strips all CUDA variants (~2 GB saved)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from confidence import apply_confidence_markers, get_threshold
|
|||||||
from engines import kraken as kraken_engine
|
from engines import kraken as kraken_engine
|
||||||
from engines import surya as surya_engine
|
from engines import surya as surya_engine
|
||||||
from models import OcrBlock, OcrRequest
|
from models import OcrBlock, OcrRequest
|
||||||
|
from preprocessing import preprocess_page
|
||||||
|
|
||||||
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
|
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
|
||||||
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
||||||
@@ -86,6 +87,10 @@ async def run_ocr(request: OcrRequest):
|
|||||||
|
|
||||||
images = await _download_and_convert_pdf(request.pdfUrl)
|
images = await _download_and_convert_pdf(request.pdfUrl)
|
||||||
|
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
images[i] = await asyncio.to_thread(preprocess_page, img)
|
||||||
|
del img
|
||||||
|
|
||||||
script_type = request.scriptType.upper()
|
script_type = request.scriptType.upper()
|
||||||
|
|
||||||
if script_type == "HANDWRITING_KURRENT":
|
if script_type == "HANDWRITING_KURRENT":
|
||||||
@@ -157,6 +162,8 @@ async def run_ocr_stream(request: OcrRequest):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
||||||
|
image = await asyncio.to_thread(preprocess_page, image)
|
||||||
blocks = []
|
blocks = []
|
||||||
for region in page_regions:
|
for region in page_regions:
|
||||||
text = await asyncio.to_thread(
|
text = await asyncio.to_thread(
|
||||||
@@ -214,6 +221,8 @@ async def run_ocr_stream(request: OcrRequest):
|
|||||||
|
|
||||||
for page_idx, image in enumerate(images, start=1):
|
for page_idx, image in enumerate(images, start=1):
|
||||||
try:
|
try:
|
||||||
|
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
||||||
|
image = await asyncio.to_thread(preprocess_page, image)
|
||||||
blocks = await asyncio.to_thread(
|
blocks = await asyncio.to_thread(
|
||||||
engine.extract_page_blocks, image, page_idx, request.language
|
engine.extract_page_blocks, image, page_idx, request.language
|
||||||
)
|
)
|
||||||
|
|||||||
50
ocr-service/preprocessing.py
Normal file
50
ocr-service/preprocessing.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Image preprocessing pipeline for aged document OCR quality improvement."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CLAHE_CLIP_LIMIT = float(os.environ.get("OCR_CLAHE_CLIP_LIMIT", "2.0"))
|
||||||
|
CLAHE_TILE_SIZE = int(os.environ.get("OCR_CLAHE_TILE_SIZE", "8"))
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_page(image: Image.Image) -> Image.Image:
|
||||||
|
"""Apply CLAHE + grayscale + Gaussian blur to improve OCR quality on aged documents.
|
||||||
|
|
||||||
|
Falls back silently to the original image if cv2, numpy, or memory errors occur.
|
||||||
|
Unexpected exceptions (programming errors) are allowed to propagate.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
img_array = np.array(image)
|
||||||
|
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
|
||||||
|
del img_array
|
||||||
|
|
||||||
|
l_channel = lab[:, :, 0].copy()
|
||||||
|
del lab
|
||||||
|
|
||||||
|
clahe = cv2.createCLAHE(
|
||||||
|
clipLimit=CLAHE_CLIP_LIMIT,
|
||||||
|
tileGridSize=(CLAHE_TILE_SIZE, CLAHE_TILE_SIZE),
|
||||||
|
)
|
||||||
|
l_clahe = clahe.apply(l_channel)
|
||||||
|
del l_channel
|
||||||
|
|
||||||
|
blurred = cv2.GaussianBlur(l_clahe, (3, 3), 0)
|
||||||
|
del l_clahe
|
||||||
|
|
||||||
|
result = Image.fromarray(blurred, mode="L")
|
||||||
|
del blurred
|
||||||
|
|
||||||
|
return result
|
||||||
|
except (cv2.error, ValueError, MemoryError) as e:
|
||||||
|
logger.warning(
|
||||||
|
"preprocess_page failed (falling back to original): %s: %s",
|
||||||
|
type(e).__name__,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return image
|
||||||
@@ -8,3 +8,4 @@ pillow>=10.2.0,<11.0.0
|
|||||||
pypdfium2==4.30.0
|
pypdfium2==4.30.0
|
||||||
pyvips>=2.2.0
|
pyvips>=2.2.0
|
||||||
httpx==0.28.1
|
httpx==0.28.1
|
||||||
|
opencv-python-headless==4.11.0.86
|
||||||
|
|||||||
82
ocr-service/test_preprocessing.py
Normal file
82
ocr-service/test_preprocessing.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Tests for the image preprocessing pipeline."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
||||||
|
def _make_yellowed_image(width=100, height=100):
|
||||||
|
"""Dark, faded yellowed page: L values in a narrow low range with spatial noise.
|
||||||
|
|
||||||
|
Very dark (R≈30, G≈20, B≈10) → L_cv ≈ 80-100 in OpenCV uint8 LAB space.
|
||||||
|
The per-pixel noise gives each CLAHE tile a non-trivial histogram to equalize,
|
||||||
|
which stretches the narrow dark range toward [0-255] and reliably increases mean L.
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(42)
|
||||||
|
arr = np.zeros((height, width, 3), dtype=np.uint8)
|
||||||
|
arr[:, :, 0] = np.clip(30 + rng.integers(-8, 9, (height, width)), 0, 255)
|
||||||
|
arr[:, :, 1] = np.clip(20 + rng.integers(-5, 6, (height, width)), 0, 255)
|
||||||
|
arr[:, :, 2] = np.clip(10 + rng.integers(-3, 4, (height, width)), 0, 255)
|
||||||
|
return Image.fromarray(arr.astype(np.uint8), mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreprocessPage:
|
||||||
|
def test_output_has_same_dimensions_as_input(self):
|
||||||
|
from preprocessing import preprocess_page
|
||||||
|
|
||||||
|
img = Image.new("RGB", (150, 200))
|
||||||
|
result = preprocess_page(img)
|
||||||
|
|
||||||
|
assert result.size == img.size
|
||||||
|
|
||||||
|
def test_l_channel_mean_increases_on_yellowed_image(self):
|
||||||
|
"""CLAHE equalizes the dark narrow-range histogram toward [0-255], raising mean L."""
|
||||||
|
from preprocessing import preprocess_page
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
img = _make_yellowed_image()
|
||||||
|
|
||||||
|
arr_before = np.array(img)
|
||||||
|
lab_before = cv2.cvtColor(arr_before, cv2.COLOR_RGB2LAB)
|
||||||
|
l_mean_before = float(lab_before[:, :, 0].mean())
|
||||||
|
|
||||||
|
result = preprocess_page(img)
|
||||||
|
|
||||||
|
# Output is grayscale (mode "L"); its values ARE the CLAHE-enhanced L channel
|
||||||
|
l_mean_after = float(np.array(result).mean())
|
||||||
|
|
||||||
|
assert l_mean_after > l_mean_before
|
||||||
|
|
||||||
|
def test_does_not_crash_on_sub_tile_size_image(self):
|
||||||
|
"""A 1×1 image is smaller than the CLAHE tile (8×8) in both axes.
|
||||||
|
preprocess_page must not raise — it either succeeds or falls back silently."""
|
||||||
|
from preprocessing import preprocess_page
|
||||||
|
|
||||||
|
img = Image.new("RGB", (1, 1), color=(128, 100, 80))
|
||||||
|
result = preprocess_page(img)
|
||||||
|
|
||||||
|
assert isinstance(result, Image.Image)
|
||||||
|
|
||||||
|
def test_falls_back_to_pixel_identical_original_on_cv2_error(self):
|
||||||
|
"""When cv2 raises a known error, preprocess_page returns the unmodified original image."""
|
||||||
|
from preprocessing import preprocess_page
|
||||||
|
|
||||||
|
img = Image.new("RGB", (80, 60), color=(123, 45, 67))
|
||||||
|
original_pixels = list(img.getdata())
|
||||||
|
|
||||||
|
with patch("preprocessing.cv2.cvtColor", side_effect=ValueError("bad input")):
|
||||||
|
result = preprocess_page(img)
|
||||||
|
|
||||||
|
result_pixels = list(result.getdata())
|
||||||
|
assert result_pixels == original_pixels
|
||||||
|
|
||||||
|
def test_unexpected_exception_propagates(self):
|
||||||
|
"""A RuntimeError (programming error) must propagate — not be swallowed by the cv2 fallback."""
|
||||||
|
from preprocessing import preprocess_page
|
||||||
|
|
||||||
|
img = Image.new("RGB", (80, 60))
|
||||||
|
|
||||||
|
with patch("preprocessing.cv2.cvtColor", side_effect=RuntimeError("unexpected")):
|
||||||
|
with pytest.raises(RuntimeError, match="unexpected"):
|
||||||
|
preprocess_page(img)
|
||||||
@@ -301,3 +301,83 @@ async def test_ocr_stream_applies_confidence_markers(mock_images):
|
|||||||
page = [l for l in lines if l["type"] == "page"][0]
|
page = [l for l in lines if l["type"] == "page"][0]
|
||||||
assert page["blocks"][0]["text"] == "Lieber [unleserlich]"
|
assert page["blocks"][0]["text"] == "Lieber [unleserlich]"
|
||||||
assert "words" not in page["blocks"][0]
|
assert "words" not in page["blocks"][0]
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Preprocessing integration ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_emits_preprocessing_event_per_page_before_page_event(mock_images):
|
||||||
|
"""generate() must emit a preprocessing event for each page before its page event."""
|
||||||
|
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||||
|
patch("main._models_ready", True), \
|
||||||
|
patch("main.surya_engine") as mock_surya, \
|
||||||
|
patch("main.preprocess_page", side_effect=lambda img: img) as mock_preprocess:
|
||||||
|
mock_surya.extract_page_blocks.return_value = []
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr/stream", json={
|
||||||
|
"pdfUrl": "http://minio/test.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
})
|
||||||
|
|
||||||
|
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||||
|
preprocessing_events = [l for l in lines if l["type"] == "preprocessing"]
|
||||||
|
page_events = [l for l in lines if l["type"] == "page"]
|
||||||
|
|
||||||
|
assert len(preprocessing_events) == 3
|
||||||
|
assert preprocessing_events[0]["pageNumber"] == 1
|
||||||
|
assert preprocessing_events[1]["pageNumber"] == 2
|
||||||
|
assert preprocessing_events[2]["pageNumber"] == 3
|
||||||
|
|
||||||
|
# Each preprocessing event must come immediately before the corresponding page event
|
||||||
|
for pre, page in zip(preprocessing_events, page_events):
|
||||||
|
assert pre["pageNumber"] == page["pageNumber"]
|
||||||
|
|
||||||
|
assert mock_preprocess.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_guided_stream_preprocesses_once_per_page_not_per_region(mock_images):
|
||||||
|
"""generate_guided() must call preprocess_page once per page, not once per region."""
|
||||||
|
regions = [
|
||||||
|
{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 0.5, "height": 0.5, "annotationId": "a1"},
|
||||||
|
{"pageNumber": 1, "x": 0.5, "y": 0.0, "width": 0.5, "height": 0.5, "annotationId": "a2"},
|
||||||
|
{"pageNumber": 2, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0, "annotationId": "a3"},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images[:2]), \
|
||||||
|
patch("main._models_ready", True), \
|
||||||
|
patch("main.surya_engine") as mock_surya, \
|
||||||
|
patch("main.preprocess_page", side_effect=lambda img: img) as mock_preprocess:
|
||||||
|
mock_surya.extract_region_text.return_value = "text"
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr/stream", json={
|
||||||
|
"pdfUrl": "http://minio/test.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
"regions": regions,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
# 2 pages, each preprocessed once — not 3 (once per region)
|
||||||
|
assert mock_preprocess.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ocr_endpoint_preprocesses_silently_without_emitting_events(mock_images):
|
||||||
|
"""/ocr endpoint preprocesses each page without emitting preprocessing events."""
|
||||||
|
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||||
|
patch("main._models_ready", True), \
|
||||||
|
patch("main.surya_engine") as mock_surya, \
|
||||||
|
patch("main.preprocess_page", side_effect=lambda img: img) as mock_preprocess:
|
||||||
|
mock_surya.extract_blocks.return_value = []
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr", json={
|
||||||
|
"pdfUrl": "http://minio/test.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert mock_preprocess.call_count == 3
|
||||||
|
|||||||
Reference in New Issue
Block a user