diff --git a/.env.example b/.env.example index 9011b4a5..5b928337 100644 --- a/.env.example +++ b/.env.example @@ -21,9 +21,10 @@ PORT_FRONTEND=5173 PORT_MAILPIT_UI=8100 PORT_MAILPIT_SMTP=1025 -# OCR Training — set a secret token to protect the /train and /segtrain endpoints on the -# Python OCR microservice. Leave empty to disable token authentication (development only). -# OCR_TRAINING_TOKEN=change-me-in-production +# OCR Training — secret token required to call /train and /segtrain on the OCR service. +# Also set in the backend so it can pass the token through. Must not be empty in production. +# Generate with: python3 -c "import secrets; print(secrets.token_hex(32))" +OCR_TRAINING_TOKEN=change-me-in-production # Production SMTP — uncomment and fill in to send real emails instead of catching them # APP_BASE_URL=https://your-domain.example.com diff --git a/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java b/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java index 0bab0e99..fe7d61d1 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java @@ -12,5 +12,5 @@ public interface OcrTrainingRunRepository extends JpaRepository findFirstByStatus(TrainingStatus status); - List findTop5ByOrderByCreatedAtDesc(); + List findTop10ByOrderByCreatedAtDesc(); } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java index 9c8f4e5d..65828999 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java @@ -45,6 +45,13 @@ public class OcrTrainingService { List runs ) {} + private void assertNoRunningTraining() { + if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { + throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, + "A training run is already in progress"); + } + } + // Not safe for horizontal scaling: training reloads the Kraken model in-process on the // Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique // index) prevents concurrent training API calls, but cannot prevent two OCR service replicas @@ -53,10 +60,7 @@ public class OcrTrainingService { // Short transaction: guard check + create RUNNING row, then commit immediately. // The DB connection is released before the OCR HTTP call, which can take several minutes. OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { - if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { - throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, - "A training run is already in progress"); - } + assertNoRunningTraining(); var eligibleBlocks = trainingDataExportService.queryEligibleBlocks(); if (eligibleBlocks.size() < 5) { @@ -120,10 +124,7 @@ public class OcrTrainingService { public OcrTrainingRun triggerSegTraining(UUID triggeredBy) { // Same pattern as triggerTraining: narrow transactions around DB writes only. OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { - if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { - throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, - "A training run is already in progress"); - } + assertNoRunningTraining(); var segBlocks = segmentationTrainingExportService.querySegmentationBlocks(); if (segBlocks.size() < 5) { @@ -162,11 +163,12 @@ public class OcrTrainingService { return Objects.requireNonNull(txTemplate.execute(status -> { run.setStatus(TrainingStatus.DONE); run.setCompletedAt(Instant.now()); + run.setCer(result.cer()); run.setLoss(result.loss()); run.setAccuracy(result.accuracy()); run.setEpochs(result.epochs()); OcrTrainingRun updated = trainingRunRepository.save(run); - log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs()); + log.info("[trainingRun={}] Segmentation training completed — cer={} epochs={}", runId, result.cer(), result.epochs()); return updated; })); } catch (Exception e) { @@ -193,7 +195,7 @@ public class OcrTrainingService { int totalOcrBlocks = (int) blockRepository.count(); int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size(); - List recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc(); + List recentRuns = trainingRunRepository.findTop10ByOrderByCreatedAtDesc(); OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0); return new TrainingInfoResponse( diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java index 7d83188f..62ef00f6 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -53,7 +53,7 @@ class OcrTrainingServiceTest { service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); when(blockRepository.count()).thenReturn(0L); - when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of()); + when(runRepository.findTop10ByOrderByCreatedAtDesc()).thenReturn(List.of()); when(segExportService.querySegmentationBlocks()).thenReturn(List.of()); } @@ -146,6 +146,90 @@ class OcrTrainingServiceTest { run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null)); } + // ─── triggerSegTraining ─────────────────────────────────────────────────── + + @Test + void triggerSegTraining_throws409_whenRunningRunExists() { + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)) + .thenReturn(Optional.of(OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(2).modelName("blla").build())); + + assertThatThrownBy(() -> service.triggerSegTraining(null)) + .isInstanceOf(DomainException.class) + .extracting("status") + .satisfies(s -> assertThat(s.toString()).contains("409")); + } + + @Test + void triggerSegTraining_throws422_whenFewerThan5Segments() { + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + when(segExportService.querySegmentationBlocks()).thenReturn(List.of( + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build() + )); + + assertThatThrownBy(() -> service.triggerSegTraining(null)) + .isInstanceOf(DomainException.class); + } + + @Test + void triggerSegTraining_createsRunWithBlla_andMarksDoneWithCer() throws Exception { + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + + UUID docA = UUID.randomUUID(); + UUID docB = UUID.randomUUID(); + List segs = List.of( + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docB).build() + ); + when(segExportService.querySegmentationBlocks()).thenReturn(segs); + when(segExportService.exportToZip()).thenReturn(out -> {}); + when(ocrClient.segtrainModel(any())).thenReturn(new OcrClient.TrainingResult(null, 0.92, 0.08, 5)); + + OcrTrainingRun saved = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(2).modelName("blla").build(); + when(runRepository.save(any())).thenReturn(saved); + + service.triggerSegTraining(null); + + verify(runRepository, atLeastOnce()).save(argThat(run -> + run.getStatus() == TrainingStatus.DONE + && "blla".equals(run.getModelName()) + && run.getCer() != null)); + } + + @Test + void triggerSegTraining_marksRunFailed_whenOcrClientThrows() throws Exception { + when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + + UUID docA = UUID.randomUUID(); + List segs = List.of( + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(), + TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build() + ); + when(segExportService.querySegmentationBlocks()).thenReturn(segs); + when(segExportService.exportToZip()).thenReturn(out -> {}); + when(ocrClient.segtrainModel(any())).thenThrow(new RuntimeException("seg timeout")); + + OcrTrainingRun saved = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(1).modelName("blla").build(); + when(runRepository.save(any())).thenReturn(saved); + + service.triggerSegTraining(null); + + verify(runRepository, atLeastOnce()).save(argThat(run -> + run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null)); + } + // ─── Orphan recovery ────────────────────────────────────────────────────── @Test diff --git a/docker-compose.yml b/docker-compose.yml index bf57501a..0637d417 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -83,11 +83,11 @@ services: restart: unless-stopped expose: - "8000" - mem_limit: 8g - memswap_limit: 8g + mem_limit: 12g + memswap_limit: 12g volumes: - ocr_models:/app/models - - ocr_cache:/root/.cache + - ocr_cache:/root/.cache # Hugging Face / ketos model download cache — prevents re-downloads on container recreate environment: KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel TRAINING_TOKEN: "${OCR_TRAINING_TOKEN:-}" @@ -102,7 +102,7 @@ services: interval: 10s timeout: 5s retries: 12 - start_period: 60s + start_period: 120s # --- Backend: Spring Boot --- backend: @@ -145,6 +145,7 @@ services: SPRING_MAIL_PROPERTIES_MAIL_SMTP_AUTH: ${MAIL_SMTP_AUTH:-false} SPRING_MAIL_PROPERTIES_MAIL_SMTP_STARTTLS_ENABLE: ${MAIL_STARTTLS_ENABLE:-false} APP_OCR_BASE_URL: http://ocr-service:8000 + APP_OCR_TRAINING_TOKEN: "${OCR_TRAINING_TOKEN:-}" ports: - "${PORT_BACKEND}:8080" networks: diff --git a/frontend/messages/de.json b/frontend/messages/de.json index 48135368..cfb4578d 100644 --- a/frontend/messages/de.json +++ b/frontend/messages/de.json @@ -554,6 +554,7 @@ "training_history_col_cer": "Fehlerrate", "training_status_done": "Fertig", "training_status_failed": "Fehler", + "training_error_detail_label": "Fehlerdetails", "training_status_running": "Läuft…", "training_seg_heading": "Segmentierung trainieren", "training_seg_description": "Starte ein neues Training mit annotierten Segmentierungsbereichen, um die Texterkennung zu verbessern.", diff --git a/frontend/messages/en.json b/frontend/messages/en.json index 8f55b111..887b4091 100644 --- a/frontend/messages/en.json +++ b/frontend/messages/en.json @@ -554,6 +554,7 @@ "training_history_col_cer": "Error Rate", "training_status_done": "Done", "training_status_failed": "Failed", + "training_error_detail_label": "Error details", "training_status_running": "Running…", "training_seg_heading": "Train segmentation", "training_seg_description": "Start a new training run using annotated segmentation regions to improve text detection.", diff --git a/frontend/messages/es.json b/frontend/messages/es.json index b4b0ba65..ab357f76 100644 --- a/frontend/messages/es.json +++ b/frontend/messages/es.json @@ -554,6 +554,7 @@ "training_history_col_cer": "Tasa de error", "training_status_done": "Listo", "training_status_failed": "Error", + "training_error_detail_label": "Detalles del error", "training_status_running": "Ejecutando…", "training_seg_heading": "Entrenar segmentación", "training_seg_description": "Inicia un nuevo entrenamiento con regiones de segmentación anotadas para mejorar la detección de texto.", diff --git a/frontend/src/lib/components/TrainingHistory.svelte b/frontend/src/lib/components/TrainingHistory.svelte index ea194cd4..586fee32 100644 --- a/frontend/src/lib/components/TrainingHistory.svelte +++ b/frontend/src/lib/components/TrainingHistory.svelte @@ -20,6 +20,12 @@ interface Props { let { runs }: Props = $props(); +const COLLAPSED_COUNT = 3; +let expanded = $state(false); + +const visibleRuns = $derived(expanded ? runs : runs.slice(0, COLLAPSED_COUNT)); +const hasMore = $derived(runs.length > COLLAPSED_COUNT); + const dateFormatter = new Intl.DateTimeFormat('de-DE', { day: 'numeric', month: 'short', @@ -46,7 +52,7 @@ function formatCer(cer: number | undefined | null): string { {m.training_history_col_cer()} - + {#if runs.length === 0} @@ -54,7 +60,7 @@ function formatCer(cer: number | undefined | null): string { {:else} - {#each runs as run (run.id)} + {#each visibleRuns as run (run.id)} {formatDate(run.createdAt)} @@ -79,7 +85,6 @@ function formatCer(cer: number | undefined | null): string { {:else if run.status === 'FAILED'} + {#if run.errorMessage} +
+ + {m.training_error_detail_label()} + +

