Adds ensure_blla_model.py which loads the blla segmentation model with ketos on every container start. If the model is missing or in the legacy PyTorch ZIP format (incompatible with ketos 7), it re-downloads the correct CoreML protobuf model from Zenodo (DOI 10.5281/zenodo.14602569). The Dockerfile now uses entrypoint.sh which runs this check before starting uvicorn. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
"""Validates the blla segmentation base model and downloads it if needed.
|
|
|
|
Run at container startup before uvicorn. ketos 7 requires the model in
|
|
CoreML protobuf or safetensors format — legacy PyTorch ZIP archives
|
|
(torch.save output from kraken <4) are not loadable and will be replaced.
|
|
|
|
Exits non-zero on failure so Docker marks the container unhealthy rather
|
|
than silently starting with a broken model.
|
|
"""
|
|
|
|
import glob
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(levelname)s:ensure_blla_model:%(message)s",
|
|
)
|
|
log = logging.getLogger(__name__)
|
|
|
|
BLLA_MODEL_PATH = os.environ.get("BLLA_MODEL_PATH", "/app/models/blla.mlmodel")
|
|
# DOI for "General segmentation model for print and handwriting" — ketos 7 compatible.
|
|
BLLA_MODEL_DOI = "10.5281/zenodo.14602569"
|
|
HTRMOPO_DIR = os.path.expanduser("~/.local/share/htrmopo")
|
|
|
|
|
|
def _model_is_loadable(path: str) -> bool:
|
|
try:
|
|
from kraken.lib import vgsl
|
|
|
|
vgsl.TorchVGSLModel.load_model(path)
|
|
return True
|
|
except Exception as e:
|
|
log.warning("Model at %s failed to load: %s", path, e)
|
|
return False
|
|
|
|
|
|
def _download_blla() -> str:
|
|
log.info("Downloading blla model (DOI %s) ...", BLLA_MODEL_DOI)
|
|
result = subprocess.run(
|
|
["kraken", "get", BLLA_MODEL_DOI],
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
if result.returncode != 0:
|
|
log.error("kraken get failed: %s", result.stderr)
|
|
sys.exit(1)
|
|
|
|
candidates = sorted(glob.glob(os.path.join(HTRMOPO_DIR, "*/blla.mlmodel")))
|
|
if not candidates:
|
|
log.error("Downloaded blla.mlmodel not found under %s", HTRMOPO_DIR)
|
|
sys.exit(1)
|
|
|
|
return candidates[-1]
|
|
|
|
|
|
def main() -> None:
|
|
if os.path.exists(BLLA_MODEL_PATH):
|
|
if _model_is_loadable(BLLA_MODEL_PATH):
|
|
log.info("blla model OK: %s", BLLA_MODEL_PATH)
|
|
return
|
|
log.warning(
|
|
"blla model at %s is in an incompatible format — replacing", BLLA_MODEL_PATH
|
|
)
|
|
os.rename(BLLA_MODEL_PATH, BLLA_MODEL_PATH + ".incompatible")
|
|
|
|
os.makedirs(os.path.dirname(BLLA_MODEL_PATH), exist_ok=True)
|
|
downloaded = _download_blla()
|
|
shutil.copy2(downloaded, BLLA_MODEL_PATH)
|
|
log.info("Installed blla model at %s", BLLA_MODEL_PATH)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|