fix(training): use ketos CLI subprocess instead of missing Python API

kraken.ketos has no .train or .segtrain attributes in Kraken 7 — both are
only exposed as CLI commands. Rewrites both training functions to invoke
`ketos train` / `ketos segtrain` via subprocess and parse the best
val_metric from checkpoint filenames.

Also fixes the OcrTrainingCard history so it only shows non-blla runs
(recognition model), matching SegmentationTrainingCard which already
filtered to blla-only.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-13 20:50:21 +02:00
parent 22954f348a
commit e33164c4aa
2 changed files with 98 additions and 31 deletions

View File

@@ -88,5 +88,5 @@ async function startTraining() {
<h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase"> <h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase">
{m.training_history_heading()} {m.training_history_heading()}
</h3> </h3>
<TrainingHistory runs={trainingInfo?.runs ?? []} /> <TrainingHistory runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName !== 'blla')} />
</div> </div>

View File

@@ -6,7 +6,9 @@ import io
import json import json
import logging import logging
import os import os
import re
import shutil import shutil
import subprocess
import tempfile import tempfile
import zipfile import zipfile
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@@ -282,6 +284,47 @@ def _rotate_backups(model_path: str, keep: int = 3) -> None:
logger.warning("Could not remove old backup: %s", old) logger.warning("Could not remove old backup: %s", old)
def _parse_best_checkpoint(checkpoint_dir: str) -> tuple[float | None, int]:
"""Parse checkpoint filenames to find the best validation metric and epoch count.
Kraken saves checkpoints as e.g. ``checkpoint_03-0.9500.ckpt``.
Returns (best_accuracy, epoch_count).
"""
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
best_acc: float | None = None
max_epoch = 0
for fname in os.listdir(checkpoint_dir):
m = pattern.match(fname)
if m:
epoch = int(m.group(1))
acc = float(m.group(2))
max_epoch = max(max_epoch, epoch)
if best_acc is None or acc > best_acc:
best_acc = acc
return best_acc, max_epoch
def _find_best_model(checkpoint_dir: str) -> str | None:
"""Return the checkpoint file with the highest validation metric, or any model file."""
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
best_acc: float | None = None
best_path: str | None = None
for fname in os.listdir(checkpoint_dir):
m = pattern.match(fname)
if m:
acc = float(m.group(2))
if best_acc is None or acc > best_acc:
best_acc = acc
best_path = os.path.join(checkpoint_dir, fname)
if best_path:
return best_path
# Fallback: any .mlmodel file in the directory
for fname in os.listdir(checkpoint_dir):
if fname.endswith(".mlmodel"):
return os.path.join(checkpoint_dir, fname)
return None
@app.post("/train") @app.post("/train")
async def train_model( async def train_model(
file: UploadFile, file: UploadFile,
@@ -314,29 +357,41 @@ async def train_model(
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir) log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
# Run ketos train (transfer learning from existing model)
from kraken import ketos
ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt")) ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt"))
if not ground_truth: if not ground_truth:
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP") raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
log.info("Training on %d ground-truth pairs", len(ground_truth)) log.info("Training on %d ground-truth pairs", len(ground_truth))
output_model_path = os.path.join(tmp_dir, "fine_tuned.mlmodel") checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
result = ketos.train( cmd = [
ground_truth=ground_truth, "ketos", "train",
load=KRAKEN_MODEL_PATH, "-f", "path",
output=output_model_path, "-o", checkpoint_dir,
format_type="path", "-q", "fixed",
) "-N", "50",
]
if os.path.exists(KRAKEN_MODEL_PATH):
cmd += ["-i", KRAKEN_MODEL_PATH]
cmd += ground_truth
epochs = getattr(result, "epochs", None) or 0 log.info("Running: %s", " ".join(cmd[:6]) + " ...")
loss = getattr(result, "best_loss", None) proc = subprocess.run(cmd, capture_output=True, text=True)
accuracy = getattr(result, "best_accuracy", None) log.info("ketos train stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
# CER = 1 - character accuracy (Kraken reports character-level accuracy) if proc.stderr:
log.info("ketos train stderr: %s", proc.stderr[-2000:])
if proc.returncode != 0:
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
cer = round(1.0 - accuracy, 4) if accuracy is not None else None cer = round(1.0 - accuracy, 4) if accuracy is not None else None
log.info("Training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
log.info("Training complete — epochs=%s loss=%s accuracy=%s cer=%s", epochs, loss, accuracy, cer) # Find the best model file produced by training
best_model = _find_best_model(checkpoint_dir)
if best_model is None:
raise RuntimeError("Training produced no model file")
# Backup existing model and replace # Backup existing model and replace
if os.path.exists(KRAKEN_MODEL_PATH): if os.path.exists(KRAKEN_MODEL_PATH):
@@ -346,14 +401,14 @@ async def train_model(
log.info("Backed up model to %s", backup_path) log.info("Backed up model to %s", backup_path)
_rotate_backups(KRAKEN_MODEL_PATH, keep=3) _rotate_backups(KRAKEN_MODEL_PATH, keep=3)
shutil.move(output_model_path, KRAKEN_MODEL_PATH) shutil.copy2(best_model, KRAKEN_MODEL_PATH)
log.info("Replaced model at %s", KRAKEN_MODEL_PATH) log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
# Reload model in-process # Reload model in-process
kraken_engine.load_models() kraken_engine.load_models()
log.info("Reloaded Kraken model in-process") log.info("Reloaded Kraken model in-process")
return {"loss": loss, "accuracy": accuracy, "cer": cer, "epochs": epochs} return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_training) result = await asyncio.to_thread(_run_training)
return result return result
@@ -397,21 +452,33 @@ async def segtrain_model(
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP") raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
log.info("Training on %d PAGE XML files", len(xml_files)) log.info("Training on %d PAGE XML files", len(xml_files))
output_model_path = os.path.join(tmp_dir, "fine_tuned_blla.mlmodel") checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
from kraken import ketos cmd = [
result = ketos.segtrain( "ketos", "segtrain",
ground_truth=xml_files, "-o", checkpoint_dir,
load=blla_model_path if os.path.exists(blla_model_path) else None, "-q", "fixed",
output=output_model_path, "-N", "50",
format_type="path", ]
) if os.path.exists(blla_model_path):
cmd += ["-i", blla_model_path]
cmd += xml_files
epochs = getattr(result, "epochs", None) or 0 log.info("Running: %s", " ".join(cmd[:5]) + " ...")
loss = getattr(result, "best_loss", None) proc = subprocess.run(cmd, capture_output=True, text=True)
accuracy = getattr(result, "best_accuracy", None) log.info("ketos segtrain stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
if proc.stderr:
log.info("ketos segtrain stderr: %s", proc.stderr[-2000:])
if proc.returncode != 0:
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
log.info("Segmentation training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy) accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
best_model = _find_best_model(checkpoint_dir)
if best_model is None:
raise RuntimeError("Segmentation training produced no model file")
if os.path.exists(blla_model_path): if os.path.exists(blla_model_path):
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
@@ -419,10 +486,10 @@ async def segtrain_model(
shutil.copy2(blla_model_path, backup_path) shutil.copy2(blla_model_path, backup_path)
_rotate_backups(blla_model_path, keep=3) _rotate_backups(blla_model_path, keep=3)
shutil.move(output_model_path, blla_model_path) shutil.copy2(best_model, blla_model_path)
log.info("Replaced blla model at %s", blla_model_path) log.info("Replaced blla model at %s", blla_model_path)
return {"loss": loss, "accuracy": accuracy, "cer": None, "epochs": epochs} return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
result = await asyncio.to_thread(_run_segtrain) result = await asyncio.to_thread(_run_segtrain)
return result return result