From e33164c4aa4f2d8239d8602100769d4a0bab8751 Mon Sep 17 00:00:00 2001 From: Marcel Date: Mon, 13 Apr 2026 20:50:21 +0200 Subject: [PATCH] fix(training): use ketos CLI subprocess instead of missing Python API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit kraken.ketos has no .train or .segtrain attributes in Kraken 7 — both are only exposed as CLI commands. Rewrites both training functions to invoke `ketos train` / `ketos segtrain` via subprocess and parse the best val_metric from checkpoint filenames. Also fixes the OcrTrainingCard history so it only shows non-blla runs (recognition model), matching SegmentationTrainingCard which already filtered to blla-only. Co-Authored-By: Claude Sonnet 4.6 --- .../src/lib/components/OcrTrainingCard.svelte | 2 +- ocr-service/main.py | 127 +++++++++++++----- 2 files changed, 98 insertions(+), 31 deletions(-) diff --git a/frontend/src/lib/components/OcrTrainingCard.svelte b/frontend/src/lib/components/OcrTrainingCard.svelte index 64f96f2f..e34793af 100644 --- a/frontend/src/lib/components/OcrTrainingCard.svelte +++ b/frontend/src/lib/components/OcrTrainingCard.svelte @@ -88,5 +88,5 @@ async function startTraining() {

{m.training_history_heading()}

- + r.modelName !== 'blla')} /> diff --git a/ocr-service/main.py b/ocr-service/main.py index 0484e2b1..955f7fbd 100644 --- a/ocr-service/main.py +++ b/ocr-service/main.py @@ -6,7 +6,9 @@ import io import json import logging import os +import re import shutil +import subprocess import tempfile import zipfile from contextlib import asynccontextmanager @@ -282,6 +284,47 @@ def _rotate_backups(model_path: str, keep: int = 3) -> None: logger.warning("Could not remove old backup: %s", old) +def _parse_best_checkpoint(checkpoint_dir: str) -> tuple[float | None, int]: + """Parse checkpoint filenames to find the best validation metric and epoch count. + + Kraken saves checkpoints as e.g. ``checkpoint_03-0.9500.ckpt``. + Returns (best_accuracy, epoch_count). + """ + pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$") + best_acc: float | None = None + max_epoch = 0 + for fname in os.listdir(checkpoint_dir): + m = pattern.match(fname) + if m: + epoch = int(m.group(1)) + acc = float(m.group(2)) + max_epoch = max(max_epoch, epoch) + if best_acc is None or acc > best_acc: + best_acc = acc + return best_acc, max_epoch + + +def _find_best_model(checkpoint_dir: str) -> str | None: + """Return the checkpoint file with the highest validation metric, or any model file.""" + pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$") + best_acc: float | None = None + best_path: str | None = None + for fname in os.listdir(checkpoint_dir): + m = pattern.match(fname) + if m: + acc = float(m.group(2)) + if best_acc is None or acc > best_acc: + best_acc = acc + best_path = os.path.join(checkpoint_dir, fname) + if best_path: + return best_path + # Fallback: any .mlmodel file in the directory + for fname in os.listdir(checkpoint_dir): + if fname.endswith(".mlmodel"): + return os.path.join(checkpoint_dir, fname) + return None + + @app.post("/train") async def train_model( file: UploadFile, @@ -314,29 +357,41 @@ async def train_model( log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir) - # Run ketos train (transfer learning from existing model) - from kraken import ketos ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt")) if not ground_truth: raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP") log.info("Training on %d ground-truth pairs", len(ground_truth)) - output_model_path = os.path.join(tmp_dir, "fine_tuned.mlmodel") + checkpoint_dir = os.path.join(tmp_dir, "checkpoints") + os.makedirs(checkpoint_dir, exist_ok=True) - result = ketos.train( - ground_truth=ground_truth, - load=KRAKEN_MODEL_PATH, - output=output_model_path, - format_type="path", - ) + cmd = [ + "ketos", "train", + "-f", "path", + "-o", checkpoint_dir, + "-q", "fixed", + "-N", "50", + ] + if os.path.exists(KRAKEN_MODEL_PATH): + cmd += ["-i", KRAKEN_MODEL_PATH] + cmd += ground_truth - epochs = getattr(result, "epochs", None) or 0 - loss = getattr(result, "best_loss", None) - accuracy = getattr(result, "best_accuracy", None) - # CER = 1 - character accuracy (Kraken reports character-level accuracy) + log.info("Running: %s", " ".join(cmd[:6]) + " ...") + proc = subprocess.run(cmd, capture_output=True, text=True) + log.info("ketos train stdout: %s", proc.stdout[-2000:] if proc.stdout else "") + if proc.stderr: + log.info("ketos train stderr: %s", proc.stderr[-2000:]) + if proc.returncode != 0: + raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}") + + accuracy, epochs = _parse_best_checkpoint(checkpoint_dir) cer = round(1.0 - accuracy, 4) if accuracy is not None else None + log.info("Training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer) - log.info("Training complete — epochs=%s loss=%s accuracy=%s cer=%s", epochs, loss, accuracy, cer) + # Find the best model file produced by training + best_model = _find_best_model(checkpoint_dir) + if best_model is None: + raise RuntimeError("Training produced no model file") # Backup existing model and replace if os.path.exists(KRAKEN_MODEL_PATH): @@ -346,14 +401,14 @@ async def train_model( log.info("Backed up model to %s", backup_path) _rotate_backups(KRAKEN_MODEL_PATH, keep=3) - shutil.move(output_model_path, KRAKEN_MODEL_PATH) + shutil.copy2(best_model, KRAKEN_MODEL_PATH) log.info("Replaced model at %s", KRAKEN_MODEL_PATH) # Reload model in-process kraken_engine.load_models() log.info("Reloaded Kraken model in-process") - return {"loss": loss, "accuracy": accuracy, "cer": cer, "epochs": epochs} + return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} result = await asyncio.to_thread(_run_training) return result @@ -397,21 +452,33 @@ async def segtrain_model( raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP") log.info("Training on %d PAGE XML files", len(xml_files)) - output_model_path = os.path.join(tmp_dir, "fine_tuned_blla.mlmodel") + checkpoint_dir = os.path.join(tmp_dir, "checkpoints") + os.makedirs(checkpoint_dir, exist_ok=True) - from kraken import ketos - result = ketos.segtrain( - ground_truth=xml_files, - load=blla_model_path if os.path.exists(blla_model_path) else None, - output=output_model_path, - format_type="path", - ) + cmd = [ + "ketos", "segtrain", + "-o", checkpoint_dir, + "-q", "fixed", + "-N", "50", + ] + if os.path.exists(blla_model_path): + cmd += ["-i", blla_model_path] + cmd += xml_files - epochs = getattr(result, "epochs", None) or 0 - loss = getattr(result, "best_loss", None) - accuracy = getattr(result, "best_accuracy", None) + log.info("Running: %s", " ".join(cmd[:5]) + " ...") + proc = subprocess.run(cmd, capture_output=True, text=True) + log.info("ketos segtrain stdout: %s", proc.stdout[-2000:] if proc.stdout else "") + if proc.stderr: + log.info("ketos segtrain stderr: %s", proc.stderr[-2000:]) + if proc.returncode != 0: + raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}") - log.info("Segmentation training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy) + accuracy, epochs = _parse_best_checkpoint(checkpoint_dir) + log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy) + + best_model = _find_best_model(checkpoint_dir) + if best_model is None: + raise RuntimeError("Segmentation training produced no model file") if os.path.exists(blla_model_path): timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") @@ -419,10 +486,10 @@ async def segtrain_model( shutil.copy2(blla_model_path, backup_path) _rotate_backups(blla_model_path, keep=3) - shutil.move(output_model_path, blla_model_path) + shutil.copy2(best_model, blla_model_path) log.info("Replaced blla model at %s", blla_model_path) - return {"loss": loss, "accuracy": accuracy, "cer": None, "epochs": epochs} + return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs} result = await asyncio.to_thread(_run_segtrain) return result