feat(ocr): add NDJSON streaming endpoint POST /ocr/stream
Streams one JSON line per completed page instead of buffering the entire result. Emits start/page/error/done events. On per-page failure, logs the traceback but yields a generic error message and continues with the next page. Adds X-Accel-Buffering: no and Cache-Control: no-cache headers for reverse-proxy compatibility. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -2,12 +2,14 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pypdfium2 as pdfium
|
import pypdfium2 as pdfium
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from confidence import apply_confidence_markers, get_threshold
|
from confidence import apply_confidence_markers, get_threshold
|
||||||
@@ -82,6 +84,77 @@ async def run_ocr(request: OcrRequest):
|
|||||||
return [OcrBlock(**b) for b in blocks]
|
return [OcrBlock(**b) for b in blocks]
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/ocr/stream")
|
||||||
|
async def run_ocr_stream(request: OcrRequest):
|
||||||
|
"""Run OCR on a PDF with NDJSON streaming — one JSON line per completed page."""
|
||||||
|
if not _models_ready:
|
||||||
|
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||||
|
|
||||||
|
images = await _download_and_convert_pdf(request.pdfUrl)
|
||||||
|
script_type = request.scriptType.upper()
|
||||||
|
threshold = get_threshold(script_type)
|
||||||
|
|
||||||
|
use_kraken = script_type == "HANDWRITING_KURRENT"
|
||||||
|
if use_kraken and not kraken_engine.is_available():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Kraken model not available — cannot process Kurrent script",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate():
|
||||||
|
total_pages = len(images)
|
||||||
|
yield json.dumps({"type": "start", "totalPages": total_pages}) + "\n"
|
||||||
|
|
||||||
|
total_blocks = 0
|
||||||
|
skipped_pages = 0
|
||||||
|
|
||||||
|
for page_idx, image in enumerate(images):
|
||||||
|
try:
|
||||||
|
engine = kraken_engine if use_kraken else surya_engine
|
||||||
|
blocks = await asyncio.to_thread(
|
||||||
|
engine.extract_page_blocks, image, page_idx, request.language
|
||||||
|
)
|
||||||
|
|
||||||
|
for block in blocks:
|
||||||
|
if block.get("words"):
|
||||||
|
block["text"] = apply_confidence_markers(block["words"], threshold)
|
||||||
|
block.pop("words", None)
|
||||||
|
|
||||||
|
total_blocks += len(blocks)
|
||||||
|
yield json.dumps({
|
||||||
|
"type": "page",
|
||||||
|
"pageNumber": page_idx,
|
||||||
|
"blocks": blocks,
|
||||||
|
}) + "\n"
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("OCR failed on page %d", page_idx)
|
||||||
|
skipped_pages += 1
|
||||||
|
yield json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"pageNumber": page_idx,
|
||||||
|
"message": f"OCR processing failed on page {page_idx}",
|
||||||
|
}) + "\n"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
del image
|
||||||
|
|
||||||
|
yield json.dumps({
|
||||||
|
"type": "done",
|
||||||
|
"totalBlocks": total_blocks,
|
||||||
|
"skippedPages": skipped_pages,
|
||||||
|
}) + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
||||||
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
||||||
|
|||||||
163
ocr-service/test_stream.py
Normal file
163
ocr-service/test_stream.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
"""Tests for the NDJSON streaming OCR endpoint POST /ocr/stream."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from main import app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_images():
|
||||||
|
"""Return 3 fake PIL images."""
|
||||||
|
from PIL import Image
|
||||||
|
return [Image.new("RGB", (100, 200)) for _ in range(3)]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_block(page_idx, text="Test"):
|
||||||
|
return {
|
||||||
|
"pageNumber": page_idx,
|
||||||
|
"x": 0.1, "y": 0.2, "width": 0.8, "height": 0.1,
|
||||||
|
"polygon": None, "text": text,
|
||||||
|
"words": [{"text": text, "confidence": 0.95}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ─── P3: start event with total pages ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ocr_stream_emits_start_event_with_total_pages(mock_images):
|
||||||
|
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||||
|
patch("main._models_ready", True), \
|
||||||
|
patch("main.surya_engine") as mock_surya:
|
||||||
|
mock_surya.extract_page_blocks.return_value = [_make_block(0)]
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr/stream", json={
|
||||||
|
"pdfUrl": "http://minio/test.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
"language": "de",
|
||||||
|
})
|
||||||
|
|
||||||
|
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||||
|
assert lines[0] == {"type": "start", "totalPages": 3}
|
||||||
|
|
||||||
|
|
||||||
|
# ─── P4: page events per completed page ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ocr_stream_emits_page_event_per_page_with_blocks(mock_images):
|
||||||
|
def page_blocks(image, page_idx, language="de"):
|
||||||
|
return [_make_block(page_idx, f"Page {page_idx}")]
|
||||||
|
|
||||||
|
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||||
|
patch("main._models_ready", True), \
|
||||||
|
patch("main.surya_engine") as mock_surya:
|
||||||
|
mock_surya.extract_page_blocks.side_effect = page_blocks
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr/stream", json={
|
||||||
|
"pdfUrl": "http://minio/test.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
})
|
||||||
|
|
||||||
|
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||||
|
page_events = [l for l in lines if l["type"] == "page"]
|
||||||
|
assert len(page_events) == 3
|
||||||
|
assert page_events[0]["pageNumber"] == 0
|
||||||
|
assert page_events[1]["pageNumber"] == 1
|
||||||
|
assert page_events[2]["pageNumber"] == 2
|
||||||
|
assert len(page_events[0]["blocks"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ─── P5: done event with total blocks and skipped ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ocr_stream_emits_done_with_total_blocks(mock_images):
|
||||||
|
def page_blocks(image, page_idx, language="de"):
|
||||||
|
return [_make_block(page_idx)] * 2 # 2 blocks per page
|
||||||
|
|
||||||
|
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||||
|
patch("main._models_ready", True), \
|
||||||
|
patch("main.surya_engine") as mock_surya:
|
||||||
|
mock_surya.extract_page_blocks.side_effect = page_blocks
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr/stream", json={
|
||||||
|
"pdfUrl": "http://minio/test.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
})
|
||||||
|
|
||||||
|
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||||
|
done = [l for l in lines if l["type"] == "done"][0]
|
||||||
|
assert done["totalBlocks"] == 6
|
||||||
|
assert done["skippedPages"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ─── P6: error event on page failure, continues ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ocr_stream_emits_error_event_on_page_failure_and_continues(mock_images):
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def page_blocks(image, page_idx, language="de"):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if page_idx == 1:
|
||||||
|
raise RuntimeError("Engine crashed on page 1")
|
||||||
|
return [_make_block(page_idx)]
|
||||||
|
|
||||||
|
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||||
|
patch("main._models_ready", True), \
|
||||||
|
patch("main.surya_engine") as mock_surya:
|
||||||
|
mock_surya.extract_page_blocks.side_effect = page_blocks
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr/stream", json={
|
||||||
|
"pdfUrl": "http://minio/test.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
})
|
||||||
|
|
||||||
|
lines = [json.loads(line) for line in response.text.strip().split("\n") if line.strip()]
|
||||||
|
types = [l["type"] for l in lines]
|
||||||
|
assert "error" in types
|
||||||
|
error_event = [l for l in lines if l["type"] == "error"][0]
|
||||||
|
assert error_event["pageNumber"] == 1
|
||||||
|
# Error message must be generic, not the raw traceback
|
||||||
|
assert "Engine crashed" not in error_event["message"]
|
||||||
|
|
||||||
|
# Processing continued: pages 0 and 2 have page events
|
||||||
|
page_events = [l for l in lines if l["type"] == "page"]
|
||||||
|
assert len(page_events) == 2
|
||||||
|
|
||||||
|
done = [l for l in lines if l["type"] == "done"][0]
|
||||||
|
assert done["skippedPages"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ─── P7: old /ocr endpoint still works ───────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_old_ocr_endpoint_still_returns_flat_list(mock_images):
|
||||||
|
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||||
|
patch("main._models_ready", True), \
|
||||||
|
patch("main.surya_engine") as mock_surya:
|
||||||
|
mock_surya.extract_blocks.return_value = [_make_block(0), _make_block(1)]
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr", json={
|
||||||
|
"pdfUrl": "http://minio/test.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == 2
|
||||||
|
assert data[0]["pageNumber"] == 0
|
||||||
Reference in New Issue
Block a user