test(ocr): collapse shared client setup into ocr_client helper

Each metrics test was repeating the same five-line block — patch
kraken_engine.load_models, patch load_spell_checker, instantiate the
AsyncClient, force _models_ready True, restore it. Lift the lot into a
single async context manager so each test body shrinks to its real
arrange / act / assert intent.

Tests that drive the lifespan directly (models_ready gauge) or stub
asyncio.to_thread for /train (which already patches _models_ready) stay
unchanged.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-05-21 17:03:29 +02:00
parent df952861c4
commit ddf6cf4cbc

View File

@@ -4,6 +4,7 @@ Each test that asserts on a counter/gauge value uses a fresh CollectorRegistry
(see decision #3 on issue #652) to keep the metrics isolated between tests. (see decision #3 on issue #652) to keep the metrics isolated between tests.
""" """
import contextlib
import io import io
import zipfile import zipfile
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
@@ -17,6 +18,26 @@ from main import app
from metrics import build_metrics from metrics import build_metrics
@contextlib.asynccontextmanager
async def ocr_client(*, raise_app_exceptions: bool = True):
"""Yield an AsyncClient with model-loaders patched and _models_ready forced on.
The shared setup for almost every metrics test: stub the heavy lifecycle
hooks (kraken_engine.load_models, load_spell_checker), flip the readiness
flag so request handlers do not 503, and restore it afterwards.
"""
with patch("main.kraken_engine.load_models"), \
patch("main.load_spell_checker"):
transport = ASGITransport(app=app, raise_app_exceptions=raise_app_exceptions)
async with AsyncClient(transport=transport, base_url="http://test") as client:
import main as main_module
main_module._models_ready = True
try:
yield client
finally:
main_module._models_ready = False
def _minimal_zip() -> bytes: def _minimal_zip() -> bytes:
"""Return a ZIP containing one fake .xml so endpoint validation passes.""" """Return a ZIP containing one fake .xml so endpoint validation passes."""
buf = io.BytesIO() buf = io.BytesIO()
@@ -57,26 +78,18 @@ async def test_metrics_includes_http_request_metrics_after_ocr_call():
mock_blocks = [{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0, mock_blocks = [{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0,
"polygon": None, "text": "hi", "words": []}] "polygon": None, "text": "hi", "words": []}]
with patch("main.kraken_engine.load_models"), \ with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.load_spell_checker"), \
patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.preprocess_page", side_effect=lambda img: img), \ patch("main.preprocess_page", side_effect=lambda img: img), \
patch("main.surya_engine.extract_blocks", return_value=mock_blocks): patch("main.surya_engine.extract_blocks", return_value=mock_blocks):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with ocr_client() as client:
# Models need to be loaded for /ocr to accept requests; force the flag. ocr_response = await client.post("/ocr", json={
import main as main_module "pdfUrl": "http://minio/doc.pdf",
main_module._models_ready = True "scriptType": "TYPEWRITER",
try: "language": "de",
ocr_response = await client.post("/ocr", json={ })
"pdfUrl": "http://minio/doc.pdf", assert ocr_response.status_code == 200, ocr_response.text
"scriptType": "TYPEWRITER",
"language": "de",
})
assert ocr_response.status_code == 200, ocr_response.text
metrics_response = await client.get("/metrics") metrics_response = await client.get("/metrics")
finally:
main_module._models_ready = False
body = metrics_response.text body = metrics_response.text
assert "http_requests_total" in body assert "http_requests_total" in body
@@ -124,20 +137,13 @@ async def test_ocr_jobs_total_incremented_with_kraken_engine_label_for_kurrent(f
mock_blocks = [{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0, mock_blocks = [{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0,
"polygon": None, "text": "hi", "words": []}] "polygon": None, "text": "hi", "words": []}]
with patch("main.kraken_engine.load_models"), \ with patch("main.correct_text", side_effect=lambda t: t), \
patch("main.load_spell_checker"), \
patch("main.correct_text", side_effect=lambda t: t), \
patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \ patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.preprocess_page", side_effect=lambda img: img), \ patch("main.preprocess_page", side_effect=lambda img: img), \
patch("main.kraken_engine.is_available", return_value=True), \ patch("main.kraken_engine.is_available", return_value=True), \
patch("main.kraken_engine.extract_blocks", return_value=mock_blocks): patch("main.kraken_engine.extract_blocks", return_value=mock_blocks):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with ocr_client() as client:
import main as main_module await _drive_ocr(client, script_type="HANDWRITING_KURRENT")
main_module._models_ready = True
try:
await _drive_ocr(client, script_type="HANDWRITING_KURRENT")
finally:
main_module._models_ready = False
value = fresh_metrics.ocr_jobs_total.labels( value = fresh_metrics.ocr_jobs_total.labels(
engine="kraken", script_type="HANDWRITING_KURRENT" engine="kraken", script_type="HANDWRITING_KURRENT"
@@ -152,18 +158,11 @@ async def test_ocr_jobs_total_incremented_with_surya_engine_label_for_typewriter
mock_blocks = [{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0, mock_blocks = [{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0,
"polygon": None, "text": "hi", "words": []}] "polygon": None, "text": "hi", "words": []}]
with patch("main.kraken_engine.load_models"), \ with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.load_spell_checker"), \
patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.preprocess_page", side_effect=lambda img: img), \ patch("main.preprocess_page", side_effect=lambda img: img), \
patch("main.surya_engine.extract_blocks", return_value=mock_blocks): patch("main.surya_engine.extract_blocks", return_value=mock_blocks):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with ocr_client() as client:
import main as main_module await _drive_ocr(client, script_type="TYPEWRITER")
main_module._models_ready = True
try:
await _drive_ocr(client, script_type="TYPEWRITER")
finally:
main_module._models_ready = False
value = fresh_metrics.ocr_jobs_total.labels( value = fresh_metrics.ocr_jobs_total.labels(
engine="surya", script_type="TYPEWRITER" engine="surya", script_type="TYPEWRITER"
@@ -178,26 +177,19 @@ async def test_ocr_pages_total_incremented_once_per_page_in_stream(fresh_metrics
mock_blocks = [{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0, mock_blocks = [{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0,
"polygon": None, "text": "hi", "words": []}] "polygon": None, "text": "hi", "words": []}]
with patch("main.kraken_engine.load_models"), \ with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.load_spell_checker"), \
patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.preprocess_page", side_effect=lambda img: img), \ patch("main.preprocess_page", side_effect=lambda img: img), \
patch("main.surya_engine.extract_page_blocks", return_value=mock_blocks): patch("main.surya_engine.extract_page_blocks", return_value=mock_blocks):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with ocr_client() as client:
import main as main_module async with client.stream("POST", "/ocr/stream", json={
main_module._models_ready = True "pdfUrl": "http://minio/doc.pdf",
try: "scriptType": "TYPEWRITER",
async with client.stream("POST", "/ocr/stream", json={ "language": "de",
"pdfUrl": "http://minio/doc.pdf", }) as response:
"scriptType": "TYPEWRITER", assert response.status_code == 200
"language": "de", # Drain the stream so all per-page increments fire.
}) as response: async for _ in response.aiter_lines():
assert response.status_code == 200 pass
# Drain the stream so all per-page increments fire.
async for _ in response.aiter_lines():
pass
finally:
main_module._models_ready = False
value = fresh_metrics.ocr_pages_total.labels(engine="surya")._value.get() value = fresh_metrics.ocr_pages_total.labels(engine="surya")._value.get()
assert value == 3.0 assert value == 3.0
@@ -218,28 +210,21 @@ async def test_ocr_skipped_pages_total_incremented_when_engine_raises_for_a_page
raise RuntimeError("synthetic engine failure") raise RuntimeError("synthetic engine failure")
return good_blocks return good_blocks
with patch("main.kraken_engine.load_models"), \ with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.load_spell_checker"), \
patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.preprocess_page", side_effect=lambda img: img), \ patch("main.preprocess_page", side_effect=lambda img: img), \
patch("main.surya_engine.extract_page_blocks", side_effect=extract_side_effect): patch("main.surya_engine.extract_page_blocks", side_effect=extract_side_effect):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with ocr_client() as client:
import main as main_module async with client.stream("POST", "/ocr/stream", json={
main_module._models_ready = True "pdfUrl": "http://minio/doc.pdf",
try: "scriptType": "TYPEWRITER",
async with client.stream("POST", "/ocr/stream", json={ "language": "de",
"pdfUrl": "http://minio/doc.pdf", }) as response:
"scriptType": "TYPEWRITER", assert response.status_code == 200
"language": "de", saw_error = False
}) as response: async for line in response.aiter_lines():
assert response.status_code == 200 if line and '"type": "error"' in line:
saw_error = False saw_error = True
async for line in response.aiter_lines(): assert saw_error
if line and '"type": "error"' in line:
saw_error = True
assert saw_error
finally:
main_module._models_ready = False
assert fresh_metrics.ocr_skipped_pages_total._value.get() == 1.0 assert fresh_metrics.ocr_skipped_pages_total._value.get() == 1.0
# The second page still succeeds. # The second page still succeeds.
@@ -266,18 +251,11 @@ async def test_ocr_words_and_illegible_words_total_sum_across_blocks(fresh_metri
{"text": "Karl", "confidence": 0.95}]}, {"text": "Karl", "confidence": 0.95}]},
] ]
with patch("main.kraken_engine.load_models"), \ with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.load_spell_checker"), \
patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.preprocess_page", side_effect=lambda img: img), \ patch("main.preprocess_page", side_effect=lambda img: img), \
patch("main.surya_engine.extract_blocks", return_value=mock_blocks): patch("main.surya_engine.extract_blocks", return_value=mock_blocks):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with ocr_client() as client:
import main as main_module await _drive_ocr(client, script_type="TYPEWRITER")
main_module._models_ready = True
try:
await _drive_ocr(client, script_type="TYPEWRITER")
finally:
main_module._models_ready = False
assert fresh_metrics.ocr_words_total._value.get() == 5.0 assert fresh_metrics.ocr_words_total._value.get() == 5.0
assert fresh_metrics.ocr_illegible_words_total._value.get() == 2.0 assert fresh_metrics.ocr_illegible_words_total._value.get() == 2.0
@@ -296,25 +274,18 @@ async def test_ocr_processing_seconds_histogram_observed_per_page_in_stream(fres
mock_blocks = [{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0, mock_blocks = [{"pageNumber": 1, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0,
"polygon": None, "text": "ok", "words": []}] "polygon": None, "text": "ok", "words": []}]
with patch("main.kraken_engine.load_models"), \ with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.load_spell_checker"), \
patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.preprocess_page", side_effect=lambda img: img), \ patch("main.preprocess_page", side_effect=lambda img: img), \
patch("main.surya_engine.extract_page_blocks", return_value=mock_blocks): patch("main.surya_engine.extract_page_blocks", return_value=mock_blocks):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with ocr_client() as client:
import main as main_module async with client.stream("POST", "/ocr/stream", json={
main_module._models_ready = True "pdfUrl": "http://minio/doc.pdf",
try: "scriptType": "TYPEWRITER",
async with client.stream("POST", "/ocr/stream", json={ "language": "de",
"pdfUrl": "http://minio/doc.pdf", }) as response:
"scriptType": "TYPEWRITER", assert response.status_code == 200
"language": "de", async for _ in response.aiter_lines():
}) as response: pass
assert response.status_code == 200
async for _ in response.aiter_lines():
pass
finally:
main_module._models_ready = False
sum_seconds, count = _histogram_count_sum( sum_seconds, count = _histogram_count_sum(
fresh_metrics.ocr_processing_seconds, engine="surya" fresh_metrics.ocr_processing_seconds, engine="surya"
@@ -507,26 +478,19 @@ async def test_ocr_processing_seconds_histogram_observed_per_page_in_guided_stre
{"pageNumber": 2, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0, "annotationId": "a2"}, {"pageNumber": 2, "x": 0.0, "y": 0.0, "width": 1.0, "height": 1.0, "annotationId": "a2"},
] ]
with patch("main.kraken_engine.load_models"), \ with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.load_spell_checker"), \
patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.preprocess_page", side_effect=lambda img: img), \ patch("main.preprocess_page", side_effect=lambda img: img), \
patch("main.surya_engine.extract_region_text", return_value="text"): patch("main.surya_engine.extract_region_text", return_value="text"):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with ocr_client() as client:
import main as main_module async with client.stream("POST", "/ocr/stream", json={
main_module._models_ready = True "pdfUrl": "http://minio/doc.pdf",
try: "scriptType": "TYPEWRITER",
async with client.stream("POST", "/ocr/stream", json={ "language": "de",
"pdfUrl": "http://minio/doc.pdf", "regions": regions,
"scriptType": "TYPEWRITER", }) as response:
"language": "de", assert response.status_code == 200
"regions": regions, async for _ in response.aiter_lines():
}) as response: pass
assert response.status_code == 200
async for _ in response.aiter_lines():
pass
finally:
main_module._models_ready = False
sum_seconds, count = _histogram_count_sum( sum_seconds, count = _histogram_count_sum(
fresh_metrics.ocr_processing_seconds, engine="surya" fresh_metrics.ocr_processing_seconds, engine="surya"
@@ -549,28 +513,21 @@ async def test_ocr_processing_seconds_histogram_excludes_spell_check_time_in_gui
_time.sleep(0.05) _time.sleep(0.05)
return text return text
with patch("main.kraken_engine.load_models"), \ with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.load_spell_checker"), \
patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
patch("main.preprocess_page", side_effect=lambda img: img), \ patch("main.preprocess_page", side_effect=lambda img: img), \
patch("main.kraken_engine.is_available", return_value=True), \ patch("main.kraken_engine.is_available", return_value=True), \
patch("main.kraken_engine.extract_region_text", return_value="text"), \ patch("main.kraken_engine.extract_region_text", return_value="text"), \
patch("main.correct_text", side_effect=slow_correct): patch("main.correct_text", side_effect=slow_correct):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with ocr_client() as client:
import main as main_module async with client.stream("POST", "/ocr/stream", json={
main_module._models_ready = True "pdfUrl": "http://minio/doc.pdf",
try: "scriptType": "HANDWRITING_KURRENT",
async with client.stream("POST", "/ocr/stream", json={ "language": "de",
"pdfUrl": "http://minio/doc.pdf", "regions": regions,
"scriptType": "HANDWRITING_KURRENT", }) as response:
"language": "de", assert response.status_code == 200
"regions": regions, async for _ in response.aiter_lines():
}) as response: pass
assert response.status_code == 200
async for _ in response.aiter_lines():
pass
finally:
main_module._models_ready = False
sum_seconds, _ = _histogram_count_sum( sum_seconds, _ = _histogram_count_sum(
fresh_metrics.ocr_processing_seconds, engine="kraken" fresh_metrics.ocr_processing_seconds, engine="kraken"
@@ -591,21 +548,13 @@ async def test_ocr_jobs_total_not_incremented_when_pdf_download_fails_in_stream(
async def fail_download(url): async def fail_download(url):
raise RuntimeError("synthetic download failure") raise RuntimeError("synthetic download failure")
with patch("main.kraken_engine.load_models"), \ with patch("main._download_and_convert_pdf", new=fail_download):
patch("main.load_spell_checker"), \ async with ocr_client(raise_app_exceptions=False) as client:
patch("main._download_and_convert_pdf", new=fail_download): response = await client.post("/ocr/stream", json={
transport = ASGITransport(app=app, raise_app_exceptions=False) "pdfUrl": "http://minio/doc.pdf",
async with AsyncClient(transport=transport, base_url="http://test") as client: "scriptType": "TYPEWRITER",
import main as main_module "language": "de",
main_module._models_ready = True })
try:
response = await client.post("/ocr/stream", json={
"pdfUrl": "http://minio/doc.pdf",
"scriptType": "TYPEWRITER",
"language": "de",
})
finally:
main_module._models_ready = False
assert response.status_code == 500 assert response.status_code == 500
assert fresh_metrics.ocr_jobs_total.labels( assert fresh_metrics.ocr_jobs_total.labels(