fix(ocr-service): fix ketos 7 segtrain compatibility and prevent OOM

Three issues fixed:

1. --resize both was removed in ketos 7; replaced with --resize union
   which extends the model's class mapping to include training data classes.

2. ketos ignores -s when -i is present, so the 1800px blla model caused
   7+ GB peak RAM and OOM-killed the host (no swap, 5 GB free).
   Now checks the loaded model's input height: only uses the base model
   when it was already fine-tuned at 800px; otherwise trains from scratch
   at 800px (~200 MB peak). After the first run the trained 800px model
   becomes the base for all subsequent fine-tuning runs.

3. segtrain now computes and returns cer = 1 - accuracy, matching the
   recognition training path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-14 13:06:44 +02:00
committed by marcel
parent e8375d6c72
commit a5979c4069

View File

@@ -472,16 +472,35 @@ async def segtrain_model(
"-q", "fixed",
"-N", "10",
]
# Train at 800px height. The default blla model uses 1800px, which peaks at
# ~7+ GB on CPU and kills the host (ketos ignores -s when -i is present, so
# we cannot override the height of an existing model).
# Strategy: only use the base model if it is already at 800px (i.e. was
# produced by a previous fine-tuning run here). Otherwise train from scratch —
# the first run bootstraps a 800px model; all subsequent runs fine-tune it.
seg_spec = (
"[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 "
"Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]"
)
use_base_model = False
if os.path.exists(blla_model_path):
cmd += ["-i", blla_model_path, "--resize", "both"]
try:
from kraken.lib import vgsl as _vgsl
_m = _vgsl.TorchVGSLModel.load_model(blla_model_path)
use_base_model = _m.input[2] == 800 # input is (batch, channels, H, W)
if not use_base_model:
log.info(
"Base model height is %dpx — skipping -i to avoid OOM; "
"will train from scratch at 800px",
_m.input[2],
)
except Exception:
pass
if use_base_model:
cmd += ["-i", blla_model_path, "--resize", "union", "-s", seg_spec]
else:
# No pretrained model — train from scratch with reduced height (800px)
# to keep peak RAM under ~200 MB on CPU (default 1800px uses ~500 MB+)
cmd += [
"-s",
"[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 "
"Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]",
]
cmd += ["-s", seg_spec]
cmd += xml_files
log.info("Running: %s", " ".join(cmd[:5]) + " ...")
@@ -493,7 +512,8 @@ async def segtrain_model(
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
log.info("Segmentation training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
best_model = _find_best_model(checkpoint_dir)
if best_model is None:
@@ -508,7 +528,7 @@ async def segtrain_model(
shutil.copy2(best_model, blla_model_path)
log.info("Replaced blla model at %s", blla_model_path)
return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_segtrain)
return result