fix(training): use ketos CLI subprocess instead of missing Python API
kraken.ketos has no .train or .segtrain attributes in Kraken 7 — both are only exposed as CLI commands. Rewrites both training functions to invoke `ketos train` / `ketos segtrain` via subprocess and parse the best val_metric from checkpoint filenames. Also fixes the OcrTrainingCard history so it only shows non-blla runs (recognition model), matching SegmentationTrainingCard which already filtered to blla-only. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -88,5 +88,5 @@ async function startTraining() {
|
|||||||
<h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase">
|
<h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase">
|
||||||
{m.training_history_heading()}
|
{m.training_history_heading()}
|
||||||
</h3>
|
</h3>
|
||||||
<TrainingHistory runs={trainingInfo?.runs ?? []} />
|
<TrainingHistory runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName !== 'blla')} />
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import io
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import zipfile
|
import zipfile
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@@ -282,6 +284,47 @@ def _rotate_backups(model_path: str, keep: int = 3) -> None:
|
|||||||
logger.warning("Could not remove old backup: %s", old)
|
logger.warning("Could not remove old backup: %s", old)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_best_checkpoint(checkpoint_dir: str) -> tuple[float | None, int]:
|
||||||
|
"""Parse checkpoint filenames to find the best validation metric and epoch count.
|
||||||
|
|
||||||
|
Kraken saves checkpoints as e.g. ``checkpoint_03-0.9500.ckpt``.
|
||||||
|
Returns (best_accuracy, epoch_count).
|
||||||
|
"""
|
||||||
|
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
|
||||||
|
best_acc: float | None = None
|
||||||
|
max_epoch = 0
|
||||||
|
for fname in os.listdir(checkpoint_dir):
|
||||||
|
m = pattern.match(fname)
|
||||||
|
if m:
|
||||||
|
epoch = int(m.group(1))
|
||||||
|
acc = float(m.group(2))
|
||||||
|
max_epoch = max(max_epoch, epoch)
|
||||||
|
if best_acc is None or acc > best_acc:
|
||||||
|
best_acc = acc
|
||||||
|
return best_acc, max_epoch
|
||||||
|
|
||||||
|
|
||||||
|
def _find_best_model(checkpoint_dir: str) -> str | None:
|
||||||
|
"""Return the checkpoint file with the highest validation metric, or any model file."""
|
||||||
|
pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
|
||||||
|
best_acc: float | None = None
|
||||||
|
best_path: str | None = None
|
||||||
|
for fname in os.listdir(checkpoint_dir):
|
||||||
|
m = pattern.match(fname)
|
||||||
|
if m:
|
||||||
|
acc = float(m.group(2))
|
||||||
|
if best_acc is None or acc > best_acc:
|
||||||
|
best_acc = acc
|
||||||
|
best_path = os.path.join(checkpoint_dir, fname)
|
||||||
|
if best_path:
|
||||||
|
return best_path
|
||||||
|
# Fallback: any .mlmodel file in the directory
|
||||||
|
for fname in os.listdir(checkpoint_dir):
|
||||||
|
if fname.endswith(".mlmodel"):
|
||||||
|
return os.path.join(checkpoint_dir, fname)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@app.post("/train")
|
@app.post("/train")
|
||||||
async def train_model(
|
async def train_model(
|
||||||
file: UploadFile,
|
file: UploadFile,
|
||||||
@@ -314,29 +357,41 @@ async def train_model(
|
|||||||
|
|
||||||
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
|
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
|
||||||
|
|
||||||
# Run ketos train (transfer learning from existing model)
|
|
||||||
from kraken import ketos
|
|
||||||
ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt"))
|
ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt"))
|
||||||
if not ground_truth:
|
if not ground_truth:
|
||||||
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
|
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
|
||||||
|
|
||||||
log.info("Training on %d ground-truth pairs", len(ground_truth))
|
log.info("Training on %d ground-truth pairs", len(ground_truth))
|
||||||
output_model_path = os.path.join(tmp_dir, "fine_tuned.mlmodel")
|
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
result = ketos.train(
|
cmd = [
|
||||||
ground_truth=ground_truth,
|
"ketos", "train",
|
||||||
load=KRAKEN_MODEL_PATH,
|
"-f", "path",
|
||||||
output=output_model_path,
|
"-o", checkpoint_dir,
|
||||||
format_type="path",
|
"-q", "fixed",
|
||||||
)
|
"-N", "50",
|
||||||
|
]
|
||||||
|
if os.path.exists(KRAKEN_MODEL_PATH):
|
||||||
|
cmd += ["-i", KRAKEN_MODEL_PATH]
|
||||||
|
cmd += ground_truth
|
||||||
|
|
||||||
epochs = getattr(result, "epochs", None) or 0
|
log.info("Running: %s", " ".join(cmd[:6]) + " ...")
|
||||||
loss = getattr(result, "best_loss", None)
|
proc = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
accuracy = getattr(result, "best_accuracy", None)
|
log.info("ketos train stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
|
||||||
# CER = 1 - character accuracy (Kraken reports character-level accuracy)
|
if proc.stderr:
|
||||||
|
log.info("ketos train stderr: %s", proc.stderr[-2000:])
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
||||||
|
|
||||||
|
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
||||||
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
||||||
|
log.info("Training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
|
||||||
|
|
||||||
log.info("Training complete — epochs=%s loss=%s accuracy=%s cer=%s", epochs, loss, accuracy, cer)
|
# Find the best model file produced by training
|
||||||
|
best_model = _find_best_model(checkpoint_dir)
|
||||||
|
if best_model is None:
|
||||||
|
raise RuntimeError("Training produced no model file")
|
||||||
|
|
||||||
# Backup existing model and replace
|
# Backup existing model and replace
|
||||||
if os.path.exists(KRAKEN_MODEL_PATH):
|
if os.path.exists(KRAKEN_MODEL_PATH):
|
||||||
@@ -346,14 +401,14 @@ async def train_model(
|
|||||||
log.info("Backed up model to %s", backup_path)
|
log.info("Backed up model to %s", backup_path)
|
||||||
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
|
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
|
||||||
|
|
||||||
shutil.move(output_model_path, KRAKEN_MODEL_PATH)
|
shutil.copy2(best_model, KRAKEN_MODEL_PATH)
|
||||||
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
|
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
|
||||||
|
|
||||||
# Reload model in-process
|
# Reload model in-process
|
||||||
kraken_engine.load_models()
|
kraken_engine.load_models()
|
||||||
log.info("Reloaded Kraken model in-process")
|
log.info("Reloaded Kraken model in-process")
|
||||||
|
|
||||||
return {"loss": loss, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||||
|
|
||||||
result = await asyncio.to_thread(_run_training)
|
result = await asyncio.to_thread(_run_training)
|
||||||
return result
|
return result
|
||||||
@@ -397,21 +452,33 @@ async def segtrain_model(
|
|||||||
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
|
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
|
||||||
|
|
||||||
log.info("Training on %d PAGE XML files", len(xml_files))
|
log.info("Training on %d PAGE XML files", len(xml_files))
|
||||||
output_model_path = os.path.join(tmp_dir, "fine_tuned_blla.mlmodel")
|
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
from kraken import ketos
|
cmd = [
|
||||||
result = ketos.segtrain(
|
"ketos", "segtrain",
|
||||||
ground_truth=xml_files,
|
"-o", checkpoint_dir,
|
||||||
load=blla_model_path if os.path.exists(blla_model_path) else None,
|
"-q", "fixed",
|
||||||
output=output_model_path,
|
"-N", "50",
|
||||||
format_type="path",
|
]
|
||||||
)
|
if os.path.exists(blla_model_path):
|
||||||
|
cmd += ["-i", blla_model_path]
|
||||||
|
cmd += xml_files
|
||||||
|
|
||||||
epochs = getattr(result, "epochs", None) or 0
|
log.info("Running: %s", " ".join(cmd[:5]) + " ...")
|
||||||
loss = getattr(result, "best_loss", None)
|
proc = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
accuracy = getattr(result, "best_accuracy", None)
|
log.info("ketos segtrain stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
|
||||||
|
if proc.stderr:
|
||||||
|
log.info("ketos segtrain stderr: %s", proc.stderr[-2000:])
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
||||||
|
|
||||||
log.info("Segmentation training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
|
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
||||||
|
log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
|
||||||
|
|
||||||
|
best_model = _find_best_model(checkpoint_dir)
|
||||||
|
if best_model is None:
|
||||||
|
raise RuntimeError("Segmentation training produced no model file")
|
||||||
|
|
||||||
if os.path.exists(blla_model_path):
|
if os.path.exists(blla_model_path):
|
||||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||||
@@ -419,10 +486,10 @@ async def segtrain_model(
|
|||||||
shutil.copy2(blla_model_path, backup_path)
|
shutil.copy2(blla_model_path, backup_path)
|
||||||
_rotate_backups(blla_model_path, keep=3)
|
_rotate_backups(blla_model_path, keep=3)
|
||||||
|
|
||||||
shutil.move(output_model_path, blla_model_path)
|
shutil.copy2(best_model, blla_model_path)
|
||||||
log.info("Replaced blla model at %s", blla_model_path)
|
log.info("Replaced blla model at %s", blla_model_path)
|
||||||
|
|
||||||
return {"loss": loss, "accuracy": accuracy, "cer": None, "epochs": epochs}
|
return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
|
||||||
|
|
||||||
result = await asyncio.to_thread(_run_segtrain)
|
result = await asyncio.to_thread(_run_segtrain)
|
||||||
return result
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user