{run.errorMessage}

+
+ {/if} {:else} {m.training_status_running()} @@ -117,3 +130,17 @@ function formatCer(cer: number | undefined | null): string { {/if} + +{#if hasMore} +
+ +
+{/if} diff --git a/frontend/src/lib/components/TrainingHistory.svelte.spec.ts b/frontend/src/lib/components/TrainingHistory.svelte.spec.ts new file mode 100644 index 00000000..728b167a --- /dev/null +++ b/frontend/src/lib/components/TrainingHistory.svelte.spec.ts @@ -0,0 +1,52 @@ +import { afterEach, describe, expect, it } from 'vitest'; +import { cleanup, render } from 'vitest-browser-svelte'; +import { page } from 'vitest/browser'; +import TrainingHistory from './TrainingHistory.svelte'; + +afterEach(cleanup); + +function makeRun(i: number) { + return { + id: `run-${i}`, + status: 'DONE' as const, + blockCount: 10, + documentCount: 2, + modelName: 'german_kurrent', + createdAt: `2026-01-0${i + 1}T12:00:00Z`, + completedAt: `2026-01-0${i + 1}T12:05:00Z` + }; +} + +const fiveRuns = Array.from({ length: 5 }, (_, i) => makeRun(i)); +const twoRuns = Array.from({ length: 2 }, (_, i) => makeRun(i)); + +describe('TrainingHistory — expand/collapse', () => { + it('shows only 3 runs initially when more than 3 exist', async () => { + render(TrainingHistory, { runs: fiveRuns }); + + const rows = page.getByRole('row'); + // 1 header row + 3 data rows = 4 total + await expect.element(rows.nth(3)).toBeInTheDocument(); + await expect.element(rows.nth(4)).not.toBeInTheDocument(); + + await expect.element(page.getByRole('button', { name: /Mehr anzeigen/i })).toBeInTheDocument(); + }); + + it('shows all runs after clicking the expand button', async () => { + render(TrainingHistory, { runs: fiveRuns }); + + await page.getByRole('button', { name: /Mehr anzeigen/i }).click(); + + const rows = page.getByRole('row'); + // 1 header row + 5 data rows = 6 total + await expect.element(rows.nth(5)).toBeInTheDocument(); + }); + + it('hides the toggle button when 3 or fewer runs exist', async () => { + render(TrainingHistory, { runs: twoRuns }); + + await expect + .element(page.getByRole('button', { name: /Mehr anzeigen/i })) + .not.toBeInTheDocument(); + }); +}); diff --git a/ocr-service/Dockerfile b/ocr-service/Dockerfile index 01b8ebfa..25d383a4 100644 --- a/ocr-service/Dockerfile +++ b/ocr-service/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.11-slim +FROM python:3.11.9-slim WORKDIR /app @@ -21,6 +21,8 @@ RUN pip install --no-cache-dir -r requirements.txt COPY . . +RUN chmod +x /app/entrypoint.sh + EXPOSE 8000 -CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"] +CMD ["/app/entrypoint.sh"] diff --git a/ocr-service/ensure_blla_model.py b/ocr-service/ensure_blla_model.py new file mode 100644 index 00000000..d0fa0305 --- /dev/null +++ b/ocr-service/ensure_blla_model.py @@ -0,0 +1,80 @@ +"""Validates the blla segmentation base model and downloads it if needed. + +Run at container startup before uvicorn. ketos 7 requires the model in +CoreML protobuf or safetensors format — legacy PyTorch ZIP archives +(torch.save output from kraken <4) are not loadable and will be replaced. + +Exits non-zero on failure so Docker marks the container unhealthy rather +than silently starting with a broken model. +""" + +import glob +import logging +import os +import shutil +import subprocess +import sys + +logging.basicConfig( + level=logging.INFO, + format="%(levelname)s:ensure_blla_model:%(message)s", +) +log = logging.getLogger(__name__) + +BLLA_MODEL_PATH = os.environ.get("BLLA_MODEL_PATH", "/app/models/blla.mlmodel") +# DOI for "General segmentation model for print and handwriting" — ketos 7 compatible. +BLLA_MODEL_DOI = "10.5281/zenodo.14602569" +HTRMOPO_DIR = os.path.expanduser("~/.local/share/htrmopo") + + +def _model_is_loadable(path: str) -> bool: + try: + from kraken.lib import vgsl + + vgsl.TorchVGSLModel.load_model(path) + return True + except (RuntimeError, OSError, ValueError) as e: + log.warning("Model at %s failed to load: %s", path, e) + return False + except Exception: + log.debug("Unexpected error loading model at %s", path, exc_info=True) + return False + + +def _download_blla() -> str: + log.info("Downloading blla model (DOI %s) ...", BLLA_MODEL_DOI) + result = subprocess.run( + ["kraken", "get", BLLA_MODEL_DOI], + capture_output=True, + text=True, + ) + if result.returncode != 0: + log.error("kraken get failed: %s", result.stderr) + sys.exit(1) + + candidates = sorted(glob.glob(os.path.join(HTRMOPO_DIR, "*/blla.mlmodel"))) + if not candidates: + log.error("Downloaded blla.mlmodel not found under %s", HTRMOPO_DIR) + sys.exit(1) + + return candidates[-1] + + +def main() -> None: + if os.path.exists(BLLA_MODEL_PATH): + if _model_is_loadable(BLLA_MODEL_PATH): + log.info("blla model OK: %s", BLLA_MODEL_PATH) + return + log.warning( + "blla model at %s is in an incompatible format — replacing", BLLA_MODEL_PATH + ) + os.rename(BLLA_MODEL_PATH, BLLA_MODEL_PATH + ".incompatible") + + os.makedirs(os.path.dirname(BLLA_MODEL_PATH), exist_ok=True) + downloaded = _download_blla() + shutil.copy2(downloaded, BLLA_MODEL_PATH) + log.info("Installed blla model at %s", BLLA_MODEL_PATH) + + +if __name__ == "__main__": + main() diff --git a/ocr-service/entrypoint.sh b/ocr-service/entrypoint.sh new file mode 100644 index 00000000..ec6892a8 --- /dev/null +++ b/ocr-service/entrypoint.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -euo pipefail + +# Validate the blla segmentation base model and download it if missing or +# incompatible. ketos 7 dropped support for legacy PyTorch ZIP archives — +# this ensures the volume always holds a loadable CoreML protobuf model. +python3 /app/ensure_blla_model.py + +exec uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1 diff --git a/ocr-service/main.py b/ocr-service/main.py index 0f656aaa..3545f6a4 100644 --- a/ocr-service/main.py +++ b/ocr-service/main.py @@ -472,16 +472,35 @@ async def segtrain_model( "-q", "fixed", "-N", "10", ] + # Train at 800px height. The default blla model uses 1800px, which peaks at + # ~7+ GB on CPU and kills the host (ketos ignores -s when -i is present, so + # we cannot override the height of an existing model). + # Strategy: only use the base model if it is already at 800px (i.e. was + # produced by a previous fine-tuning run here). Otherwise train from scratch — + # the first run bootstraps a 800px model; all subsequent runs fine-tune it. + seg_spec = ( + "[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 " + "Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]" + ) + use_base_model = False if os.path.exists(blla_model_path): - cmd += ["-i", blla_model_path, "--resize", "both"] + try: + from kraken.lib import vgsl as _vgsl + _m = _vgsl.TorchVGSLModel.load_model(blla_model_path) + use_base_model = _m.input[2] == 800 # input is (batch, channels, H, W) + if not use_base_model: + log.info( + "Base model height is %dpx — skipping -i to avoid OOM; " + "will train from scratch at 800px", + _m.input[2], + ) + except Exception as exc: + log.warning("Could not inspect base model height, training from scratch: %s", exc) + + if use_base_model: + cmd += ["-i", blla_model_path, "--resize", "union", "-s", seg_spec] else: - # No pretrained model — train from scratch with reduced height (800px) - # to keep peak RAM under ~200 MB on CPU (default 1800px uses ~500 MB+) - cmd += [ - "-s", - "[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 " - "Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]", - ] + cmd += ["-s", seg_spec] cmd += xml_files log.info("Running: %s", " ".join(cmd[:5]) + " ...") @@ -493,7 +512,8 @@ async def segtrain_model( raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}") accuracy, epochs = _parse_best_checkpoint(checkpoint_dir) - log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy) + cer = round(1.0 - accuracy, 4) if accuracy is not None else None + log.info("Segmentation training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer) best_model = _find_best_model(checkpoint_dir) if best_model is None: @@ -508,7 +528,7 @@ async def segtrain_model( shutil.copy2(best_model, blla_model_path) log.info("Replaced blla model at %s", blla_model_path) - return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs} + return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs} result = await asyncio.to_thread(_run_segtrain) return result diff --git a/ocr-service/test_ensure_blla_model.py b/ocr-service/test_ensure_blla_model.py new file mode 100644 index 00000000..cc171e92 --- /dev/null +++ b/ocr-service/test_ensure_blla_model.py @@ -0,0 +1,69 @@ +"""Unit tests for ensure_blla_model.main().""" + +from unittest.mock import MagicMock, call, patch + +import ensure_blla_model + + +# ─── Model already loadable ─────────────────────────────────────────────────── + + +def test_main_returns_early_when_model_is_loadable(): + """When the model exists and loads cleanly, no download or rename occurs.""" + with ( + patch("os.path.exists", return_value=True), + patch.object(ensure_blla_model, "_model_is_loadable", return_value=True), + patch.object(ensure_blla_model, "_download_blla") as mock_download, + patch("os.rename") as mock_rename, + ): + ensure_blla_model.main() + + mock_download.assert_not_called() + mock_rename.assert_not_called() + + +# ─── Model exists but is incompatible ───────────────────────────────────────── + + +def test_main_replaces_incompatible_model(): + """An incompatible model is renamed and replaced with a fresh download.""" + fake_path = "/app/models/blla.mlmodel" + downloaded_path = "/tmp/downloaded.mlmodel" + + with ( + patch.object(ensure_blla_model, "BLLA_MODEL_PATH", fake_path), + patch("os.path.exists", return_value=True), + patch.object(ensure_blla_model, "_model_is_loadable", return_value=False), + patch.object(ensure_blla_model, "_download_blla", return_value=downloaded_path), + patch("os.rename") as mock_rename, + patch("shutil.copy2") as mock_copy, + patch("os.makedirs"), + ): + ensure_blla_model.main() + + mock_rename.assert_called_once_with(fake_path, fake_path + ".incompatible") + mock_copy.assert_called_once_with(downloaded_path, fake_path) + + +# ─── Model missing ──────────────────────────────────────────────────────────── + + +def test_main_downloads_when_model_missing(): + """When the model file doesn't exist at all, it is downloaded without rename.""" + fake_path = "/app/models/blla.mlmodel" + downloaded_path = "/tmp/downloaded.mlmodel" + + with ( + patch.object(ensure_blla_model, "BLLA_MODEL_PATH", fake_path), + patch("os.path.exists", return_value=False), + patch.object(ensure_blla_model, "_model_is_loadable") as mock_loadable, + patch.object(ensure_blla_model, "_download_blla", return_value=downloaded_path), + patch("os.rename") as mock_rename, + patch("shutil.copy2") as mock_copy, + patch("os.makedirs"), + ): + ensure_blla_model.main() + + mock_loadable.assert_not_called() + mock_rename.assert_not_called() + mock_copy.assert_called_once_with(downloaded_path, fake_path)