fix(ocr): fix segmentation training for ketos 7 and low-memory hosts #234

Merged
marcel merged 13 commits from fix/ocr-segtrain-training into main 2026-04-14 21:17:54 +02:00
15 changed files with 385 additions and 35 deletions

View File

@@ -21,9 +21,10 @@ PORT_FRONTEND=5173
PORT_MAILPIT_UI=8100
PORT_MAILPIT_SMTP=1025
# OCR Training — set a secret token to protect the /train and /segtrain endpoints on the
# Python OCR microservice. Leave empty to disable token authentication (development only).
# OCR_TRAINING_TOKEN=change-me-in-production
# OCR Training — secret token required to call /train and /segtrain on the OCR service.
# Also set in the backend so it can pass the token through. Must not be empty in production.
# Generate with: python3 -c "import secrets; print(secrets.token_hex(32))"
OCR_TRAINING_TOKEN=change-me-in-production
# Production SMTP — uncomment and fill in to send real emails instead of catching them
# APP_BASE_URL=https://your-domain.example.com

View File

@@ -12,5 +12,5 @@ public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun,
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
List<OcrTrainingRun> findTop5ByOrderByCreatedAtDesc();
List<OcrTrainingRun> findTop10ByOrderByCreatedAtDesc();
}

View File

@@ -45,6 +45,13 @@ public class OcrTrainingService {
List<OcrTrainingRun> runs
) {}
private void assertNoRunningTraining() {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
}
// Not safe for horizontal scaling: training reloads the Kraken model in-process on the
// Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique
// index) prevents concurrent training API calls, but cannot prevent two OCR service replicas
@@ -53,10 +60,7 @@ public class OcrTrainingService {
// Short transaction: guard check + create RUNNING row, then commit immediately.
// The DB connection is released before the OCR HTTP call, which can take several minutes.
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
assertNoRunningTraining();
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
if (eligibleBlocks.size() < 5) {
@@ -120,10 +124,7 @@ public class OcrTrainingService {
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
// Same pattern as triggerTraining: narrow transactions around DB writes only.
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
assertNoRunningTraining();
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
if (segBlocks.size() < 5) {
@@ -162,11 +163,12 @@ public class OcrTrainingService {
return Objects.requireNonNull(txTemplate.execute(status -> {
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setCer(result.cer());
run.setLoss(result.loss());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
OcrTrainingRun updated = trainingRunRepository.save(run);
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
log.info("[trainingRun={}] Segmentation training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
return updated;
}));
} catch (Exception e) {
@@ -193,7 +195,7 @@ public class OcrTrainingService {
int totalOcrBlocks = (int) blockRepository.count();
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop10ByOrderByCreatedAtDesc();
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
return new TrainingInfoResponse(

View File

@@ -53,7 +53,7 @@ class OcrTrainingServiceTest {
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
when(blockRepository.count()).thenReturn(0L);
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
when(runRepository.findTop10ByOrderByCreatedAtDesc()).thenReturn(List.of());
when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
}
@@ -146,6 +146,90 @@ class OcrTrainingServiceTest {
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
}
// ─── triggerSegTraining ───────────────────────────────────────────────────
@Test
void triggerSegTraining_throws409_whenRunningRunExists() {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING))
.thenReturn(Optional.of(OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(2).modelName("blla").build()));
assertThatThrownBy(() -> service.triggerSegTraining(null))
.isInstanceOf(DomainException.class)
.extracting("status")
.satisfies(s -> assertThat(s.toString()).contains("409"));
}
@Test
void triggerSegTraining_throws422_whenFewerThan5Segments() {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
when(segExportService.querySegmentationBlocks()).thenReturn(List.of(
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build()
));
assertThatThrownBy(() -> service.triggerSegTraining(null))
.isInstanceOf(DomainException.class);
}
@Test
void triggerSegTraining_createsRunWithBlla_andMarksDoneWithCer() throws Exception {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
UUID docA = UUID.randomUUID();
UUID docB = UUID.randomUUID();
List<TranscriptionBlock> segs = List.of(
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docB).build()
);
when(segExportService.querySegmentationBlocks()).thenReturn(segs);
when(segExportService.exportToZip()).thenReturn(out -> {});
when(ocrClient.segtrainModel(any())).thenReturn(new OcrClient.TrainingResult(null, 0.92, 0.08, 5));
OcrTrainingRun saved = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(2).modelName("blla").build();
when(runRepository.save(any())).thenReturn(saved);
service.triggerSegTraining(null);
verify(runRepository, atLeastOnce()).save(argThat(run ->
run.getStatus() == TrainingStatus.DONE
&& "blla".equals(run.getModelName())
&& run.getCer() != null));
}
@Test
void triggerSegTraining_marksRunFailed_whenOcrClientThrows() throws Exception {
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
UUID docA = UUID.randomUUID();
List<TranscriptionBlock> segs = List.of(
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build()
);
when(segExportService.querySegmentationBlocks()).thenReturn(segs);
when(segExportService.exportToZip()).thenReturn(out -> {});
when(ocrClient.segtrainModel(any())).thenThrow(new RuntimeException("seg timeout"));
OcrTrainingRun saved = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("blla").build();
when(runRepository.save(any())).thenReturn(saved);
service.triggerSegTraining(null);
verify(runRepository, atLeastOnce()).save(argThat(run ->
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
}
// ─── Orphan recovery ──────────────────────────────────────────────────────
@Test

View File

@@ -83,11 +83,11 @@ services:
restart: unless-stopped
expose:
- "8000"
mem_limit: 8g
memswap_limit: 8g
mem_limit: 12g
memswap_limit: 12g
volumes:
- ocr_models:/app/models
- ocr_cache:/root/.cache
- ocr_cache:/root/.cache # Hugging Face / ketos model download cache — prevents re-downloads on container recreate
environment:
KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel
TRAINING_TOKEN: "${OCR_TRAINING_TOKEN:-}"
@@ -102,7 +102,7 @@ services:
interval: 10s
timeout: 5s
retries: 12
start_period: 60s
start_period: 120s
# --- Backend: Spring Boot ---
backend:
@@ -145,6 +145,7 @@ services:
SPRING_MAIL_PROPERTIES_MAIL_SMTP_AUTH: ${MAIL_SMTP_AUTH:-false}
SPRING_MAIL_PROPERTIES_MAIL_SMTP_STARTTLS_ENABLE: ${MAIL_STARTTLS_ENABLE:-false}
APP_OCR_BASE_URL: http://ocr-service:8000
APP_OCR_TRAINING_TOKEN: "${OCR_TRAINING_TOKEN:-}"
ports:
- "${PORT_BACKEND}:8080"
networks:

View File

@@ -554,6 +554,7 @@
"training_history_col_cer": "Fehlerrate",
"training_status_done": "Fertig",
"training_status_failed": "Fehler",
"training_error_detail_label": "Fehlerdetails",
"training_status_running": "Läuft…",
"training_seg_heading": "Segmentierung trainieren",
"training_seg_description": "Starte ein neues Training mit annotierten Segmentierungsbereichen, um die Texterkennung zu verbessern.",

View File

@@ -554,6 +554,7 @@
"training_history_col_cer": "Error Rate",
"training_status_done": "Done",
"training_status_failed": "Failed",
"training_error_detail_label": "Error details",
"training_status_running": "Running…",
"training_seg_heading": "Train segmentation",
"training_seg_description": "Start a new training run using annotated segmentation regions to improve text detection.",

View File

@@ -554,6 +554,7 @@
"training_history_col_cer": "Tasa de error",
"training_status_done": "Listo",
"training_status_failed": "Error",
"training_error_detail_label": "Detalles del error",
"training_status_running": "Ejecutando…",
"training_seg_heading": "Entrenar segmentación",
"training_seg_description": "Inicia un nuevo entrenamiento con regiones de segmentación anotadas para mejorar la detección de texto.",

View File

@@ -20,6 +20,12 @@ interface Props {
let { runs }: Props = $props();
const COLLAPSED_COUNT = 3;
let expanded = $state(false);
const visibleRuns = $derived(expanded ? runs : runs.slice(0, COLLAPSED_COUNT));
const hasMore = $derived(runs.length > COLLAPSED_COUNT);
const dateFormatter = new Intl.DateTimeFormat('de-DE', {
day: 'numeric',
month: 'short',
@@ -46,7 +52,7 @@ function formatCer(cer: number | undefined | null): string {
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_cer()}</th>
</tr>
</thead>
<tbody>
<tbody id="training-history-rows">
{#if runs.length === 0}
<tr>
<td colspan="5" class="py-4 text-center text-sm text-ink-2">
@@ -54,7 +60,7 @@ function formatCer(cer: number | undefined | null): string {
</td>
</tr>
{:else}
{#each runs as run (run.id)}
{#each visibleRuns as run (run.id)}
<tr class="border-b border-line/50 last:border-0">
<td class="py-2 text-ink-2">{formatDate(run.createdAt)}</td>
<td class="py-2">
@@ -79,7 +85,6 @@ function formatCer(cer: number | undefined | null): string {
{:else if run.status === 'FAILED'}
<span
class="inline-flex items-center gap-1 rounded-sm bg-red-100 px-1.5 py-0.5 text-xs font-medium text-red-700"
title={run.errorMessage}
>
<svg
aria-hidden="true"
@@ -95,13 +100,21 @@ function formatCer(cer: number | undefined | null): string {
</svg>
{m.training_status_failed()}
</span>
{#if run.errorMessage}
<details class="mt-0.5">
<summary class="cursor-pointer text-xs text-red-700 underline">
{m.training_error_detail_label()}
</summary>
<p class="mt-1 text-xs text-red-600">{run.errorMessage}</p>
</details>
{/if}
{:else}
<span
class="inline-flex items-center gap-1 rounded-sm bg-yellow-100 px-1.5 py-0.5 text-xs font-medium text-yellow-700"
>
<span
aria-hidden="true"
class="h-1.5 w-1.5 animate-pulse rounded-full bg-yellow-500"
class="h-1.5 w-1.5 rounded-full bg-yellow-500 motion-safe:animate-pulse"
></span>
{m.training_status_running()}
</span>
@@ -117,3 +130,17 @@ function formatCer(cer: number | undefined | null): string {
{/if}
</tbody>
</table>
{#if hasMore}
<div class="mt-2 text-center">
<button
type="button"
aria-expanded={expanded}
aria-controls="training-history-rows"
class="text-xs font-medium text-ink-3 transition-colors hover:text-ink"
onclick={() => (expanded = !expanded)}
>
{expanded ? m.comp_expandable_show_less() : m.comp_expandable_show_more()}
</button>
</div>
{/if}

View File

@@ -0,0 +1,52 @@
import { afterEach, describe, expect, it } from 'vitest';
import { cleanup, render } from 'vitest-browser-svelte';
import { page } from 'vitest/browser';
import TrainingHistory from './TrainingHistory.svelte';
afterEach(cleanup);
function makeRun(i: number) {
return {
id: `run-${i}`,
status: 'DONE' as const,
blockCount: 10,
documentCount: 2,
modelName: 'german_kurrent',
createdAt: `2026-01-0${i + 1}T12:00:00Z`,
completedAt: `2026-01-0${i + 1}T12:05:00Z`
};
}
const fiveRuns = Array.from({ length: 5 }, (_, i) => makeRun(i));
const twoRuns = Array.from({ length: 2 }, (_, i) => makeRun(i));
describe('TrainingHistory — expand/collapse', () => {
it('shows only 3 runs initially when more than 3 exist', async () => {
render(TrainingHistory, { runs: fiveRuns });
const rows = page.getByRole('row');
// 1 header row + 3 data rows = 4 total
await expect.element(rows.nth(3)).toBeInTheDocument();
await expect.element(rows.nth(4)).not.toBeInTheDocument();
await expect.element(page.getByRole('button', { name: /Mehr anzeigen/i })).toBeInTheDocument();
});
it('shows all runs after clicking the expand button', async () => {
render(TrainingHistory, { runs: fiveRuns });
await page.getByRole('button', { name: /Mehr anzeigen/i }).click();
const rows = page.getByRole('row');
// 1 header row + 5 data rows = 6 total
await expect.element(rows.nth(5)).toBeInTheDocument();
});
it('hides the toggle button when 3 or fewer runs exist', async () => {
render(TrainingHistory, { runs: twoRuns });
await expect
.element(page.getByRole('button', { name: /Mehr anzeigen/i }))
.not.toBeInTheDocument();
});
});

View File

@@ -1,4 +1,4 @@
FROM python:3.11-slim
FROM python:3.11.9-slim
WORKDIR /app
@@ -21,6 +21,8 @@ RUN pip install --no-cache-dir -r requirements.txt
COPY . .
RUN chmod +x /app/entrypoint.sh
EXPOSE 8000
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
CMD ["/app/entrypoint.sh"]

View File

@@ -0,0 +1,80 @@
"""Validates the blla segmentation base model and downloads it if needed.
Run at container startup before uvicorn. ketos 7 requires the model in
CoreML protobuf or safetensors format — legacy PyTorch ZIP archives
(torch.save output from kraken <4) are not loadable and will be replaced.
Exits non-zero on failure so Docker marks the container unhealthy rather
than silently starting with a broken model.
"""
import glob
import logging
import os
import shutil
import subprocess
import sys
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s:ensure_blla_model:%(message)s",
)
log = logging.getLogger(__name__)
BLLA_MODEL_PATH = os.environ.get("BLLA_MODEL_PATH", "/app/models/blla.mlmodel")
# DOI for "General segmentation model for print and handwriting" — ketos 7 compatible.
BLLA_MODEL_DOI = "10.5281/zenodo.14602569"
HTRMOPO_DIR = os.path.expanduser("~/.local/share/htrmopo")
def _model_is_loadable(path: str) -> bool:
try:
from kraken.lib import vgsl
vgsl.TorchVGSLModel.load_model(path)
return True
except (RuntimeError, OSError, ValueError) as e:
log.warning("Model at %s failed to load: %s", path, e)
return False
except Exception:
log.debug("Unexpected error loading model at %s", path, exc_info=True)
return False
def _download_blla() -> str:
log.info("Downloading blla model (DOI %s) ...", BLLA_MODEL_DOI)
result = subprocess.run(
["kraken", "get", BLLA_MODEL_DOI],
capture_output=True,
text=True,
)
if result.returncode != 0:
log.error("kraken get failed: %s", result.stderr)
sys.exit(1)
candidates = sorted(glob.glob(os.path.join(HTRMOPO_DIR, "*/blla.mlmodel")))
if not candidates:
log.error("Downloaded blla.mlmodel not found under %s", HTRMOPO_DIR)
sys.exit(1)
return candidates[-1]
def main() -> None:
if os.path.exists(BLLA_MODEL_PATH):
if _model_is_loadable(BLLA_MODEL_PATH):
log.info("blla model OK: %s", BLLA_MODEL_PATH)
return
log.warning(
"blla model at %s is in an incompatible format — replacing", BLLA_MODEL_PATH
)
os.rename(BLLA_MODEL_PATH, BLLA_MODEL_PATH + ".incompatible")
os.makedirs(os.path.dirname(BLLA_MODEL_PATH), exist_ok=True)
downloaded = _download_blla()
shutil.copy2(downloaded, BLLA_MODEL_PATH)
log.info("Installed blla model at %s", BLLA_MODEL_PATH)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,9 @@
#!/bin/bash
set -euo pipefail
# Validate the blla segmentation base model and download it if missing or
# incompatible. ketos 7 dropped support for legacy PyTorch ZIP archives —
# this ensures the volume always holds a loadable CoreML protobuf model.
python3 /app/ensure_blla_model.py
exec uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1

View File

@@ -472,16 +472,35 @@ async def segtrain_model(
"-q", "fixed",
"-N", "10",
]
# Train at 800px height. The default blla model uses 1800px, which peaks at
# ~7+ GB on CPU and kills the host (ketos ignores -s when -i is present, so
# we cannot override the height of an existing model).
# Strategy: only use the base model if it is already at 800px (i.e. was
# produced by a previous fine-tuning run here). Otherwise train from scratch —
# the first run bootstraps a 800px model; all subsequent runs fine-tune it.
seg_spec = (
"[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 "
"Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]"
)
use_base_model = False
if os.path.exists(blla_model_path):
cmd += ["-i", blla_model_path, "--resize", "both"]
try:
from kraken.lib import vgsl as _vgsl
_m = _vgsl.TorchVGSLModel.load_model(blla_model_path)
use_base_model = _m.input[2] == 800 # input is (batch, channels, H, W)
if not use_base_model:
log.info(
"Base model height is %dpx — skipping -i to avoid OOM; "
"will train from scratch at 800px",
_m.input[2],
)
except Exception as exc:
log.warning("Could not inspect base model height, training from scratch: %s", exc)
if use_base_model:
cmd += ["-i", blla_model_path, "--resize", "union", "-s", seg_spec]
else:
# No pretrained model — train from scratch with reduced height (800px)
# to keep peak RAM under ~200 MB on CPU (default 1800px uses ~500 MB+)
cmd += [
"-s",
"[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 "
"Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]",
]
cmd += ["-s", seg_spec]
cmd += xml_files
log.info("Running: %s", " ".join(cmd[:5]) + " ...")
@@ -493,7 +512,8 @@ async def segtrain_model(
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
log.info("Segmentation training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
best_model = _find_best_model(checkpoint_dir)
if best_model is None:
@@ -508,7 +528,7 @@ async def segtrain_model(
shutil.copy2(best_model, blla_model_path)
log.info("Replaced blla model at %s", blla_model_path)
return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_segtrain)
return result

View File

@@ -0,0 +1,69 @@
"""Unit tests for ensure_blla_model.main()."""
from unittest.mock import MagicMock, call, patch
import ensure_blla_model
# ─── Model already loadable ───────────────────────────────────────────────────
def test_main_returns_early_when_model_is_loadable():
"""When the model exists and loads cleanly, no download or rename occurs."""
with (
patch("os.path.exists", return_value=True),
patch.object(ensure_blla_model, "_model_is_loadable", return_value=True),
patch.object(ensure_blla_model, "_download_blla") as mock_download,
patch("os.rename") as mock_rename,
):
ensure_blla_model.main()
mock_download.assert_not_called()
mock_rename.assert_not_called()
# ─── Model exists but is incompatible ─────────────────────────────────────────
def test_main_replaces_incompatible_model():
"""An incompatible model is renamed and replaced with a fresh download."""
fake_path = "/app/models/blla.mlmodel"
downloaded_path = "/tmp/downloaded.mlmodel"
with (
patch.object(ensure_blla_model, "BLLA_MODEL_PATH", fake_path),
patch("os.path.exists", return_value=True),
patch.object(ensure_blla_model, "_model_is_loadable", return_value=False),
patch.object(ensure_blla_model, "_download_blla", return_value=downloaded_path),
patch("os.rename") as mock_rename,
patch("shutil.copy2") as mock_copy,
patch("os.makedirs"),
):
ensure_blla_model.main()
mock_rename.assert_called_once_with(fake_path, fake_path + ".incompatible")
mock_copy.assert_called_once_with(downloaded_path, fake_path)
# ─── Model missing ────────────────────────────────────────────────────────────
def test_main_downloads_when_model_missing():
"""When the model file doesn't exist at all, it is downloaded without rename."""
fake_path = "/app/models/blla.mlmodel"
downloaded_path = "/tmp/downloaded.mlmodel"
with (
patch.object(ensure_blla_model, "BLLA_MODEL_PATH", fake_path),
patch("os.path.exists", return_value=False),
patch.object(ensure_blla_model, "_model_is_loadable") as mock_loadable,
patch.object(ensure_blla_model, "_download_blla", return_value=downloaded_path),
patch("os.rename") as mock_rename,
patch("shutil.copy2") as mock_copy,
patch("os.makedirs"),
):
ensure_blla_model.main()
mock_loadable.assert_not_called()
mock_rename.assert_not_called()
mock_copy.assert_called_once_with(downloaded_path, fake_path)