feat(ocr): add NDJSON streaming endpoint POST /ocr/stream

Streams one JSON line per completed page instead of buffering the
entire result. Emits start/page/error/done events. On per-page
failure, logs the traceback but yields a generic error message and
continues with the next page. Adds X-Accel-Buffering: no and
Cache-Control: no-cache headers for reverse-proxy compatibility.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-13 09:57:57 +02:00
parent b7d5f71ef7
commit 97c6cf6a65
2 changed files with 236 additions and 0 deletions

View File

@@ -2,12 +2,14 @@
import asyncio
import io
import json
import logging
from contextlib import asynccontextmanager
import httpx
import pypdfium2 as pdfium
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from PIL import Image
from confidence import apply_confidence_markers, get_threshold
@@ -82,6 +84,77 @@ async def run_ocr(request: OcrRequest):
return [OcrBlock(**b) for b in blocks]
@app.post("/ocr/stream")
async def run_ocr_stream(request: OcrRequest):
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page."""
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
images = await _download_and_convert_pdf(request.pdfUrl)
script_type = request.scriptType.upper()
threshold = get_threshold(script_type)
use_kraken = script_type == "HANDWRITING_KURRENT"
if use_kraken and not kraken_engine.is_available():
raise HTTPException(
status_code=400,
detail="Kraken model not available — cannot process Kurrent script",
)
async def generate():
total_pages = len(images)
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
total_blocks = 0
skipped_pages = 0
for page_idx, image in enumerate(images):
try:
engine = kraken_engine if use_kraken else surya_engine
blocks = await asyncio.to_thread(
engine.extract_page_blocks, image, page_idx, request.language
)
for block in blocks:
if block.get("words"):
block["text"] = apply_confidence_markers(block["words"], threshold)
block.pop("words", None)
total_blocks += len(blocks)
yield json.dumps({
"type": "page",
"pageNumber": page_idx,
"blocks": blocks,
}) + "\n"
except Exception:
logger.exception("OCR failed on page %d", page_idx)
skipped_pages += 1
yield json.dumps({
"type": "error",
"pageNumber": page_idx,
"message": f"OCR processing failed on page {page_idx}",
}) + "\n"
finally:
del image
yield json.dumps({
"type": "done",
"totalBlocks": total_blocks,
"skippedPages": skipped_pages,
}) + "\n"
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
},
)
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:

163
ocr-service/test_stream.py Normal file
View File

@@ -0,0 +1,163 @@
"""Tests for the NDJSON streaming OCR endpoint POST /ocr/stream."""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from main import app
@pytest.fixture
def mock_images():
"""Return 3 fake PIL images."""
from PIL import Image
return [Image.new("RGB", (100, 200)) for _ in range(3)]
def _make_block(page_idx, text="Test"):
return {
"pageNumber": page_idx,
"x": 0.1, "y": 0.2, "width": 0.8, "height": 0.1,
"polygon": None, "text": text,
"words": [{"text": text, "confidence": 0.95}],
}
# ─── P3: start event with total pages ────────────────────────────────────────
@pytest.mark.asyncio
async def test_ocr_stream_emits_start_event_with_total_pages(mock_images):
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya:
mock_surya.extract_page_blocks.return_value = [_make_block(0)]
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
"language": "de",
})
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
assert lines[0] == {"type": "start", "totalPages": 3}
# ─── P4: page events per completed page ──────────────────────────────────────
@pytest.mark.asyncio
async def test_ocr_stream_emits_page_event_per_page_with_blocks(mock_images):
def page_blocks(image, page_idx, language="de"):
return [_make_block(page_idx, f"Page {page_idx}")]
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya:
mock_surya.extract_page_blocks.side_effect = page_blocks
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
page_events = [l for l in lines if l["type"] == "page"]
assert len(page_events) == 3
assert page_events[0]["pageNumber"] == 0
assert page_events[1]["pageNumber"] == 1
assert page_events[2]["pageNumber"] == 2
assert len(page_events[0]["blocks"]) == 1
# ─── P5: done event with total blocks and skipped ────────────────────────────
@pytest.mark.asyncio
async def test_ocr_stream_emits_done_with_total_blocks(mock_images):
def page_blocks(image, page_idx, language="de"):
return [_make_block(page_idx)] * 2 # 2 blocks per page
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya:
mock_surya.extract_page_blocks.side_effect = page_blocks
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
done = [l for l in lines if l["type"] == "done"][0]
assert done["totalBlocks"] == 6
assert done["skippedPages"] == 0
# ─── P6: error event on page failure, continues ──────────────────────────────
@pytest.mark.asyncio
async def test_ocr_stream_emits_error_event_on_page_failure_and_continues(mock_images):
call_count = 0
def page_blocks(image, page_idx, language="de"):
nonlocal call_count
call_count += 1
if page_idx == 1:
raise RuntimeError("Engine crashed on page 1")
return [_make_block(page_idx)]
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya:
mock_surya.extract_page_blocks.side_effect = page_blocks
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
types = [l["type"] for l in lines]
assert "error" in types
error_event = [l for l in lines if l["type"] == "error"][0]
assert error_event["pageNumber"] == 1
# Error message must be generic, not the raw traceback
assert "Engine crashed" not in error_event["message"]
# Processing continued: pages 0 and 2 have page events
page_events = [l for l in lines if l["type"] == "page"]
assert len(page_events) == 2
done = [l for l in lines if l["type"] == "done"][0]
assert done["skippedPages"] == 1
# ─── P7: old /ocr endpoint still works ───────────────────────────────────────
@pytest.mark.asyncio
async def test_old_ocr_endpoint_still_returns_flat_list(mock_images):
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main._models_ready", True), \
patch("main.surya_engine") as mock_surya:
mock_surya.extract_blocks.return_value = [_make_block(0), _make_block(1)]
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/ocr", json={
"pdfUrl": "http://minio/test.pdf",
"scriptType": "TYPEWRITER",
})
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 2
assert data[0]["pageNumber"] == 0