feat(ocr): integrate preprocessing into stream and batch endpoints
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from confidence import apply_confidence_markers, get_threshold
|
|||||||
from engines import kraken as kraken_engine
|
from engines import kraken as kraken_engine
|
||||||
from engines import surya as surya_engine
|
from engines import surya as surya_engine
|
||||||
from models import OcrBlock, OcrRequest
|
from models import OcrBlock, OcrRequest
|
||||||
|
from preprocessing import preprocess_page
|
||||||
|
|
||||||
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
|
TRAINING_TOKEN = os.environ.get("TRAINING_TOKEN", "")
|
||||||
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
KRAKEN_MODEL_PATH = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
||||||
@@ -86,6 +87,10 @@ async def run_ocr(request: OcrRequest):
|
|||||||
|
|
||||||
images = await _download_and_convert_pdf(request.pdfUrl)
|
images = await _download_and_convert_pdf(request.pdfUrl)
|
||||||
|
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
images[i] = await asyncio.to_thread(preprocess_page, img)
|
||||||
|
del img
|
||||||
|
|
||||||
script_type = request.scriptType.upper()
|
script_type = request.scriptType.upper()
|
||||||
|
|
||||||
if script_type == "HANDWRITING_KURRENT":
|
if script_type == "HANDWRITING_KURRENT":
|
||||||
@@ -157,6 +162,8 @@ async def run_ocr_stream(request: OcrRequest):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
||||||
|
image = await asyncio.to_thread(preprocess_page, image)
|
||||||
blocks = []
|
blocks = []
|
||||||
for region in page_regions:
|
for region in page_regions:
|
||||||
text = await asyncio.to_thread(
|
text = await asyncio.to_thread(
|
||||||
@@ -214,6 +221,8 @@ async def run_ocr_stream(request: OcrRequest):
|
|||||||
|
|
||||||
for page_idx, image in enumerate(images, start=1):
|
for page_idx, image in enumerate(images, start=1):
|
||||||
try:
|
try:
|
||||||
|
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
||||||
|
image = await asyncio.to_thread(preprocess_page, image)
|
||||||
blocks = await asyncio.to_thread(
|
blocks = await asyncio.to_thread(
|
||||||
engine.extract_page_blocks, image, page_idx, request.language
|
engine.extract_page_blocks, image, page_idx, request.language
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -301,3 +301,83 @@ async def test_ocr_stream_applies_confidence_markers(mock_images):
|
|||||||
page = [l for l in lines if l["type"] == "page"][0]
|
page = [l for l in lines if l["type"] == "page"][0]
|
||||||
assert page["blocks"][0]["text"] == "Lieber [unleserlich]"
|
assert page["blocks"][0]["text"] == "Lieber [unleserlich]"
|
||||||
assert "words" not in page["blocks"][0]
|
assert "words" not in page["blocks"][0]
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Preprocessing integration ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_emits_preprocessing_event_per_page_before_page_event(mock_images):
|
||||||
|
"""generate() must emit a preprocessing event for each page before its page event."""
|
||||||
|
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||||
|
patch("main._models_ready", True), \
|
||||||
|
patch("main.surya_engine") as mock_surya, \
|
||||||
|
patch("main.preprocess_page", side_effect=lambda img: img) as mock_preprocess:
|
||||||
|
mock_surya.extract_page_blocks.return_value = []
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr/stream", json={
|
||||||
|
"pdfUrl": "http://minio/test.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
})
|
||||||
|
|
||||||
|
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||||
|
preprocessing_events = [l for l in lines if l["type"] == "preprocessing"]
|
||||||
|
page_events = [l for l in lines if l["type"] == "page"]
|
||||||
|
|
||||||
|
assert len(preprocessing_events) == 3
|
||||||
|
assert preprocessing_events[0]["pageNumber"] == 1
|
||||||
|
assert preprocessing_events[1]["pageNumber"] == 2
|
||||||
|
assert preprocessing_events[2]["pageNumber"] == 3
|
||||||
|
|
||||||
|
# Each preprocessing event must come immediately before the corresponding page event
|
||||||
|
for pre, page in zip(preprocessing_events, page_events):
|
||||||
|
assert pre["pageNumber"] == page["pageNumber"]
|
||||||
|
|
||||||
|
assert mock_preprocess.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_guided_stream_preprocesses_once_per_page_not_per_region(mock_images):
|
||||||
|
"""generate_guided() must call preprocess_page once per page, not once per region."""
|
||||||
|
regions = [
|
||||||
|
{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 0.5, "height": 0.5, "annotationId": "a1"},
|
||||||
|
{"pageNumber": 1, "x": 0.5, "y": 0.0, "width": 0.5, "height": 0.5, "annotationId": "a2"},
|
||||||
|
{"pageNumber": 2, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0, "annotationId": "a3"},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images[:2]), \
|
||||||
|
patch("main._models_ready", True), \
|
||||||
|
patch("main.surya_engine") as mock_surya, \
|
||||||
|
patch("main.preprocess_page", side_effect=lambda img: img) as mock_preprocess:
|
||||||
|
mock_surya.extract_region_text.return_value = "text"
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr/stream", json={
|
||||||
|
"pdfUrl": "http://minio/test.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
"regions": regions,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
# 2 pages, each preprocessed once — not 3 (once per region)
|
||||||
|
assert mock_preprocess.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ocr_endpoint_preprocesses_silently_without_emitting_events(mock_images):
|
||||||
|
"""/ocr endpoint preprocesses each page without emitting preprocessing events."""
|
||||||
|
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||||
|
patch("main._models_ready", True), \
|
||||||
|
patch("main.surya_engine") as mock_surya, \
|
||||||
|
patch("main.preprocess_page", side_effect=lambda img: img) as mock_preprocess:
|
||||||
|
mock_surya.extract_blocks.return_value = []
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr", json={
|
||||||
|
"pdfUrl": "http://minio/test.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert mock_preprocess.call_count == 3
|
||||||
|
|||||||
Reference in New Issue
Block a user