diff --git a/ocr-service/test_metrics.py b/ocr-service/test_metrics.py index 55108fb8..7d154102 100644 --- a/ocr-service/test_metrics.py +++ b/ocr-service/test_metrics.py @@ -388,3 +388,37 @@ async def test_ocr_training_runs_total_incremented_with_segmentation_success_lab assert fresh_metrics.ocr_training_runs_total.labels( kind="segmentation", outcome="success" )._value.get() == 1.0 + + +@pytest.mark.asyncio +async def test_ocr_model_accuracy_gauge_set_per_kind_after_successful_training(fresh_metrics): + """After /train and /segtrain succeed, ocr_model_accuracy{kind=...} reflects the result.""" + recognition_accuracy = 0.917 + segmentation_accuracy = 0.834 + + async def fake_recognition_to_thread(func, *args, **kwargs): + return _fake_training_result(accuracy=recognition_accuracy) + + async def fake_segmentation_to_thread(func, *args, **kwargs): + return _fake_training_result(accuracy=segmentation_accuracy) + + with patch("main.TRAINING_TOKEN", "secret-token"), \ + patch("main._models_ready", True): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + with patch("main.asyncio.to_thread", side_effect=fake_recognition_to_thread): + rec_resp = await client.post( + "/train", + files={"file": ("training.zip", _minimal_zip(), "application/zip")}, + headers={"X-Training-Token": "secret-token"}, + ) + assert rec_resp.status_code == 200 + with patch("main.asyncio.to_thread", side_effect=fake_segmentation_to_thread): + seg_resp = await client.post( + "/segtrain", + files={"file": ("training.zip", _minimal_zip(), "application/zip")}, + headers={"X-Training-Token": "secret-token"}, + ) + assert seg_resp.status_code == 200 + + assert fresh_metrics.ocr_model_accuracy.labels(kind="recognition")._value.get() == pytest.approx(recognition_accuracy) + assert fresh_metrics.ocr_model_accuracy.labels(kind="segmentation")._value.get() == pytest.approx(segmentation_accuracy)