diff --git a/ocr-service/ensure_blla_model.py b/ocr-service/ensure_blla_model.py index a8aed449..d0fa0305 100644 --- a/ocr-service/ensure_blla_model.py +++ b/ocr-service/ensure_blla_model.py @@ -33,9 +33,12 @@ def _model_is_loadable(path: str) -> bool: vgsl.TorchVGSLModel.load_model(path) return True - except Exception as e: + except (RuntimeError, OSError, ValueError) as e: log.warning("Model at %s failed to load: %s", path, e) return False + except Exception: + log.debug("Unexpected error loading model at %s", path, exc_info=True) + return False def _download_blla() -> str: diff --git a/ocr-service/main.py b/ocr-service/main.py index 63473737..3545f6a4 100644 --- a/ocr-service/main.py +++ b/ocr-service/main.py @@ -494,8 +494,8 @@ async def segtrain_model( "will train from scratch at 800px", _m.input[2], ) - except Exception: - pass + except Exception as exc: + log.warning("Could not inspect base model height, training from scratch: %s", exc) if use_base_model: cmd += ["-i", blla_model_path, "--resize", "union", "-s", seg_spec] diff --git a/ocr-service/test_ensure_blla_model.py b/ocr-service/test_ensure_blla_model.py new file mode 100644 index 00000000..cc171e92 --- /dev/null +++ b/ocr-service/test_ensure_blla_model.py @@ -0,0 +1,69 @@ +"""Unit tests for ensure_blla_model.main().""" + +from unittest.mock import MagicMock, call, patch + +import ensure_blla_model + + +# ─── Model already loadable ─────────────────────────────────────────────────── + + +def test_main_returns_early_when_model_is_loadable(): + """When the model exists and loads cleanly, no download or rename occurs.""" + with ( + patch("os.path.exists", return_value=True), + patch.object(ensure_blla_model, "_model_is_loadable", return_value=True), + patch.object(ensure_blla_model, "_download_blla") as mock_download, + patch("os.rename") as mock_rename, + ): + ensure_blla_model.main() + + mock_download.assert_not_called() + mock_rename.assert_not_called() + + +# ─── Model exists but is incompatible ───────────────────────────────────────── + + +def test_main_replaces_incompatible_model(): + """An incompatible model is renamed and replaced with a fresh download.""" + fake_path = "/app/models/blla.mlmodel" + downloaded_path = "/tmp/downloaded.mlmodel" + + with ( + patch.object(ensure_blla_model, "BLLA_MODEL_PATH", fake_path), + patch("os.path.exists", return_value=True), + patch.object(ensure_blla_model, "_model_is_loadable", return_value=False), + patch.object(ensure_blla_model, "_download_blla", return_value=downloaded_path), + patch("os.rename") as mock_rename, + patch("shutil.copy2") as mock_copy, + patch("os.makedirs"), + ): + ensure_blla_model.main() + + mock_rename.assert_called_once_with(fake_path, fake_path + ".incompatible") + mock_copy.assert_called_once_with(downloaded_path, fake_path) + + +# ─── Model missing ──────────────────────────────────────────────────────────── + + +def test_main_downloads_when_model_missing(): + """When the model file doesn't exist at all, it is downloaded without rename.""" + fake_path = "/app/models/blla.mlmodel" + downloaded_path = "/tmp/downloaded.mlmodel" + + with ( + patch.object(ensure_blla_model, "BLLA_MODEL_PATH", fake_path), + patch("os.path.exists", return_value=False), + patch.object(ensure_blla_model, "_model_is_loadable") as mock_loadable, + patch.object(ensure_blla_model, "_download_blla", return_value=downloaded_path), + patch("os.rename") as mock_rename, + patch("shutil.copy2") as mock_copy, + patch("os.makedirs"), + ): + ensure_blla_model.main() + + mock_loadable.assert_not_called() + mock_rename.assert_not_called() + mock_copy.assert_called_once_with(downloaded_path, fake_path)