fix(ocr-service): fix ketos 7 segtrain compatibility and prevent OOM
Three issues fixed: 1. --resize both was removed in ketos 7; replaced with --resize union which extends the model's class mapping to include training data classes. 2. ketos ignores -s when -i is present, so the 1800px blla model caused 7+ GB peak RAM and OOM-killed the host (no swap, 5 GB free). Now checks the loaded model's input height: only uses the base model when it was already fine-tuned at 800px; otherwise trains from scratch at 800px (~200 MB peak). After the first run the trained 800px model becomes the base for all subsequent fine-tuning runs. 3. segtrain now computes and returns cer = 1 - accuracy, matching the recognition training path. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -472,16 +472,35 @@ async def segtrain_model(
|
|||||||
"-q", "fixed",
|
"-q", "fixed",
|
||||||
"-N", "10",
|
"-N", "10",
|
||||||
]
|
]
|
||||||
|
# Train at 800px height. The default blla model uses 1800px, which peaks at
|
||||||
|
# ~7+ GB on CPU and kills the host (ketos ignores -s when -i is present, so
|
||||||
|
# we cannot override the height of an existing model).
|
||||||
|
# Strategy: only use the base model if it is already at 800px (i.e. was
|
||||||
|
# produced by a previous fine-tuning run here). Otherwise train from scratch —
|
||||||
|
# the first run bootstraps a 800px model; all subsequent runs fine-tune it.
|
||||||
|
seg_spec = (
|
||||||
|
"[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 "
|
||||||
|
"Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]"
|
||||||
|
)
|
||||||
|
use_base_model = False
|
||||||
if os.path.exists(blla_model_path):
|
if os.path.exists(blla_model_path):
|
||||||
cmd += ["-i", blla_model_path, "--resize", "both"]
|
try:
|
||||||
|
from kraken.lib import vgsl as _vgsl
|
||||||
|
_m = _vgsl.TorchVGSLModel.load_model(blla_model_path)
|
||||||
|
use_base_model = _m.input[2] == 800 # input is (batch, channels, H, W)
|
||||||
|
if not use_base_model:
|
||||||
|
log.info(
|
||||||
|
"Base model height is %dpx — skipping -i to avoid OOM; "
|
||||||
|
"will train from scratch at 800px",
|
||||||
|
_m.input[2],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if use_base_model:
|
||||||
|
cmd += ["-i", blla_model_path, "--resize", "union", "-s", seg_spec]
|
||||||
else:
|
else:
|
||||||
# No pretrained model — train from scratch with reduced height (800px)
|
cmd += ["-s", seg_spec]
|
||||||
# to keep peak RAM under ~200 MB on CPU (default 1800px uses ~500 MB+)
|
|
||||||
cmd += [
|
|
||||||
"-s",
|
|
||||||
"[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 "
|
|
||||||
"Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]",
|
|
||||||
]
|
|
||||||
cmd += xml_files
|
cmd += xml_files
|
||||||
|
|
||||||
log.info("Running: %s", " ".join(cmd[:5]) + " ...")
|
log.info("Running: %s", " ".join(cmd[:5]) + " ...")
|
||||||
@@ -493,7 +512,8 @@ async def segtrain_model(
|
|||||||
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
||||||
|
|
||||||
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
||||||
log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
|
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
||||||
|
log.info("Segmentation training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
|
||||||
|
|
||||||
best_model = _find_best_model(checkpoint_dir)
|
best_model = _find_best_model(checkpoint_dir)
|
||||||
if best_model is None:
|
if best_model is None:
|
||||||
@@ -508,7 +528,7 @@ async def segtrain_model(
|
|||||||
shutil.copy2(best_model, blla_model_path)
|
shutil.copy2(best_model, blla_model_path)
|
||||||
log.info("Replaced blla model at %s", blla_model_path)
|
log.info("Replaced blla model at %s", blla_model_path)
|
||||||
|
|
||||||
return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
|
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||||
|
|
||||||
result = await asyncio.to_thread(_run_segtrain)
|
result = await asyncio.to_thread(_run_segtrain)
|
||||||
return result
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user