fix(ocr-service): fix ketos 7 segtrain compatibility and prevent OOM
Three issues fixed: 1. --resize both was removed in ketos 7; replaced with --resize union which extends the model's class mapping to include training data classes. 2. ketos ignores -s when -i is present, so the 1800px blla model caused 7+ GB peak RAM and OOM-killed the host (no swap, 5 GB free). Now checks the loaded model's input height: only uses the base model when it was already fine-tuned at 800px; otherwise trains from scratch at 800px (~200 MB peak). After the first run the trained 800px model becomes the base for all subsequent fine-tuning runs. 3. segtrain now computes and returns cer = 1 - accuracy, matching the recognition training path. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -472,16 +472,35 @@ async def segtrain_model(
|
||||
"-q", "fixed",
|
||||
"-N", "10",
|
||||
]
|
||||
# Train at 800px height. The default blla model uses 1800px, which peaks at
|
||||
# ~7+ GB on CPU and kills the host (ketos ignores -s when -i is present, so
|
||||
# we cannot override the height of an existing model).
|
||||
# Strategy: only use the base model if it is already at 800px (i.e. was
|
||||
# produced by a previous fine-tuning run here). Otherwise train from scratch —
|
||||
# the first run bootstraps a 800px model; all subsequent runs fine-tune it.
|
||||
seg_spec = (
|
||||
"[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 "
|
||||
"Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]"
|
||||
)
|
||||
use_base_model = False
|
||||
if os.path.exists(blla_model_path):
|
||||
cmd += ["-i", blla_model_path, "--resize", "both"]
|
||||
try:
|
||||
from kraken.lib import vgsl as _vgsl
|
||||
_m = _vgsl.TorchVGSLModel.load_model(blla_model_path)
|
||||
use_base_model = _m.input[2] == 800 # input is (batch, channels, H, W)
|
||||
if not use_base_model:
|
||||
log.info(
|
||||
"Base model height is %dpx — skipping -i to avoid OOM; "
|
||||
"will train from scratch at 800px",
|
||||
_m.input[2],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if use_base_model:
|
||||
cmd += ["-i", blla_model_path, "--resize", "union", "-s", seg_spec]
|
||||
else:
|
||||
# No pretrained model — train from scratch with reduced height (800px)
|
||||
# to keep peak RAM under ~200 MB on CPU (default 1800px uses ~500 MB+)
|
||||
cmd += [
|
||||
"-s",
|
||||
"[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 "
|
||||
"Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]",
|
||||
]
|
||||
cmd += ["-s", seg_spec]
|
||||
cmd += xml_files
|
||||
|
||||
log.info("Running: %s", " ".join(cmd[:5]) + " ...")
|
||||
@@ -493,7 +512,8 @@ async def segtrain_model(
|
||||
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
||||
|
||||
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
||||
log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
|
||||
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
||||
log.info("Segmentation training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
|
||||
|
||||
best_model = _find_best_model(checkpoint_dir)
|
||||
if best_model is None:
|
||||
@@ -508,7 +528,7 @@ async def segtrain_model(
|
||||
shutil.copy2(best_model, blla_model_path)
|
||||
log.info("Replaced blla model at %s", blla_model_path)
|
||||
|
||||
return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
|
||||
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_segtrain)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user