From 669f2f8b98e88df9df391993a011c302b3530bf2 Mon Sep 17 00:00:00 2001 From: Marcel Date: Mon, 13 Apr 2026 21:57:42 +0200 Subject: [PATCH] fix(training): output CoreML format and fix best-model finder ketos 7 defaults to safetensors output, but kraken's load_any() only handles CoreML (.mlmodel). Adding --weights-format coreml ensures the hot-swap after training produces a file that load_any() can parse. Also fixed _find_best_model to look for best_.mlmodel (produced by --weights-format coreml) in addition to the previous checkpoint_* pattern. Co-Authored-By: Claude Sonnet 4.6 --- ocr-service/main.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/ocr-service/main.py b/ocr-service/main.py index 16b2d955..754ed75b 100644 --- a/ocr-service/main.py +++ b/ocr-service/main.py @@ -305,20 +305,25 @@ def _parse_best_checkpoint(checkpoint_dir: str) -> tuple[float | None, int]: def _find_best_model(checkpoint_dir: str) -> str | None: - """Return the checkpoint file with the highest validation metric, or any model file.""" - pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$") + """Return the best final model file produced by ketos train. + + With --weights-format coreml, ketos writes ``best_.mlmodel``. + Falls back to any .mlmodel in the directory. + """ + # Prefer the named best file (e.g. best_0.8256.mlmodel or best_0.8256.safetensors) + best_pattern = re.compile(r"best_([0-9.]+)\.(mlmodel|safetensors)$") best_acc: float | None = None best_path: str | None = None for fname in os.listdir(checkpoint_dir): - m = pattern.match(fname) + m = best_pattern.match(fname) if m: - acc = float(m.group(2)) + acc = float(m.group(1)) if best_acc is None or acc > best_acc: best_acc = acc best_path = os.path.join(checkpoint_dir, fname) if best_path: return best_path - # Fallback: any .mlmodel file in the directory + # Fallback: any .mlmodel file for fname in os.listdir(checkpoint_dir): if fname.endswith(".mlmodel"): return os.path.join(checkpoint_dir, fname) @@ -369,6 +374,7 @@ async def train_model( "ketos", "--workers", "0", "--device", "cpu", "--threads", "2", "train", "-f", "page", + "--weights-format", "coreml", "-o", checkpoint_dir, "-q", "fixed", "-N", "10",