feat(ocr): add image preprocessing pipeline to improve transcription quality on aged documents #255

Merged
marcel merged 7 commits from feat/issue-252-ocr-preprocessing into main 2026-04-17 15:50:37 +02:00
18 changed files with 327 additions and 0 deletions

View File

@@ -82,6 +82,10 @@ public class OcrAsyncRunner {
ocrJobDocumentRepository.save(jobDoc);
}
}
case OcrStreamEvent.Preprocessing preprocessing -> {
updateProgress(job, "PREPROCESSING_PAGE:" + preprocessing.pageNumber()
+ ":" + totalPages.get());
}
case OcrStreamEvent.Page page -> {
for (OcrBlockResult block : page.blocks()) {
createSingleBlock(documentId, block, userId,

View File

@@ -6,6 +6,8 @@ public sealed interface OcrStreamEvent {
record Start(int totalPages) implements OcrStreamEvent {}
record Preprocessing(int pageNumber) implements OcrStreamEvent {}
record Page(int pageNumber, List<OcrBlockResult> blocks) implements OcrStreamEvent {}
record Error(int pageNumber, String message) implements OcrStreamEvent {}

View File

@@ -232,6 +232,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
switch (type) {
case "start" -> handler.accept(
new OcrStreamEvent.Start(node.path("totalPages").asInt()));
case "preprocessing" -> handler.accept(
new OcrStreamEvent.Preprocessing(node.path("pageNumber").asInt()));
case "page" -> {
int pageNumber = node.path("pageNumber").asInt();
List<OcrBlockResult> blocks = NDJSON_MAPPER.convertValue(

View File

@@ -230,6 +230,42 @@ class OcrAsyncRunnerTest {
assertThat(job.getProgressMessage()).isEqualTo("DONE:0:1");
}
@Test
void runSingleDocument_updatesProgressOnPreprocessingEvent() {
UUID jobId = UUID.randomUUID();
UUID docId = UUID.randomUUID();
UUID userId = UUID.randomUUID();
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
Document doc = Document.builder().id(docId).filePath("test.pdf")
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
.thenReturn(Optional.of(jobDoc));
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
List<String> progressMessages = new ArrayList<>();
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
handler.accept(new OcrStreamEvent.Start(5));
handler.accept(new OcrStreamEvent.Preprocessing(1));
progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Page(1, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
// Must have totalPages=5 from the preceding Start event, not 0
assertThat(progressMessages.get(0)).isEqualTo("PREPROCESSING_PAGE:1:5");
}
@Test
void runSingleDocument_logsStreamErrorAtWarnWithoutSettingJobFailed() {
UUID jobId = UUID.randomUUID();

View File

@@ -37,11 +37,19 @@ class OcrStreamEventTest {
assertThat(done.skippedPages()).isEqualTo(2);
}
@Test
void preprocessingRecordHoldsPageNumber() {
var preprocessing = new OcrStreamEvent.Preprocessing(4);
assertThat(preprocessing.pageNumber()).isEqualTo(4);
assertThat(preprocessing).isInstanceOf(OcrStreamEvent.class);
}
@Test
void patternMatchingWorksOnSealedInterface() {
OcrStreamEvent event = new OcrStreamEvent.Start(3);
String result = switch (event) {
case OcrStreamEvent.Start s -> "start:" + s.totalPages();
case OcrStreamEvent.Preprocessing pre -> "preprocessing:" + pre.pageNumber();
case OcrStreamEvent.Page p -> "page:" + p.pageNumber();
case OcrStreamEvent.Error e -> "error:" + e.pageNumber();
case OcrStreamEvent.Done d -> "done:" + d.totalBlocks();

View File

@@ -83,6 +83,25 @@ class RestClientOcrClientStreamTest {
assertThat(events).hasSize(3);
}
@Test
void parseNdjsonStream_dispatchesPreprocessingEvent() {
String ndjson = """
{"type":"start","totalPages":3}
{"type":"preprocessing","pageNumber":1}
{"type":"page","pageNumber":1,"blocks":[]}
{"type":"done","totalBlocks":0,"skippedPages":0}
""";
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
List<OcrStreamEvent> events = new ArrayList<>();
RestClientOcrClient.parseNdjsonStream(stream, events::add);
assertThat(events).hasSize(4);
assertThat(events.get(1)).isInstanceOf(OcrStreamEvent.Preprocessing.class);
var preprocessing = (OcrStreamEvent.Preprocessing) events.get(1);
assertThat(preprocessing.pageNumber()).isEqualTo(1);
}
@Test
void parseNdjsonStream_ignoresUnknownEventTypes() {
String ndjson = """

View File

@@ -95,6 +95,8 @@ services:
OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5"
RECOGNITION_BATCH_SIZE: "16"
DETECTOR_BATCH_SIZE: "8"
OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency)
OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page)
networks:
- archive-net
healthcheck:

View File

@@ -526,6 +526,7 @@
"ocr_status_creating_blocks": "{count} Textblöcke erkannt — erstelle Transkription…",
"ocr_status_done_blocks": "{count} Blöcke erstellt",
"ocr_status_analyzing_page": "Seite {current} von {total} wird analysiert…",
"ocr_status_preprocessing_page": "Seite {current} von {total} wird aufbereitet…",
"ocr_status_done_skipped": "{count} Blöcke erstellt, {skipped} Seite(n) übersprungen",
"ocr_status_error": "OCR fehlgeschlagen",
"ocr_trigger_no_annotations": "Zeichnen Sie zuerst Bereiche auf dem Dokument ein.",

View File

@@ -526,6 +526,7 @@
"ocr_status_creating_blocks": "{count} text blocks detected — creating transcription…",
"ocr_status_done_blocks": "{count} blocks created",
"ocr_status_analyzing_page": "Analyzing page {current} of {total}…",
"ocr_status_preprocessing_page": "Preparing page {current} of {total}…",
"ocr_status_done_skipped": "{count} blocks created, {skipped} page(s) skipped",
"ocr_status_error": "OCR failed",
"ocr_trigger_no_annotations": "Draw regions on the document first.",

View File

@@ -526,6 +526,7 @@
"ocr_status_creating_blocks": "{count} bloques de texto detectados — creando transcripción…",
"ocr_status_done_blocks": "{count} bloques creados",
"ocr_status_analyzing_page": "Analizando página {current} de {total}…",
"ocr_status_preprocessing_page": "Preparando página {current} de {total}…",
"ocr_status_done_skipped": "{count} bloques creados, {skipped} página(s) omitida(s)",
"ocr_status_error": "OCR fallido",
"ocr_trigger_no_annotations": "Dibuje regiones en el documento primero.",

View File

@@ -12,6 +12,8 @@ vi.mock('$lib/paraglide/messages.js', () => ({
`${count} Blöcke erstellt, ${skipped} Seite(n) übersprungen`,
ocr_status_analyzing_page: ({ current, total }: { current: string; total: string }) =>
`Seite ${current} von ${total} wird analysiert…`,
ocr_status_preprocessing_page: ({ current, total }: { current: string; total: string }) =>
`Seite ${current} von ${total} wird aufbereitet…`,
ocr_status_error: () => 'OCR fehlgeschlagen'
}
}));
@@ -68,6 +70,19 @@ describe('translateOcrProgress', () => {
expect(result.totalPages).toBe(5);
});
it('translates PREPROCESSING_PAGE with current and total', () => {
const result = translateOcrProgress('PREPROCESSING_PAGE:3:10');
expect(result.message).toBe('Seite 3 von 10 wird aufbereitet…');
expect(result.currentPage).toBe(3);
expect(result.totalPages).toBe(10);
});
it('PREPROCESSING_PAGE with no colon parts gracefully falls back to zero', () => {
const result = translateOcrProgress('PREPROCESSING_PAGE');
expect(result.currentPage).toBe(0);
expect(result.totalPages).toBe(0);
});
it('translates ERROR', () => {
expect(translateOcrProgress('ERROR').message).toBe('OCR fehlgeschlagen');
});

View File

@@ -48,6 +48,18 @@ export function translateOcrProgress(code: string): OcrProgressResult {
totalPages: total
};
}
case 'PREPROCESSING_PAGE': {
const current = parseInt(parts[1] ?? '0', 10);
const total = parseInt(parts[2] ?? '0', 10);
return {
message: m.ocr_status_preprocessing_page({
current: String(current),
total: String(total)
}),
currentPage: current,
totalPages: total
};
}
case 'ERROR':
return { message: m.ocr_status_error() };
default:

View File

@@ -3,10 +3,12 @@ FROM python:3.11.9-slim
WORKDIR /app
# curl for healthcheck; libgomp1 for PyTorch CPU threading; libvips for kraken PDF support
# libglib2.0-0 is required by opencv-python-headless on Debian slim
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
libgomp1 \
libvips42 \
libglib2.0-0 \
&& rm -rf /var/lib/apt/lists/*
# PyTorch CPU-only — separate layer; the whl/cpu index strips all CUDA variants (~2 GB saved)

View File

@@ -25,6 +25,7 @@ from confidence import apply_confidence_markers, get_threshold
from engines import kraken as kraken_engine
from engines import surya as surya_engine
from models import OcrBlock, OcrRequest
from preprocessing import preprocess_page
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
@@ -86,6 +87,10 @@ async def run_ocr(request: OcrRequest):
images = await _download_and_convert_pdf(request.pdfUrl)
for i, img in enumerate(images):
images[i] = await asyncio.to_thread(preprocess_page, img)
del img
script_type = request.scriptType.upper()
if script_type == "HANDWRITING_KURRENT":
@@ -157,6 +162,8 @@ async def run_ocr_stream(request: OcrRequest):
continue
try:
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
image = await asyncio.to_thread(preprocess_page, image)
blocks = []
for region in page_regions:
text = await asyncio.to_thread(
@@ -214,6 +221,8 @@ async def run_ocr_stream(request: OcrRequest):
for page_idx, image in enumerate(images, start=1):
try:
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
image = await asyncio.to_thread(preprocess_page, image)
blocks = await asyncio.to_thread(
engine.extract_page_blocks, image, page_idx, request.language
)

View File

@@ -0,0 +1,50 @@
"""Image preprocessing pipeline for aged document OCR quality improvement."""
import logging
import os
import cv2
import numpy as np
from PIL import Image
logger = logging.getLogger(__name__)
CLAHE_CLIP_LIMIT = float(os.environ.get("OCR_CLAHE_CLIP_LIMIT", "2.0"))
CLAHE_TILE_SIZE = int(os.environ.get("OCR_CLAHE_TILE_SIZE", "8"))
def preprocess_page(image: Image.Image) -> Image.Image:
"""Apply CLAHE + grayscale + Gaussian blur to improve OCR quality on aged documents.
Falls back silently to the original image if cv2, numpy, or memory errors occur.
Unexpected exceptions (programming errors) are allowed to propagate.
"""
try:
img_array = np.array(image)
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
del img_array
l_channel = lab[:, :, 0].copy()
del lab
clahe = cv2.createCLAHE(
clipLimit=CLAHE_CLIP_LIMIT,
tileGridSize=(CLAHE_TILE_SIZE, CLAHE_TILE_SIZE),
)
l_clahe = clahe.apply(l_channel)
del l_channel
blurred = cv2.GaussianBlur(l_clahe, (3, 3), 0)
del l_clahe
result = Image.fromarray(blurred, mode="L")
del blurred
return result
except (cv2.error, ValueError, MemoryError) as e:
logger.warning(
"preprocess_page failed (falling back to original): %s: %s",
type(e).__name__,
e,
)
return image

View File

@@ -8,3 +8,4 @@ pillow>=10.2.0,<11.0.0
pypdfium2==4.30.0
pyvips>=2.2.0
httpx==0.28.1
opencv-python-headless==4.11.0.86

View File

@@ -0,0 +1,82 @@
"""Tests for the image preprocessing pipeline."""
import numpy as np
import pytest
from PIL import Image
from unittest.mock import patch
def _make_yellowed_image(width=100, height=100):
"""Dark, faded yellowed page: L values in a narrow low range with spatial noise.
Very dark (R≈30, G≈20, B≈10) → L_cv ≈ 80-100 in OpenCV uint8 LAB space.
The per-pixel noise gives each CLAHE tile a non-trivial histogram to equalize,
which stretches the narrow dark range toward [0-255] and reliably increases mean L.
"""
rng = np.random.default_rng(42)
arr = np.zeros((height, width, 3), dtype=np.uint8)
arr[:, :, 0] = np.clip(30 + rng.integers(-8, 9, (height, width)), 0, 255)
arr[:, :, 1] = np.clip(20 + rng.integers(-5, 6, (height, width)), 0, 255)
arr[:, :, 2] = np.clip(10 + rng.integers(-3, 4, (height, width)), 0, 255)
return Image.fromarray(arr.astype(np.uint8), mode="RGB")
class TestPreprocessPage:
def test_output_has_same_dimensions_as_input(self):
from preprocessing import preprocess_page
img = Image.new("RGB", (150, 200))
result = preprocess_page(img)
assert result.size == img.size
def test_l_channel_mean_increases_on_yellowed_image(self):
"""CLAHE equalizes the dark narrow-range histogram toward [0-255], raising mean L."""
from preprocessing import preprocess_page
import cv2
img = _make_yellowed_image()
arr_before = np.array(img)
lab_before = cv2.cvtColor(arr_before, cv2.COLOR_RGB2LAB)
l_mean_before = float(lab_before[:, :, 0].mean())
result = preprocess_page(img)
# Output is grayscale (mode "L"); its values ARE the CLAHE-enhanced L channel
l_mean_after = float(np.array(result).mean())
assert l_mean_after > l_mean_before
def test_does_not_crash_on_sub_tile_size_image(self):
"""A 1×1 image is smaller than the CLAHE tile (8×8) in both axes.
preprocess_page must not raise — it either succeeds or falls back silently."""
from preprocessing import preprocess_page
img = Image.new("RGB", (1, 1), color=(128, 100, 80))
result = preprocess_page(img)
assert isinstance(result, Image.Image)
def test_falls_back_to_pixel_identical_original_on_cv2_error(self):
"""When cv2 raises a known error, preprocess_page returns the unmodified original image."""
from preprocessing import preprocess_page
img = Image.new("RGB", (80, 60), color=(123, 45, 67))
original_pixels = list(img.getdata())
with patch("preprocessing.cv2.cvtColor", side_effect=ValueError("bad input")):
result = preprocess_page(img)
result_pixels = list(result.getdata())
assert result_pixels == original_pixels
def test_unexpected_exception_propagates(self):
"""A RuntimeError (programming error) must propagate — not be swallowed by the cv2 fallback."""
from preprocessing import preprocess_page
img = Image.new("RGB", (80, 60))
with patch("preprocessing.cv2.cvtColor", side_effect=RuntimeError("unexpected")):
with pytest.raises(RuntimeError, match="unexpected"):
preprocess_page(img)

View File

@@ -301,3 +301,83 @@ async def test_ocr_stream_applies_confidence_markers(mock_images):
page = [l for l in lines if l["type"] == "page"][0]
assert page["blocks"][0]["text"] == "Lieber [unleserlich]"
assert "words" not in page["blocks"][0]
# ─── Preprocessing integration ───────────────────────────────────────────────
@pytest.mark.asyncio
async def test_stream_emits_preprocessing_event_per_page_before_page_event(mock_images):
"""generate() must emit a preprocessing event for each page before its page event."""
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya, \
patch("main.preprocess_page", side_effect=lambda img: img) as mock_preprocess:
mock_surya.extract_page_blocks.return_value = []
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
preprocessing_events = [l for l in lines if l["type"] == "preprocessing"]
page_events = [l for l in lines if l["type"] == "page"]
assert len(preprocessing_events) == 3
assert preprocessing_events[0]["pageNumber"] == 1
assert preprocessing_events[1]["pageNumber"] == 2
assert preprocessing_events[2]["pageNumber"] == 3
# Each preprocessing event must come immediately before the corresponding page event
for pre, page in zip(preprocessing_events, page_events):
assert pre["pageNumber"] == page["pageNumber"]
assert mock_preprocess.call_count == 3
@pytest.mark.asyncio
async def test_guided_stream_preprocesses_once_per_page_not_per_region(mock_images):
"""generate_guided() must call preprocess_page once per page, not once per region."""
regions = [
{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 0.5, "height": 0.5, "annotationId": "a1"},
{"pageNumber": 1, "x": 0.5, "y": 0.0, "width": 0.5, "height": 0.5, "annotationId": "a2"},
{"pageNumber": 2, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0, "annotationId": "a3"},
]
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images[:2]), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya, \
patch("main.preprocess_page", side_effect=lambda img: img) as mock_preprocess:
mock_surya.extract_region_text.return_value = "text"
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
"regions": regions,
})
assert response.status_code == 200
# 2 pages, each preprocessed once — not 3 (once per region)
assert mock_preprocess.call_count == 2
@pytest.mark.asyncio
async def test_ocr_endpoint_preprocesses_silently_without_emitting_events(mock_images):
"""/ocr endpoint preprocesses each page without emitting preprocessing events."""
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya, \
patch("main.preprocess_page", side_effect=lambda img: img) as mock_preprocess:
mock_surya.extract_blocks.return_value = []
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
assert response.status_code == 200
assert mock_preprocess.call_count == 3