diff --git a/ocr-service/test_metrics.py b/ocr-service/test_metrics.py index 3ed02c31..6aa047e0 100644 --- a/ocr-service/test_metrics.py +++ b/ocr-service/test_metrics.py @@ -390,6 +390,52 @@ async def test_ocr_training_runs_total_incremented_with_segmentation_success_lab )._value.get() == 1.0 +@pytest.mark.asyncio +async def test_ocr_training_runs_total_incremented_with_recognition_success_label_for_train_sender(fresh_metrics): + """/train-sender success increments ocr_training_runs_total{kind=recognition, outcome=success}.""" + async def fake_to_thread(func, *args, **kwargs): + return _fake_training_result() + + with patch("main.TRAINING_TOKEN", "secret-token"), \ + patch("main._models_ready", True), \ + patch("main.asyncio.to_thread", side_effect=fake_to_thread): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/train-sender", + files={"file": ("training.zip", _minimal_zip(), "application/zip")}, + data={"output_model_path": "/app/models/sender_test.mlmodel"}, + headers={"X-Training-Token": "secret-token"}, + ) + + assert response.status_code == 200, response.text + assert fresh_metrics.ocr_training_runs_total.labels( + kind="recognition", outcome="success" + )._value.get() == 1.0 + + +@pytest.mark.asyncio +async def test_ocr_model_accuracy_gauge_stays_default_when_training_returns_no_accuracy(fresh_metrics): + """When the runner returns accuracy=None, ocr_model_accuracy must remain at its default 0.""" + async def fake_to_thread(func, *args, **kwargs): + return {"loss": None, "accuracy": None, "cer": None, "epochs": 5} + + with patch("main.TRAINING_TOKEN", "secret-token"), \ + patch("main._models_ready", True), \ + patch("main.asyncio.to_thread", side_effect=fake_to_thread): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/train", + files={"file": ("training.zip", _minimal_zip(), "application/zip")}, + headers={"X-Training-Token": "secret-token"}, + ) + + assert response.status_code == 200 + # Gauge was never .set() — accessing the label child still creates it with default 0.0. + assert fresh_metrics.ocr_model_accuracy.labels( + kind="recognition" + )._value.get() == 0.0 + + @pytest.mark.asyncio async def test_ocr_model_accuracy_gauge_set_per_kind_after_successful_training(fresh_metrics): """After /train and /segtrain succeed, ocr_model_accuracy{kind=...} reflects the result."""