fix(training): use ketos CLI subprocess instead of missing Python API
kraken.ketos has no .train or .segtrain attributes in Kraken 7 — both are only exposed as CLI commands. Rewrites both training functions to invoke `ketos train` / `ketos segtrain` via subprocess and parse the best val_metric from checkpoint filenames. Also fixes the OcrTrainingCard history so it only shows non-blla runs (recognition model), matching SegmentationTrainingCard which already filtered to blla-only. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -88,5 +88,5 @@ async function startTraining() {
|
||||
<h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase">
|
||||
{m.training_history_heading()}
|
||||
</h3>
|
||||
<TrainingHistory runs={trainingInfo?.runs ?? []} />
|
||||
<TrainingHistory runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName !== 'blla')} />
|
||||
</div>
|
||||
|
||||
@@ -6,7 +6,9 @@ import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import zipfile
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -282,6 +284,47 @@ def _rotate_backups(model_path: str, keep: int = 3) -> None:
|
||||
logger.warning("Could not remove old backup: %s", old)
|
||||
|
||||
|
||||
def _parse_best_checkpoint(checkpoint_dir: str) -> tuple[float | None, int]:
|
||||
"""Parse checkpoint filenames to find the best validation metric and epoch count.
|
||||
|
||||
Kraken saves checkpoints as e.g. ``checkpoint_03-0.9500.ckpt``.
|
||||
Returns (best_accuracy, epoch_count).
|
||||
"""
|
||||
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
|
||||
best_acc: float | None = None
|
||||
max_epoch = 0
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
m = pattern.match(fname)
|
||||
if m:
|
||||
epoch = int(m.group(1))
|
||||
acc = float(m.group(2))
|
||||
max_epoch = max(max_epoch, epoch)
|
||||
if best_acc is None or acc > best_acc:
|
||||
best_acc = acc
|
||||
return best_acc, max_epoch
|
||||
|
||||
|
||||
def _find_best_model(checkpoint_dir: str) -> str | None:
|
||||
"""Return the checkpoint file with the highest validation metric, or any model file."""
|
||||
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
|
||||
best_acc: float | None = None
|
||||
best_path: str | None = None
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
m = pattern.match(fname)
|
||||
if m:
|
||||
acc = float(m.group(2))
|
||||
if best_acc is None or acc > best_acc:
|
||||
best_acc = acc
|
||||
best_path = os.path.join(checkpoint_dir, fname)
|
||||
if best_path:
|
||||
return best_path
|
||||
# Fallback: any .mlmodel file in the directory
|
||||
for fname in os.listdir(checkpoint_dir):
|
||||
if fname.endswith(".mlmodel"):
|
||||
return os.path.join(checkpoint_dir, fname)
|
||||
return None
|
||||
|
||||
|
||||
@app.post("/train")
|
||||
async def train_model(
|
||||
file: UploadFile,
|
||||
@@ -314,29 +357,41 @@ async def train_model(
|
||||
|
||||
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
|
||||
|
||||
# Run ketos train (transfer learning from existing model)
|
||||
from kraken import ketos
|
||||
ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt"))
|
||||
if not ground_truth:
|
||||
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
|
||||
|
||||
log.info("Training on %d ground-truth pairs", len(ground_truth))
|
||||
output_model_path = os.path.join(tmp_dir, "fine_tuned.mlmodel")
|
||||
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
result = ketos.train(
|
||||
ground_truth=ground_truth,
|
||||
load=KRAKEN_MODEL_PATH,
|
||||
output=output_model_path,
|
||||
format_type="path",
|
||||
)
|
||||
cmd = [
|
||||
"ketos", "train",
|
||||
"-f", "path",
|
||||
"-o", checkpoint_dir,
|
||||
"-q", "fixed",
|
||||
"-N", "50",
|
||||
]
|
||||
if os.path.exists(KRAKEN_MODEL_PATH):
|
||||
cmd += ["-i", KRAKEN_MODEL_PATH]
|
||||
cmd += ground_truth
|
||||
|
||||
epochs = getattr(result, "epochs", None) or 0
|
||||
loss = getattr(result, "best_loss", None)
|
||||
accuracy = getattr(result, "best_accuracy", None)
|
||||
# CER = 1 - character accuracy (Kraken reports character-level accuracy)
|
||||
log.info("Running: %s", " ".join(cmd[:6]) + " ...")
|
||||
proc = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log.info("ketos train stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
|
||||
if proc.stderr:
|
||||
log.info("ketos train stderr: %s", proc.stderr[-2000:])
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
||||
|
||||
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
||||
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
||||
log.info("Training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
|
||||
|
||||
log.info("Training complete — epochs=%s loss=%s accuracy=%s cer=%s", epochs, loss, accuracy, cer)
|
||||
# Find the best model file produced by training
|
||||
best_model = _find_best_model(checkpoint_dir)
|
||||
if best_model is None:
|
||||
raise RuntimeError("Training produced no model file")
|
||||
|
||||
# Backup existing model and replace
|
||||
if os.path.exists(KRAKEN_MODEL_PATH):
|
||||
@@ -346,14 +401,14 @@ async def train_model(
|
||||
log.info("Backed up model to %s", backup_path)
|
||||
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
|
||||
|
||||
shutil.move(output_model_path, KRAKEN_MODEL_PATH)
|
||||
shutil.copy2(best_model, KRAKEN_MODEL_PATH)
|
||||
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
|
||||
|
||||
# Reload model in-process
|
||||
kraken_engine.load_models()
|
||||
log.info("Reloaded Kraken model in-process")
|
||||
|
||||
return {"loss": loss, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_training)
|
||||
return result
|
||||
@@ -397,21 +452,33 @@ async def segtrain_model(
|
||||
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
|
||||
|
||||
log.info("Training on %d PAGE XML files", len(xml_files))
|
||||
output_model_path = os.path.join(tmp_dir, "fine_tuned_blla.mlmodel")
|
||||
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
from kraken import ketos
|
||||
result = ketos.segtrain(
|
||||
ground_truth=xml_files,
|
||||
load=blla_model_path if os.path.exists(blla_model_path) else None,
|
||||
output=output_model_path,
|
||||
format_type="path",
|
||||
)
|
||||
cmd = [
|
||||
"ketos", "segtrain",
|
||||
"-o", checkpoint_dir,
|
||||
"-q", "fixed",
|
||||
"-N", "50",
|
||||
]
|
||||
if os.path.exists(blla_model_path):
|
||||
cmd += ["-i", blla_model_path]
|
||||
cmd += xml_files
|
||||
|
||||
epochs = getattr(result, "epochs", None) or 0
|
||||
loss = getattr(result, "best_loss", None)
|
||||
accuracy = getattr(result, "best_accuracy", None)
|
||||
log.info("Running: %s", " ".join(cmd[:5]) + " ...")
|
||||
proc = subprocess.run(cmd, capture_output=True, text=True)
|
||||
log.info("ketos segtrain stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
|
||||
if proc.stderr:
|
||||
log.info("ketos segtrain stderr: %s", proc.stderr[-2000:])
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
||||
|
||||
log.info("Segmentation training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
|
||||
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
||||
log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
|
||||
|
||||
best_model = _find_best_model(checkpoint_dir)
|
||||
if best_model is None:
|
||||
raise RuntimeError("Segmentation training produced no model file")
|
||||
|
||||
if os.path.exists(blla_model_path):
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
@@ -419,10 +486,10 @@ async def segtrain_model(
|
||||
shutil.copy2(blla_model_path, backup_path)
|
||||
_rotate_backups(blla_model_path, keep=3)
|
||||
|
||||
shutil.move(output_model_path, blla_model_path)
|
||||
shutil.copy2(best_model, blla_model_path)
|
||||
log.info("Replaced blla model at %s", blla_model_path)
|
||||
|
||||
return {"loss": loss, "accuracy": accuracy, "cer": None, "epochs": epochs}
|
||||
return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_segtrain)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user