fix(training): use ketos CLI subprocess instead of missing Python API

kraken.ketos has no .train or .segtrain attributes in Kraken 7 — both are
only exposed as CLI commands. Rewrites both training functions to invoke
`ketos train` / `ketos segtrain` via subprocess and parse the best
val_metric from checkpoint filenames.

Also fixes the OcrTrainingCard history so it only shows non-blla runs
(recognition model), matching SegmentationTrainingCard which already
filtered to blla-only.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-13 20:50:21 +02:00
parent 22954f348a
commit e33164c4aa
2 changed files with 98 additions and 31 deletions

View File

@@ -88,5 +88,5 @@ async function startTraining() {
<h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase">
{m.training_history_heading()}
</h3>
<TrainingHistory runs={trainingInfo?.runs ?? []} />
<TrainingHistory runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName !== 'blla')} />
</div>

View File

@@ -6,7 +6,9 @@ import io
import json
import logging
import os
import re
import shutil
import subprocess
import tempfile
import zipfile
from contextlib import asynccontextmanager
@@ -282,6 +284,47 @@ def _rotate_backups(model_path: str, keep: int = 3) -> None:
logger.warning("Could not remove old backup: %s", old)
def _parse_best_checkpoint(checkpoint_dir: str) -> tuple[float | None, int]:
"""Parse checkpoint filenames to find the best validation metric and epoch count.
Kraken saves checkpoints as e.g. ``checkpoint_03-0.9500.ckpt``.
Returns (best_accuracy, epoch_count).
"""
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
best_acc: float | None = None
max_epoch = 0
for fname in os.listdir(checkpoint_dir):
m = pattern.match(fname)
if m:
epoch = int(m.group(1))
acc = float(m.group(2))
max_epoch = max(max_epoch, epoch)
if best_acc is None or acc > best_acc:
best_acc = acc
return best_acc, max_epoch
def _find_best_model(checkpoint_dir: str) -> str | None:
"""Return the checkpoint file with the highest validation metric, or any model file."""
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
best_acc: float | None = None
best_path: str | None = None
for fname in os.listdir(checkpoint_dir):
m = pattern.match(fname)
if m:
acc = float(m.group(2))
if best_acc is None or acc > best_acc:
best_acc = acc
best_path = os.path.join(checkpoint_dir, fname)
if best_path:
return best_path
# Fallback: any .mlmodel file in the directory
for fname in os.listdir(checkpoint_dir):
if fname.endswith(".mlmodel"):
return os.path.join(checkpoint_dir, fname)
return None
@app.post("/train")
async def train_model(
file: UploadFile,
@@ -314,29 +357,41 @@ async def train_model(
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
# Run ketos train (transfer learning from existing model)
from kraken import ketos
ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt"))
if not ground_truth:
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
log.info("Training on %d ground-truth pairs", len(ground_truth))
output_model_path = os.path.join(tmp_dir, "fine_tuned.mlmodel")
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
result = ketos.train(
ground_truth=ground_truth,
load=KRAKEN_MODEL_PATH,
output=output_model_path,
format_type="path",
)
cmd = [
"ketos", "train",
"-f", "path",
"-o", checkpoint_dir,
"-q", "fixed",
"-N", "50",
]
if os.path.exists(KRAKEN_MODEL_PATH):
cmd += ["-i", KRAKEN_MODEL_PATH]
cmd += ground_truth
epochs = getattr(result, "epochs", None) or 0
loss = getattr(result, "best_loss", None)
accuracy = getattr(result, "best_accuracy", None)
# CER = 1 - character accuracy (Kraken reports character-level accuracy)
log.info("Running: %s", " ".join(cmd[:6]) + " ...")
proc = subprocess.run(cmd, capture_output=True, text=True)
log.info("ketos train stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
if proc.stderr:
log.info("ketos train stderr: %s", proc.stderr[-2000:])
if proc.returncode != 0:
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
log.info("Training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
log.info("Training complete — epochs=%s loss=%s accuracy=%s cer=%s", epochs, loss, accuracy, cer)
# Find the best model file produced by training
best_model = _find_best_model(checkpoint_dir)
if best_model is None:
raise RuntimeError("Training produced no model file")
# Backup existing model and replace
if os.path.exists(KRAKEN_MODEL_PATH):
@@ -346,14 +401,14 @@ async def train_model(
log.info("Backed up model to %s", backup_path)
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
shutil.move(output_model_path, KRAKEN_MODEL_PATH)
shutil.copy2(best_model, KRAKEN_MODEL_PATH)
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
# Reload model in-process
kraken_engine.load_models()
log.info("Reloaded Kraken model in-process")
return {"loss": loss, "accuracy": accuracy, "cer": cer, "epochs": epochs}
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_training)
return result
@@ -397,21 +452,33 @@ async def segtrain_model(
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
log.info("Training on %d PAGE XML files", len(xml_files))
output_model_path = os.path.join(tmp_dir, "fine_tuned_blla.mlmodel")
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
from kraken import ketos
result = ketos.segtrain(
ground_truth=xml_files,
load=blla_model_path if os.path.exists(blla_model_path) else None,
output=output_model_path,
format_type="path",
)
cmd = [
"ketos", "segtrain",
"-o", checkpoint_dir,
"-q", "fixed",
"-N", "50",
]
if os.path.exists(blla_model_path):
cmd += ["-i", blla_model_path]
cmd += xml_files
epochs = getattr(result, "epochs", None) or 0
loss = getattr(result, "best_loss", None)
accuracy = getattr(result, "best_accuracy", None)
log.info("Running: %s", " ".join(cmd[:5]) + " ...")
proc = subprocess.run(cmd, capture_output=True, text=True)
log.info("ketos segtrain stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
if proc.stderr:
log.info("ketos segtrain stderr: %s", proc.stderr[-2000:])
if proc.returncode != 0:
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
log.info("Segmentation training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
best_model = _find_best_model(checkpoint_dir)
if best_model is None:
raise RuntimeError("Segmentation training produced no model file")
if os.path.exists(blla_model_path):
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
@@ -419,10 +486,10 @@ async def segtrain_model(
shutil.copy2(blla_model_path, backup_path)
_rotate_backups(blla_model_path, keep=3)
shutil.move(output_model_path, blla_model_path)
shutil.copy2(best_model, blla_model_path)
log.info("Replaced blla model at %s", blla_model_path)
return {"loss": loss, "accuracy": accuracy, "cer": None, "epochs": epochs}
return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
result = await asyncio.to_thread(_run_segtrain)
return result