fix(training): output CoreML format and fix best-model finder
ketos 7 defaults to safetensors output, but kraken's load_any() only handles CoreML (.mlmodel). Adding --weights-format coreml ensures the hot-swap after training produces a file that load_any() can parse. Also fixed _find_best_model to look for best_<score>.mlmodel (produced by --weights-format coreml) in addition to the previous checkpoint_* pattern. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -305,20 +305,25 @@ def _parse_best_checkpoint(checkpoint_dir: str) -> tuple[float | None, int]:
|
|||||||
|
|
||||||
|
|
||||||
def _find_best_model(checkpoint_dir: str) -> str | None:
|
def _find_best_model(checkpoint_dir: str) -> str | None:
|
||||||
"""Return the checkpoint file with the highest validation metric, or any model file."""
|
"""Return the best final model file produced by ketos train.
|
||||||
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
|
|
||||||
|
With --weights-format coreml, ketos writes ``best_<score>.mlmodel``.
|
||||||
|
Falls back to any .mlmodel in the directory.
|
||||||
|
"""
|
||||||
|
# Prefer the named best file (e.g. best_0.8256.mlmodel or best_0.8256.safetensors)
|
||||||
|
best_pattern = re.compile(r"best_([0-9.]+)\.(mlmodel|safetensors)$")
|
||||||
best_acc: float | None = None
|
best_acc: float | None = None
|
||||||
best_path: str | None = None
|
best_path: str | None = None
|
||||||
for fname in os.listdir(checkpoint_dir):
|
for fname in os.listdir(checkpoint_dir):
|
||||||
m = pattern.match(fname)
|
m = best_pattern.match(fname)
|
||||||
if m:
|
if m:
|
||||||
acc = float(m.group(2))
|
acc = float(m.group(1))
|
||||||
if best_acc is None or acc > best_acc:
|
if best_acc is None or acc > best_acc:
|
||||||
best_acc = acc
|
best_acc = acc
|
||||||
best_path = os.path.join(checkpoint_dir, fname)
|
best_path = os.path.join(checkpoint_dir, fname)
|
||||||
if best_path:
|
if best_path:
|
||||||
return best_path
|
return best_path
|
||||||
# Fallback: any .mlmodel file in the directory
|
# Fallback: any .mlmodel file
|
||||||
for fname in os.listdir(checkpoint_dir):
|
for fname in os.listdir(checkpoint_dir):
|
||||||
if fname.endswith(".mlmodel"):
|
if fname.endswith(".mlmodel"):
|
||||||
return os.path.join(checkpoint_dir, fname)
|
return os.path.join(checkpoint_dir, fname)
|
||||||
@@ -369,6 +374,7 @@ async def train_model(
|
|||||||
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
|
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
|
||||||
"train",
|
"train",
|
||||||
"-f", "page",
|
"-f", "page",
|
||||||
|
"--weights-format", "coreml",
|
||||||
"-o", checkpoint_dir,
|
"-o", checkpoint_dir,
|
||||||
"-q", "fixed",
|
"-q", "fixed",
|
||||||
"-N", "10",
|
"-N", "10",
|
||||||
|
|||||||
Reference in New Issue
Block a user