feat(ocr): add image preprocessing pipeline to improve transcription quality on aged documents #255
@@ -82,6 +82,10 @@ public class OcrAsyncRunner {
|
||||
ocrJobDocumentRepository.save(jobDoc);
|
||||
}
|
||||
}
|
||||
case OcrStreamEvent.Preprocessing preprocessing -> {
|
||||
updateProgress(job, "PREPROCESSING_PAGE:" + preprocessing.pageNumber()
|
||||
+ ":" + totalPages.get());
|
||||
}
|
||||
case OcrStreamEvent.Page page -> {
|
||||
for (OcrBlockResult block : page.blocks()) {
|
||||
createSingleBlock(documentId, block, userId,
|
||||
|
||||
@@ -6,6 +6,8 @@ public sealed interface OcrStreamEvent {
|
||||
|
||||
record Start(int totalPages) implements OcrStreamEvent {}
|
||||
|
||||
record Preprocessing(int pageNumber) implements OcrStreamEvent {}
|
||||
|
||||
record Page(int pageNumber, List<OcrBlockResult> blocks) implements OcrStreamEvent {}
|
||||
|
||||
record Error(int pageNumber, String message) implements OcrStreamEvent {}
|
||||
|
||||
@@ -232,6 +232,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
switch (type) {
|
||||
case "start" -> handler.accept(
|
||||
new OcrStreamEvent.Start(node.path("totalPages").asInt()));
|
||||
case "preprocessing" -> handler.accept(
|
||||
new OcrStreamEvent.Preprocessing(node.path("pageNumber").asInt()));
|
||||
case "page" -> {
|
||||
int pageNumber = node.path("pageNumber").asInt();
|
||||
List<OcrBlockResult> blocks = NDJSON_MAPPER.convertValue(
|
||||
|
||||
@@ -230,6 +230,42 @@ class OcrAsyncRunnerTest {
|
||||
assertThat(job.getProgressMessage()).isEqualTo("DONE:0:1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void runSingleDocument_updatesProgressOnPreprocessingEvent() {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID userId = UUID.randomUUID();
|
||||
OcrJob job = OcrJob.builder().id(jobId).totalDocuments(1).status(OcrJobStatus.PENDING).build();
|
||||
OcrJobDocument jobDoc = OcrJobDocument.builder().id(UUID.randomUUID())
|
||||
.jobId(jobId).documentId(docId).status(OcrDocumentStatus.PENDING).build();
|
||||
Document doc = Document.builder().id(docId).filePath("test.pdf")
|
||||
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
|
||||
|
||||
when(ocrJobRepository.findById(jobId)).thenReturn(Optional.of(job));
|
||||
when(ocrJobRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(ocrJobDocumentRepository.findByJobIdAndDocumentId(jobId, docId))
|
||||
.thenReturn(Optional.of(jobDoc));
|
||||
when(ocrJobDocumentRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
|
||||
List<String> progressMessages = new ArrayList<>();
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
handler.accept(new OcrStreamEvent.Start(5));
|
||||
handler.accept(new OcrStreamEvent.Preprocessing(1));
|
||||
progressMessages.add(job.getProgressMessage());
|
||||
handler.accept(new OcrStreamEvent.Page(1, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
// Must have totalPages=5 from the preceding Start event, not 0
|
||||
assertThat(progressMessages.get(0)).isEqualTo("PREPROCESSING_PAGE:1:5");
|
||||
}
|
||||
|
||||
@Test
|
||||
void runSingleDocument_logsStreamErrorAtWarnWithoutSettingJobFailed() {
|
||||
UUID jobId = UUID.randomUUID();
|
||||
|
||||
@@ -37,11 +37,19 @@ class OcrStreamEventTest {
|
||||
assertThat(done.skippedPages()).isEqualTo(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void preprocessingRecordHoldsPageNumber() {
|
||||
var preprocessing = new OcrStreamEvent.Preprocessing(4);
|
||||
assertThat(preprocessing.pageNumber()).isEqualTo(4);
|
||||
assertThat(preprocessing).isInstanceOf(OcrStreamEvent.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void patternMatchingWorksOnSealedInterface() {
|
||||
OcrStreamEvent event = new OcrStreamEvent.Start(3);
|
||||
String result = switch (event) {
|
||||
case OcrStreamEvent.Start s -> "start:" + s.totalPages();
|
||||
case OcrStreamEvent.Preprocessing pre -> "preprocessing:" + pre.pageNumber();
|
||||
case OcrStreamEvent.Page p -> "page:" + p.pageNumber();
|
||||
case OcrStreamEvent.Error e -> "error:" + e.pageNumber();
|
||||
case OcrStreamEvent.Done d -> "done:" + d.totalBlocks();
|
||||
|
||||
@@ -83,6 +83,25 @@ class RestClientOcrClientStreamTest {
|
||||
assertThat(events).hasSize(3);
|
||||
}
|
||||
|
||||
@Test
|
||||
void parseNdjsonStream_dispatchesPreprocessingEvent() {
|
||||
String ndjson = """
|
||||
{"type":"start","totalPages":3}
|
||||
{"type":"preprocessing","pageNumber":1}
|
||||
{"type":"page","pageNumber":1,"blocks":[]}
|
||||
{"type":"done","totalBlocks":0,"skippedPages":0}
|
||||
""";
|
||||
InputStream stream = new ByteArrayInputStream(ndjson.getBytes(StandardCharsets.UTF_8));
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
RestClientOcrClient.parseNdjsonStream(stream, events::add);
|
||||
|
||||
assertThat(events).hasSize(4);
|
||||
assertThat(events.get(1)).isInstanceOf(OcrStreamEvent.Preprocessing.class);
|
||||
var preprocessing = (OcrStreamEvent.Preprocessing) events.get(1);
|
||||
assertThat(preprocessing.pageNumber()).isEqualTo(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void parseNdjsonStream_ignoresUnknownEventTypes() {
|
||||
String ndjson = """
|
||||
|
||||
@@ -95,6 +95,8 @@ services:
|
||||
OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5"
|
||||
RECOGNITION_BATCH_SIZE: "16"
|
||||
DETECTOR_BATCH_SIZE: "8"
|
||||
OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency)
|
||||
OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page)
|
||||
networks:
|
||||
- archive-net
|
||||
healthcheck:
|
||||
|
||||
@@ -526,6 +526,7 @@
|
||||
"ocr_status_creating_blocks": "{count} Textblöcke erkannt — erstelle Transkription…",
|
||||
"ocr_status_done_blocks": "{count} Blöcke erstellt",
|
||||
"ocr_status_analyzing_page": "Seite {current} von {total} wird analysiert…",
|
||||
"ocr_status_preprocessing_page": "Seite {current} von {total} wird aufbereitet…",
|
||||
"ocr_status_done_skipped": "{count} Blöcke erstellt, {skipped} Seite(n) übersprungen",
|
||||
"ocr_status_error": "OCR fehlgeschlagen",
|
||||
"ocr_trigger_no_annotations": "Zeichnen Sie zuerst Bereiche auf dem Dokument ein.",
|
||||
|
||||
@@ -526,6 +526,7 @@
|
||||
"ocr_status_creating_blocks": "{count} text blocks detected — creating transcription…",
|
||||
"ocr_status_done_blocks": "{count} blocks created",
|
||||
"ocr_status_analyzing_page": "Analyzing page {current} of {total}…",
|
||||
"ocr_status_preprocessing_page": "Preparing page {current} of {total}…",
|
||||
"ocr_status_done_skipped": "{count} blocks created, {skipped} page(s) skipped",
|
||||
"ocr_status_error": "OCR failed",
|
||||
"ocr_trigger_no_annotations": "Draw regions on the document first.",
|
||||
|
||||
@@ -526,6 +526,7 @@
|
||||
"ocr_status_creating_blocks": "{count} bloques de texto detectados — creando transcripción…",
|
||||
"ocr_status_done_blocks": "{count} bloques creados",
|
||||
"ocr_status_analyzing_page": "Analizando página {current} de {total}…",
|
||||
"ocr_status_preprocessing_page": "Preparando página {current} de {total}…",
|
||||
"ocr_status_done_skipped": "{count} bloques creados, {skipped} página(s) omitida(s)",
|
||||
"ocr_status_error": "OCR fallido",
|
||||
"ocr_trigger_no_annotations": "Dibuje regiones en el documento primero.",
|
||||
|
||||
@@ -12,6 +12,8 @@ vi.mock('$lib/paraglide/messages.js', () => ({
|
||||
`${count} Blöcke erstellt, ${skipped} Seite(n) übersprungen`,
|
||||
ocr_status_analyzing_page: ({ current, total }: { current: string; total: string }) =>
|
||||
`Seite ${current} von ${total} wird analysiert…`,
|
||||
ocr_status_preprocessing_page: ({ current, total }: { current: string; total: string }) =>
|
||||
`Seite ${current} von ${total} wird aufbereitet…`,
|
||||
ocr_status_error: () => 'OCR fehlgeschlagen'
|
||||
}
|
||||
}));
|
||||
@@ -68,6 +70,19 @@ describe('translateOcrProgress', () => {
|
||||
expect(result.totalPages).toBe(5);
|
||||
});
|
||||
|
||||
it('translates PREPROCESSING_PAGE with current and total', () => {
|
||||
const result = translateOcrProgress('PREPROCESSING_PAGE:3:10');
|
||||
expect(result.message).toBe('Seite 3 von 10 wird aufbereitet…');
|
||||
expect(result.currentPage).toBe(3);
|
||||
expect(result.totalPages).toBe(10);
|
||||
});
|
||||
|
||||
it('PREPROCESSING_PAGE with no colon parts gracefully falls back to zero', () => {
|
||||
const result = translateOcrProgress('PREPROCESSING_PAGE');
|
||||
expect(result.currentPage).toBe(0);
|
||||
expect(result.totalPages).toBe(0);
|
||||
});
|
||||
|
||||
it('translates ERROR', () => {
|
||||
expect(translateOcrProgress('ERROR').message).toBe('OCR fehlgeschlagen');
|
||||
});
|
||||
|
||||
@@ -48,6 +48,18 @@ export function translateOcrProgress(code: string): OcrProgressResult {
|
||||
totalPages: total
|
||||
};
|
||||
}
|
||||
case 'PREPROCESSING_PAGE': {
|
||||
const current = parseInt(parts[1] ?? '0', 10);
|
||||
const total = parseInt(parts[2] ?? '0', 10);
|
||||
return {
|
||||
message: m.ocr_status_preprocessing_page({
|
||||
current: String(current),
|
||||
total: String(total)
|
||||
}),
|
||||
currentPage: current,
|
||||
totalPages: total
|
||||
};
|
||||
}
|
||||
case 'ERROR':
|
||||
return { message: m.ocr_status_error() };
|
||||
default:
|
||||
|
||||
@@ -3,10 +3,12 @@ FROM python:3.11.9-slim
|
||||
WORKDIR /app
|
||||
|
||||
# curl for healthcheck; libgomp1 for PyTorch CPU threading; libvips for kraken PDF support
|
||||
# libglib2.0-0 is required by opencv-python-headless on Debian slim
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
libgomp1 \
|
||||
libvips42 \
|
||||
libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# PyTorch CPU-only — separate layer; the whl/cpu index strips all CUDA variants (~2 GB saved)
|
||||
|
||||
@@ -25,6 +25,7 @@ from confidence import apply_confidence_markers, get_threshold
|
||||
from engines import kraken as kraken_engine
|
||||
from engines import surya as surya_engine
|
||||
from models import OcrBlock, OcrRequest
|
||||
from preprocessing import preprocess_page
|
||||
|
||||
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
|
||||
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
||||
@@ -86,6 +87,10 @@ async def run_ocr(request: OcrRequest):
|
||||
|
||||
images = await _download_and_convert_pdf(request.pdfUrl)
|
||||
|
||||
for i, img in enumerate(images):
|
||||
images[i] = await asyncio.to_thread(preprocess_page, img)
|
||||
del img
|
||||
|
||||
script_type = request.scriptType.upper()
|
||||
|
||||
if script_type == "HANDWRITING_KURRENT":
|
||||
@@ -157,6 +162,8 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
continue
|
||||
|
||||
try:
|
||||
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
||||
image = await asyncio.to_thread(preprocess_page, image)
|
||||
blocks = []
|
||||
for region in page_regions:
|
||||
text = await asyncio.to_thread(
|
||||
@@ -214,6 +221,8 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
|
||||
for page_idx, image in enumerate(images, start=1):
|
||||
try:
|
||||
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
||||
image = await asyncio.to_thread(preprocess_page, image)
|
||||
blocks = await asyncio.to_thread(
|
||||
engine.extract_page_blocks, image, page_idx, request.language
|
||||
)
|
||||
|
||||
50
ocr-service/preprocessing.py
Normal file
50
ocr-service/preprocessing.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Image preprocessing pipeline for aged document OCR quality improvement."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CLAHE_CLIP_LIMIT = float(os.environ.get("OCR_CLAHE_CLIP_LIMIT", "2.0"))
|
||||
CLAHE_TILE_SIZE = int(os.environ.get("OCR_CLAHE_TILE_SIZE", "8"))
|
||||
|
||||
|
||||
def preprocess_page(image: Image.Image) -> Image.Image:
|
||||
"""Apply CLAHE + grayscale + Gaussian blur to improve OCR quality on aged documents.
|
||||
|
||||
Falls back silently to the original image if cv2, numpy, or memory errors occur.
|
||||
Unexpected exceptions (programming errors) are allowed to propagate.
|
||||
"""
|
||||
try:
|
||||
img_array = np.array(image)
|
||||
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
|
||||
del img_array
|
||||
|
||||
l_channel = lab[:, :, 0].copy()
|
||||
del lab
|
||||
|
||||
clahe = cv2.createCLAHE(
|
||||
clipLimit=CLAHE_CLIP_LIMIT,
|
||||
tileGridSize=(CLAHE_TILE_SIZE, CLAHE_TILE_SIZE),
|
||||
)
|
||||
l_clahe = clahe.apply(l_channel)
|
||||
del l_channel
|
||||
|
||||
blurred = cv2.GaussianBlur(l_clahe, (3, 3), 0)
|
||||
del l_clahe
|
||||
|
||||
result = Image.fromarray(blurred, mode="L")
|
||||
del blurred
|
||||
|
||||
return result
|
||||
except (cv2.error, ValueError, MemoryError) as e:
|
||||
logger.warning(
|
||||
"preprocess_page failed (falling back to original): %s: %s",
|
||||
type(e).__name__,
|
||||
e,
|
||||
)
|
||||
return image
|
||||
@@ -8,3 +8,4 @@ pillow>=10.2.0,<11.0.0
|
||||
pypdfium2==4.30.0
|
||||
pyvips>=2.2.0
|
||||
httpx==0.28.1
|
||||
opencv-python-headless==4.11.0.86
|
||||
|
||||
82
ocr-service/test_preprocessing.py
Normal file
82
ocr-service/test_preprocessing.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Tests for the image preprocessing pipeline."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def _make_yellowed_image(width=100, height=100):
|
||||
"""Dark, faded yellowed page: L values in a narrow low range with spatial noise.
|
||||
|
||||
Very dark (R≈30, G≈20, B≈10) → L_cv ≈ 80-100 in OpenCV uint8 LAB space.
|
||||
The per-pixel noise gives each CLAHE tile a non-trivial histogram to equalize,
|
||||
which stretches the narrow dark range toward [0-255] and reliably increases mean L.
|
||||
"""
|
||||
rng = np.random.default_rng(42)
|
||||
arr = np.zeros((height, width, 3), dtype=np.uint8)
|
||||
arr[:, :, 0] = np.clip(30 + rng.integers(-8, 9, (height, width)), 0, 255)
|
||||
arr[:, :, 1] = np.clip(20 + rng.integers(-5, 6, (height, width)), 0, 255)
|
||||
arr[:, :, 2] = np.clip(10 + rng.integers(-3, 4, (height, width)), 0, 255)
|
||||
return Image.fromarray(arr.astype(np.uint8), mode="RGB")
|
||||
|
||||
|
||||
class TestPreprocessPage:
|
||||
def test_output_has_same_dimensions_as_input(self):
|
||||
from preprocessing import preprocess_page
|
||||
|
||||
img = Image.new("RGB", (150, 200))
|
||||
result = preprocess_page(img)
|
||||
|
||||
assert result.size == img.size
|
||||
|
||||
def test_l_channel_mean_increases_on_yellowed_image(self):
|
||||
"""CLAHE equalizes the dark narrow-range histogram toward [0-255], raising mean L."""
|
||||
from preprocessing import preprocess_page
|
||||
import cv2
|
||||
|
||||
img = _make_yellowed_image()
|
||||
|
||||
arr_before = np.array(img)
|
||||
lab_before = cv2.cvtColor(arr_before, cv2.COLOR_RGB2LAB)
|
||||
l_mean_before = float(lab_before[:, :, 0].mean())
|
||||
|
||||
result = preprocess_page(img)
|
||||
|
||||
# Output is grayscale (mode "L"); its values ARE the CLAHE-enhanced L channel
|
||||
l_mean_after = float(np.array(result).mean())
|
||||
|
||||
assert l_mean_after > l_mean_before
|
||||
|
||||
def test_does_not_crash_on_sub_tile_size_image(self):
|
||||
"""A 1×1 image is smaller than the CLAHE tile (8×8) in both axes.
|
||||
preprocess_page must not raise — it either succeeds or falls back silently."""
|
||||
from preprocessing import preprocess_page
|
||||
|
||||
img = Image.new("RGB", (1, 1), color=(128, 100, 80))
|
||||
result = preprocess_page(img)
|
||||
|
||||
assert isinstance(result, Image.Image)
|
||||
|
||||
def test_falls_back_to_pixel_identical_original_on_cv2_error(self):
|
||||
"""When cv2 raises a known error, preprocess_page returns the unmodified original image."""
|
||||
from preprocessing import preprocess_page
|
||||
|
||||
img = Image.new("RGB", (80, 60), color=(123, 45, 67))
|
||||
original_pixels = list(img.getdata())
|
||||
|
||||
with patch("preprocessing.cv2.cvtColor", side_effect=ValueError("bad input")):
|
||||
result = preprocess_page(img)
|
||||
|
||||
result_pixels = list(result.getdata())
|
||||
assert result_pixels == original_pixels
|
||||
|
||||
def test_unexpected_exception_propagates(self):
|
||||
"""A RuntimeError (programming error) must propagate — not be swallowed by the cv2 fallback."""
|
||||
from preprocessing import preprocess_page
|
||||
|
||||
img = Image.new("RGB", (80, 60))
|
||||
|
||||
with patch("preprocessing.cv2.cvtColor", side_effect=RuntimeError("unexpected")):
|
||||
with pytest.raises(RuntimeError, match="unexpected"):
|
||||
preprocess_page(img)
|
||||
@@ -301,3 +301,83 @@ async def test_ocr_stream_applies_confidence_markers(mock_images):
|
||||
page = [l for l in lines if l["type"] == "page"][0]
|
||||
assert page["blocks"][0]["text"] == "Lieber [unleserlich]"
|
||||
assert "words" not in page["blocks"][0]
|
||||
|
||||
|
||||
# ─── Preprocessing integration ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_emits_preprocessing_event_per_page_before_page_event(mock_images):
|
||||
"""generate() must emit a preprocessing event for each page before its page event."""
|
||||
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.surya_engine") as mock_surya, \
|
||||
patch("main.preprocess_page", side_effect=lambda img: img) as mock_preprocess:
|
||||
mock_surya.extract_page_blocks.return_value = []
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr/stream", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "TYPEWRITER",
|
||||
})
|
||||
|
||||
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||
preprocessing_events = [l for l in lines if l["type"] == "preprocessing"]
|
||||
page_events = [l for l in lines if l["type"] == "page"]
|
||||
|
||||
assert len(preprocessing_events) == 3
|
||||
assert preprocessing_events[0]["pageNumber"] == 1
|
||||
assert preprocessing_events[1]["pageNumber"] == 2
|
||||
assert preprocessing_events[2]["pageNumber"] == 3
|
||||
|
||||
# Each preprocessing event must come immediately before the corresponding page event
|
||||
for pre, page in zip(preprocessing_events, page_events):
|
||||
assert pre["pageNumber"] == page["pageNumber"]
|
||||
|
||||
assert mock_preprocess.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_stream_preprocesses_once_per_page_not_per_region(mock_images):
|
||||
"""generate_guided() must call preprocess_page once per page, not once per region."""
|
||||
regions = [
|
||||
{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 0.5, "height": 0.5, "annotationId": "a1"},
|
||||
{"pageNumber": 1, "x": 0.5, "y": 0.0, "width": 0.5, "height": 0.5, "annotationId": "a2"},
|
||||
{"pageNumber": 2, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0, "annotationId": "a3"},
|
||||
]
|
||||
|
||||
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images[:2]), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.surya_engine") as mock_surya, \
|
||||
patch("main.preprocess_page", side_effect=lambda img: img) as mock_preprocess:
|
||||
mock_surya.extract_region_text.return_value = "text"
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr/stream", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "TYPEWRITER",
|
||||
"regions": regions,
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
# 2 pages, each preprocessed once — not 3 (once per region)
|
||||
assert mock_preprocess.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ocr_endpoint_preprocesses_silently_without_emitting_events(mock_images):
|
||||
"""/ocr endpoint preprocesses each page without emitting preprocessing events."""
|
||||
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||
patch("main._models_ready", True), \
|
||||
patch("main.surya_engine") as mock_surya, \
|
||||
patch("main.preprocess_page", side_effect=lambda img: img) as mock_preprocess:
|
||||
mock_surya.extract_blocks.return_value = []
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post("/ocr", json={
|
||||
"pdfUrl": "http://minio/test.pdf",
|
||||
"scriptType": "TYPEWRITER",
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_preprocess.call_count == 3
|
||||
|
||||
Reference in New Issue
Block a user