diff --git a/frontend/src/lib/components/OcrTrainingCard.svelte b/frontend/src/lib/components/OcrTrainingCard.svelte
index 64f96f2f..e34793af 100644
--- a/frontend/src/lib/components/OcrTrainingCard.svelte
+++ b/frontend/src/lib/components/OcrTrainingCard.svelte
@@ -88,5 +88,5 @@ async function startTraining() {
{m.training_history_heading()}
-
+ r.modelName !== 'blla')} />
diff --git a/ocr-service/main.py b/ocr-service/main.py
index 0484e2b1..955f7fbd 100644
--- a/ocr-service/main.py
+++ b/ocr-service/main.py
@@ -6,7 +6,9 @@ import io
import json
import logging
import os
+import re
import shutil
+import subprocess
import tempfile
import zipfile
from contextlib import asynccontextmanager
@@ -282,6 +284,47 @@ def _rotate_backups(model_path: str, keep: int = 3) -> None:
logger.warning("Could not remove old backup: %s", old)
+def _parse_best_checkpoint(checkpoint_dir: str) -> tuple[float | None, int]:
+ """Parse checkpoint filenames to find the best validation metric and epoch count.
+
+ Kraken saves checkpoints as e.g. ``checkpoint_03-0.9500.ckpt``.
+ Returns (best_accuracy, epoch_count).
+ """
+ pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
+ best_acc: float | None = None
+ max_epoch = 0
+ for fname in os.listdir(checkpoint_dir):
+ m = pattern.match(fname)
+ if m:
+ epoch = int(m.group(1))
+ acc = float(m.group(2))
+ max_epoch = max(max_epoch, epoch)
+ if best_acc is None or acc > best_acc:
+ best_acc = acc
+ return best_acc, max_epoch
+
+
+def _find_best_model(checkpoint_dir: str) -> str | None:
+ """Return the checkpoint file with the highest validation metric, or any model file."""
+ pattern = re.compile(r"checkpoint_(\d+)-([0-9.]+)\.(ckpt|mlmodel)$")
+ best_acc: float | None = None
+ best_path: str | None = None
+ for fname in os.listdir(checkpoint_dir):
+ m = pattern.match(fname)
+ if m:
+ acc = float(m.group(2))
+ if best_acc is None or acc > best_acc:
+ best_acc = acc
+ best_path = os.path.join(checkpoint_dir, fname)
+ if best_path:
+ return best_path
+ # Fallback: any .mlmodel file in the directory
+ for fname in os.listdir(checkpoint_dir):
+ if fname.endswith(".mlmodel"):
+ return os.path.join(checkpoint_dir, fname)
+ return None
+
+
@app.post("/train")
async def train_model(
file: UploadFile,
@@ -314,29 +357,41 @@ async def train_model(
log.info("Extracted %d ZIP entries to %s", len(os.listdir(tmp_dir)), tmp_dir)
- # Run ketos train (transfer learning from existing model)
- from kraken import ketos
ground_truth = glob.glob(os.path.join(tmp_dir, "*.gt.txt"))
if not ground_truth:
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
log.info("Training on %d ground-truth pairs", len(ground_truth))
- output_model_path = os.path.join(tmp_dir, "fine_tuned.mlmodel")
+ checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
+ os.makedirs(checkpoint_dir, exist_ok=True)
- result = ketos.train(
- ground_truth=ground_truth,
- load=KRAKEN_MODEL_PATH,
- output=output_model_path,
- format_type="path",
- )
+ cmd = [
+ "ketos", "train",
+ "-f", "path",
+ "-o", checkpoint_dir,
+ "-q", "fixed",
+ "-N", "50",
+ ]
+ if os.path.exists(KRAKEN_MODEL_PATH):
+ cmd += ["-i", KRAKEN_MODEL_PATH]
+ cmd += ground_truth
- epochs = getattr(result, "epochs", None) or 0
- loss = getattr(result, "best_loss", None)
- accuracy = getattr(result, "best_accuracy", None)
- # CER = 1 - character accuracy (Kraken reports character-level accuracy)
+ log.info("Running: %s", " ".join(cmd[:6]) + " ...")
+ proc = subprocess.run(cmd, capture_output=True, text=True)
+ log.info("ketos train stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
+ if proc.stderr:
+ log.info("ketos train stderr: %s", proc.stderr[-2000:])
+ if proc.returncode != 0:
+ raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
+
+ accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
+ log.info("Training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
- log.info("Training complete — epochs=%s loss=%s accuracy=%s cer=%s", epochs, loss, accuracy, cer)
+ # Find the best model file produced by training
+ best_model = _find_best_model(checkpoint_dir)
+ if best_model is None:
+ raise RuntimeError("Training produced no model file")
# Backup existing model and replace
if os.path.exists(KRAKEN_MODEL_PATH):
@@ -346,14 +401,14 @@ async def train_model(
log.info("Backed up model to %s", backup_path)
_rotate_backups(KRAKEN_MODEL_PATH, keep=3)
- shutil.move(output_model_path, KRAKEN_MODEL_PATH)
+ shutil.copy2(best_model, KRAKEN_MODEL_PATH)
log.info("Replaced model at %s", KRAKEN_MODEL_PATH)
# Reload model in-process
kraken_engine.load_models()
log.info("Reloaded Kraken model in-process")
- return {"loss": loss, "accuracy": accuracy, "cer": cer, "epochs": epochs}
+ return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_training)
return result
@@ -397,21 +452,33 @@ async def segtrain_model(
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
log.info("Training on %d PAGE XML files", len(xml_files))
- output_model_path = os.path.join(tmp_dir, "fine_tuned_blla.mlmodel")
+ checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
+ os.makedirs(checkpoint_dir, exist_ok=True)
- from kraken import ketos
- result = ketos.segtrain(
- ground_truth=xml_files,
- load=blla_model_path if os.path.exists(blla_model_path) else None,
- output=output_model_path,
- format_type="path",
- )
+ cmd = [
+ "ketos", "segtrain",
+ "-o", checkpoint_dir,
+ "-q", "fixed",
+ "-N", "50",
+ ]
+ if os.path.exists(blla_model_path):
+ cmd += ["-i", blla_model_path]
+ cmd += xml_files
- epochs = getattr(result, "epochs", None) or 0
- loss = getattr(result, "best_loss", None)
- accuracy = getattr(result, "best_accuracy", None)
+ log.info("Running: %s", " ".join(cmd[:5]) + " ...")
+ proc = subprocess.run(cmd, capture_output=True, text=True)
+ log.info("ketos segtrain stdout: %s", proc.stdout[-2000:] if proc.stdout else "")
+ if proc.stderr:
+ log.info("ketos segtrain stderr: %s", proc.stderr[-2000:])
+ if proc.returncode != 0:
+ raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
- log.info("Segmentation training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
+ accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
+ log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
+
+ best_model = _find_best_model(checkpoint_dir)
+ if best_model is None:
+ raise RuntimeError("Segmentation training produced no model file")
if os.path.exists(blla_model_path):
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
@@ -419,10 +486,10 @@ async def segtrain_model(
shutil.copy2(blla_model_path, backup_path)
_rotate_backups(blla_model_path, keep=3)
- shutil.move(output_model_path, blla_model_path)
+ shutil.copy2(best_model, blla_model_path)
log.info("Replaced blla model at %s", blla_model_path)
- return {"loss": loss, "accuracy": accuracy, "cer": None, "epochs": epochs}
+ return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
result = await asyncio.to_thread(_run_segtrain)
return result