fix(ocr): fix segmentation training for ketos 7 and low-memory hosts #234
@@ -21,9 +21,10 @@ PORT_FRONTEND=5173
|
||||
PORT_MAILPIT_UI=8100
|
||||
PORT_MAILPIT_SMTP=1025
|
||||
|
||||
# OCR Training — set a secret token to protect the /train and /segtrain endpoints on the
|
||||
# Python OCR microservice. Leave empty to disable token authentication (development only).
|
||||
# OCR_TRAINING_TOKEN=change-me-in-production
|
||||
# OCR Training — secret token required to call /train and /segtrain on the OCR service.
|
||||
# Also set in the backend so it can pass the token through. Must not be empty in production.
|
||||
# Generate with: python3 -c "import secrets; print(secrets.token_hex(32))"
|
||||
OCR_TRAINING_TOKEN=change-me-in-production
|
||||
|
||||
# Production SMTP — uncomment and fill in to send real emails instead of catching them
|
||||
# APP_BASE_URL=https://your-domain.example.com
|
||||
|
||||
@@ -12,5 +12,5 @@ public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun,
|
||||
|
||||
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
|
||||
|
||||
List<OcrTrainingRun> findTop5ByOrderByCreatedAtDesc();
|
||||
List<OcrTrainingRun> findTop10ByOrderByCreatedAtDesc();
|
||||
}
|
||||
|
||||
@@ -45,6 +45,13 @@ public class OcrTrainingService {
|
||||
List<OcrTrainingRun> runs
|
||||
) {}
|
||||
|
||||
private void assertNoRunningTraining() {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
}
|
||||
|
||||
// Not safe for horizontal scaling: training reloads the Kraken model in-process on the
|
||||
// Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique
|
||||
// index) prevents concurrent training API calls, but cannot prevent two OCR service replicas
|
||||
@@ -53,10 +60,7 @@ public class OcrTrainingService {
|
||||
// Short transaction: guard check + create RUNNING row, then commit immediately.
|
||||
// The DB connection is released before the OCR HTTP call, which can take several minutes.
|
||||
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
assertNoRunningTraining();
|
||||
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
if (eligibleBlocks.size() < 5) {
|
||||
@@ -120,10 +124,7 @@ public class OcrTrainingService {
|
||||
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
|
||||
// Same pattern as triggerTraining: narrow transactions around DB writes only.
|
||||
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
assertNoRunningTraining();
|
||||
|
||||
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
|
||||
if (segBlocks.size() < 5) {
|
||||
@@ -162,11 +163,12 @@ public class OcrTrainingService {
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setCer(result.cer());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
OcrTrainingRun updated = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
|
||||
log.info("[trainingRun={}] Segmentation training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
|
||||
return updated;
|
||||
}));
|
||||
} catch (Exception e) {
|
||||
@@ -193,7 +195,7 @@ public class OcrTrainingService {
|
||||
int totalOcrBlocks = (int) blockRepository.count();
|
||||
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
|
||||
|
||||
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
|
||||
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop10ByOrderByCreatedAtDesc();
|
||||
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
|
||||
|
||||
return new TrainingInfoResponse(
|
||||
|
||||
@@ -53,7 +53,7 @@ class OcrTrainingServiceTest {
|
||||
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
|
||||
|
||||
when(blockRepository.count()).thenReturn(0L);
|
||||
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||
when(runRepository.findTop10ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||
when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
|
||||
}
|
||||
|
||||
@@ -146,6 +146,90 @@ class OcrTrainingServiceTest {
|
||||
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
|
||||
}
|
||||
|
||||
// ─── triggerSegTraining ───────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerSegTraining_throws409_whenRunningRunExists() {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING))
|
||||
.thenReturn(Optional.of(OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(2).modelName("blla").build()));
|
||||
|
||||
assertThatThrownBy(() -> service.triggerSegTraining(null))
|
||||
.isInstanceOf(DomainException.class)
|
||||
.extracting("status")
|
||||
.satisfies(s -> assertThat(s.toString()).contains("409"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void triggerSegTraining_throws422_whenFewerThan5Segments() {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
when(segExportService.querySegmentationBlocks()).thenReturn(List.of(
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(UUID.randomUUID()).build()
|
||||
));
|
||||
|
||||
assertThatThrownBy(() -> service.triggerSegTraining(null))
|
||||
.isInstanceOf(DomainException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void triggerSegTraining_createsRunWithBlla_andMarksDoneWithCer() throws Exception {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
UUID docA = UUID.randomUUID();
|
||||
UUID docB = UUID.randomUUID();
|
||||
List<TranscriptionBlock> segs = List.of(
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docB).build()
|
||||
);
|
||||
when(segExportService.querySegmentationBlocks()).thenReturn(segs);
|
||||
when(segExportService.exportToZip()).thenReturn(out -> {});
|
||||
when(ocrClient.segtrainModel(any())).thenReturn(new OcrClient.TrainingResult(null, 0.92, 0.08, 5));
|
||||
|
||||
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(2).modelName("blla").build();
|
||||
when(runRepository.save(any())).thenReturn(saved);
|
||||
|
||||
service.triggerSegTraining(null);
|
||||
|
||||
verify(runRepository, atLeastOnce()).save(argThat(run ->
|
||||
run.getStatus() == TrainingStatus.DONE
|
||||
&& "blla".equals(run.getModelName())
|
||||
&& run.getCer() != null));
|
||||
}
|
||||
|
||||
@Test
|
||||
void triggerSegTraining_marksRunFailed_whenOcrClientThrows() throws Exception {
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
UUID docA = UUID.randomUUID();
|
||||
List<TranscriptionBlock> segs = List.of(
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build(),
|
||||
TranscriptionBlock.builder().id(UUID.randomUUID()).documentId(docA).build()
|
||||
);
|
||||
when(segExportService.querySegmentationBlocks()).thenReturn(segs);
|
||||
when(segExportService.exportToZip()).thenReturn(out -> {});
|
||||
when(ocrClient.segtrainModel(any())).thenThrow(new RuntimeException("seg timeout"));
|
||||
|
||||
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("blla").build();
|
||||
when(runRepository.save(any())).thenReturn(saved);
|
||||
|
||||
service.triggerSegTraining(null);
|
||||
|
||||
verify(runRepository, atLeastOnce()).save(argThat(run ->
|
||||
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
|
||||
}
|
||||
|
||||
// ─── Orphan recovery ──────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
|
||||
@@ -83,11 +83,11 @@ services:
|
||||
restart: unless-stopped
|
||||
expose:
|
||||
- "8000"
|
||||
mem_limit: 8g
|
||||
memswap_limit: 8g
|
||||
mem_limit: 12g
|
||||
memswap_limit: 12g
|
||||
volumes:
|
||||
- ocr_models:/app/models
|
||||
- ocr_cache:/root/.cache
|
||||
- ocr_cache:/root/.cache # Hugging Face / ketos model download cache — prevents re-downloads on container recreate
|
||||
environment:
|
||||
KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel
|
||||
TRAINING_TOKEN: "${OCR_TRAINING_TOKEN:-}"
|
||||
@@ -102,7 +102,7 @@ services:
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 12
|
||||
start_period: 60s
|
||||
start_period: 120s
|
||||
|
||||
# --- Backend: Spring Boot ---
|
||||
backend:
|
||||
@@ -145,6 +145,7 @@ services:
|
||||
SPRING_MAIL_PROPERTIES_MAIL_SMTP_AUTH: ${MAIL_SMTP_AUTH:-false}
|
||||
SPRING_MAIL_PROPERTIES_MAIL_SMTP_STARTTLS_ENABLE: ${MAIL_STARTTLS_ENABLE:-false}
|
||||
APP_OCR_BASE_URL: http://ocr-service:8000
|
||||
APP_OCR_TRAINING_TOKEN: "${OCR_TRAINING_TOKEN:-}"
|
||||
ports:
|
||||
- "${PORT_BACKEND}:8080"
|
||||
networks:
|
||||
|
||||
@@ -554,6 +554,7 @@
|
||||
"training_history_col_cer": "Fehlerrate",
|
||||
"training_status_done": "Fertig",
|
||||
"training_status_failed": "Fehler",
|
||||
"training_error_detail_label": "Fehlerdetails",
|
||||
"training_status_running": "Läuft…",
|
||||
"training_seg_heading": "Segmentierung trainieren",
|
||||
"training_seg_description": "Starte ein neues Training mit annotierten Segmentierungsbereichen, um die Texterkennung zu verbessern.",
|
||||
|
||||
@@ -554,6 +554,7 @@
|
||||
"training_history_col_cer": "Error Rate",
|
||||
"training_status_done": "Done",
|
||||
"training_status_failed": "Failed",
|
||||
"training_error_detail_label": "Error details",
|
||||
"training_status_running": "Running…",
|
||||
"training_seg_heading": "Train segmentation",
|
||||
"training_seg_description": "Start a new training run using annotated segmentation regions to improve text detection.",
|
||||
|
||||
@@ -554,6 +554,7 @@
|
||||
"training_history_col_cer": "Tasa de error",
|
||||
"training_status_done": "Listo",
|
||||
"training_status_failed": "Error",
|
||||
"training_error_detail_label": "Detalles del error",
|
||||
"training_status_running": "Ejecutando…",
|
||||
"training_seg_heading": "Entrenar segmentación",
|
||||
"training_seg_description": "Inicia un nuevo entrenamiento con regiones de segmentación anotadas para mejorar la detección de texto.",
|
||||
|
||||
@@ -20,6 +20,12 @@ interface Props {
|
||||
|
||||
let { runs }: Props = $props();
|
||||
|
||||
const COLLAPSED_COUNT = 3;
|
||||
let expanded = $state(false);
|
||||
|
||||
const visibleRuns = $derived(expanded ? runs : runs.slice(0, COLLAPSED_COUNT));
|
||||
const hasMore = $derived(runs.length > COLLAPSED_COUNT);
|
||||
|
||||
const dateFormatter = new Intl.DateTimeFormat('de-DE', {
|
||||
day: 'numeric',
|
||||
month: 'short',
|
||||
@@ -46,7 +52,7 @@ function formatCer(cer: number | undefined | null): string {
|
||||
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_cer()}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tbody id="training-history-rows">
|
||||
{#if runs.length === 0}
|
||||
<tr>
|
||||
<td colspan="5" class="py-4 text-center text-sm text-ink-2">
|
||||
@@ -54,7 +60,7 @@ function formatCer(cer: number | undefined | null): string {
|
||||
</td>
|
||||
</tr>
|
||||
{:else}
|
||||
{#each runs as run (run.id)}
|
||||
{#each visibleRuns as run (run.id)}
|
||||
<tr class="border-b border-line/50 last:border-0">
|
||||
<td class="py-2 text-ink-2">{formatDate(run.createdAt)}</td>
|
||||
<td class="py-2">
|
||||
@@ -79,7 +85,6 @@ function formatCer(cer: number | undefined | null): string {
|
||||
{:else if run.status === 'FAILED'}
|
||||
<span
|
||||
class="inline-flex items-center gap-1 rounded-sm bg-red-100 px-1.5 py-0.5 text-xs font-medium text-red-700"
|
||||
title={run.errorMessage}
|
||||
>
|
||||
<svg
|
||||
aria-hidden="true"
|
||||
@@ -95,13 +100,21 @@ function formatCer(cer: number | undefined | null): string {
|
||||
</svg>
|
||||
{m.training_status_failed()}
|
||||
</span>
|
||||
{#if run.errorMessage}
|
||||
<details class="mt-0.5">
|
||||
<summary class="cursor-pointer text-xs text-red-700 underline">
|
||||
{m.training_error_detail_label()}
|
||||
</summary>
|
||||
<p class="mt-1 text-xs text-red-600">{run.errorMessage}</p>
|
||||
</details>
|
||||
{/if}
|
||||
{:else}
|
||||
<span
|
||||
class="inline-flex items-center gap-1 rounded-sm bg-yellow-100 px-1.5 py-0.5 text-xs font-medium text-yellow-700"
|
||||
>
|
||||
<span
|
||||
aria-hidden="true"
|
||||
class="h-1.5 w-1.5 animate-pulse rounded-full bg-yellow-500"
|
||||
class="h-1.5 w-1.5 rounded-full bg-yellow-500 motion-safe:animate-pulse"
|
||||
></span>
|
||||
{m.training_status_running()}
|
||||
</span>
|
||||
@@ -117,3 +130,17 @@ function formatCer(cer: number | undefined | null): string {
|
||||
{/if}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
{#if hasMore}
|
||||
<div class="mt-2 text-center">
|
||||
<button
|
||||
type="button"
|
||||
aria-expanded={expanded}
|
||||
aria-controls="training-history-rows"
|
||||
class="text-xs font-medium text-ink-3 transition-colors hover:text-ink"
|
||||
onclick={() => (expanded = !expanded)}
|
||||
>
|
||||
{expanded ? m.comp_expandable_show_less() : m.comp_expandable_show_more()}
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
52
frontend/src/lib/components/TrainingHistory.svelte.spec.ts
Normal file
52
frontend/src/lib/components/TrainingHistory.svelte.spec.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
import { afterEach, describe, expect, it } from 'vitest';
|
||||
import { cleanup, render } from 'vitest-browser-svelte';
|
||||
import { page } from 'vitest/browser';
|
||||
import TrainingHistory from './TrainingHistory.svelte';
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
function makeRun(i: number) {
|
||||
return {
|
||||
id: `run-${i}`,
|
||||
status: 'DONE' as const,
|
||||
blockCount: 10,
|
||||
documentCount: 2,
|
||||
modelName: 'german_kurrent',
|
||||
createdAt: `2026-01-0${i + 1}T12:00:00Z`,
|
||||
completedAt: `2026-01-0${i + 1}T12:05:00Z`
|
||||
};
|
||||
}
|
||||
|
||||
const fiveRuns = Array.from({ length: 5 }, (_, i) => makeRun(i));
|
||||
const twoRuns = Array.from({ length: 2 }, (_, i) => makeRun(i));
|
||||
|
||||
describe('TrainingHistory — expand/collapse', () => {
|
||||
it('shows only 3 runs initially when more than 3 exist', async () => {
|
||||
render(TrainingHistory, { runs: fiveRuns });
|
||||
|
||||
const rows = page.getByRole('row');
|
||||
// 1 header row + 3 data rows = 4 total
|
||||
await expect.element(rows.nth(3)).toBeInTheDocument();
|
||||
await expect.element(rows.nth(4)).not.toBeInTheDocument();
|
||||
|
||||
await expect.element(page.getByRole('button', { name: /Mehr anzeigen/i })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows all runs after clicking the expand button', async () => {
|
||||
render(TrainingHistory, { runs: fiveRuns });
|
||||
|
||||
await page.getByRole('button', { name: /Mehr anzeigen/i }).click();
|
||||
|
||||
const rows = page.getByRole('row');
|
||||
// 1 header row + 5 data rows = 6 total
|
||||
await expect.element(rows.nth(5)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('hides the toggle button when 3 or fewer runs exist', async () => {
|
||||
render(TrainingHistory, { runs: twoRuns });
|
||||
|
||||
await expect
|
||||
.element(page.getByRole('button', { name: /Mehr anzeigen/i }))
|
||||
.not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.11-slim
|
||||
FROM python:3.11.9-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -21,6 +21,8 @@ RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN chmod +x /app/entrypoint.sh
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
CMD ["/app/entrypoint.sh"]
|
||||
|
||||
80
ocr-service/ensure_blla_model.py
Normal file
80
ocr-service/ensure_blla_model.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Validates the blla segmentation base model and downloads it if needed.
|
||||
|
||||
Run at container startup before uvicorn. ketos 7 requires the model in
|
||||
CoreML protobuf or safetensors format — legacy PyTorch ZIP archives
|
||||
(torch.save output from kraken <4) are not loadable and will be replaced.
|
||||
|
||||
Exits non-zero on failure so Docker marks the container unhealthy rather
|
||||
than silently starting with a broken model.
|
||||
"""
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(levelname)s:ensure_blla_model:%(message)s",
|
||||
)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
BLLA_MODEL_PATH = os.environ.get("BLLA_MODEL_PATH", "/app/models/blla.mlmodel")
|
||||
# DOI for "General segmentation model for print and handwriting" — ketos 7 compatible.
|
||||
BLLA_MODEL_DOI = "10.5281/zenodo.14602569"
|
||||
HTRMOPO_DIR = os.path.expanduser("~/.local/share/htrmopo")
|
||||
|
||||
|
||||
def _model_is_loadable(path: str) -> bool:
|
||||
try:
|
||||
from kraken.lib import vgsl
|
||||
|
||||
vgsl.TorchVGSLModel.load_model(path)
|
||||
return True
|
||||
except (RuntimeError, OSError, ValueError) as e:
|
||||
log.warning("Model at %s failed to load: %s", path, e)
|
||||
return False
|
||||
except Exception:
|
||||
log.debug("Unexpected error loading model at %s", path, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def _download_blla() -> str:
|
||||
log.info("Downloading blla model (DOI %s) ...", BLLA_MODEL_DOI)
|
||||
result = subprocess.run(
|
||||
["kraken", "get", BLLA_MODEL_DOI],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
log.error("kraken get failed: %s", result.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
candidates = sorted(glob.glob(os.path.join(HTRMOPO_DIR, "*/blla.mlmodel")))
|
||||
if not candidates:
|
||||
log.error("Downloaded blla.mlmodel not found under %s", HTRMOPO_DIR)
|
||||
sys.exit(1)
|
||||
|
||||
return candidates[-1]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if os.path.exists(BLLA_MODEL_PATH):
|
||||
if _model_is_loadable(BLLA_MODEL_PATH):
|
||||
log.info("blla model OK: %s", BLLA_MODEL_PATH)
|
||||
return
|
||||
log.warning(
|
||||
"blla model at %s is in an incompatible format — replacing", BLLA_MODEL_PATH
|
||||
)
|
||||
os.rename(BLLA_MODEL_PATH, BLLA_MODEL_PATH + ".incompatible")
|
||||
|
||||
os.makedirs(os.path.dirname(BLLA_MODEL_PATH), exist_ok=True)
|
||||
downloaded = _download_blla()
|
||||
shutil.copy2(downloaded, BLLA_MODEL_PATH)
|
||||
log.info("Installed blla model at %s", BLLA_MODEL_PATH)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
9
ocr-service/entrypoint.sh
Normal file
9
ocr-service/entrypoint.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Validate the blla segmentation base model and download it if missing or
|
||||
# incompatible. ketos 7 dropped support for legacy PyTorch ZIP archives —
|
||||
# this ensures the volume always holds a loadable CoreML protobuf model.
|
||||
python3 /app/ensure_blla_model.py
|
||||
|
||||
exec uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1
|
||||
@@ -472,16 +472,35 @@ async def segtrain_model(
|
||||
"-q", "fixed",
|
||||
"-N", "10",
|
||||
]
|
||||
# Train at 800px height. The default blla model uses 1800px, which peaks at
|
||||
# ~7+ GB on CPU and kills the host (ketos ignores -s when -i is present, so
|
||||
# we cannot override the height of an existing model).
|
||||
# Strategy: only use the base model if it is already at 800px (i.e. was
|
||||
# produced by a previous fine-tuning run here). Otherwise train from scratch —
|
||||
# the first run bootstraps a 800px model; all subsequent runs fine-tune it.
|
||||
seg_spec = (
|
||||
"[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 "
|
||||
"Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]"
|
||||
)
|
||||
use_base_model = False
|
||||
if os.path.exists(blla_model_path):
|
||||
cmd += ["-i", blla_model_path, "--resize", "both"]
|
||||
try:
|
||||
from kraken.lib import vgsl as _vgsl
|
||||
_m = _vgsl.TorchVGSLModel.load_model(blla_model_path)
|
||||
use_base_model = _m.input[2] == 800 # input is (batch, channels, H, W)
|
||||
if not use_base_model:
|
||||
log.info(
|
||||
"Base model height is %dpx — skipping -i to avoid OOM; "
|
||||
"will train from scratch at 800px",
|
||||
_m.input[2],
|
||||
)
|
||||
except Exception as exc:
|
||||
log.warning("Could not inspect base model height, training from scratch: %s", exc)
|
||||
|
||||
if use_base_model:
|
||||
cmd += ["-i", blla_model_path, "--resize", "union", "-s", seg_spec]
|
||||
else:
|
||||
# No pretrained model — train from scratch with reduced height (800px)
|
||||
# to keep peak RAM under ~200 MB on CPU (default 1800px uses ~500 MB+)
|
||||
cmd += [
|
||||
"-s",
|
||||
"[1,800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 "
|
||||
"Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]",
|
||||
]
|
||||
cmd += ["-s", seg_spec]
|
||||
cmd += xml_files
|
||||
|
||||
log.info("Running: %s", " ".join(cmd[:5]) + " ...")
|
||||
@@ -493,7 +512,8 @@ async def segtrain_model(
|
||||
raise RuntimeError(f"ketos segtrain failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
||||
|
||||
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
||||
log.info("Segmentation training complete — epochs=%s accuracy=%s", epochs, accuracy)
|
||||
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
||||
log.info("Segmentation training complete — epochs=%s accuracy=%s cer=%s", epochs, accuracy, cer)
|
||||
|
||||
best_model = _find_best_model(checkpoint_dir)
|
||||
if best_model is None:
|
||||
@@ -508,7 +528,7 @@ async def segtrain_model(
|
||||
shutil.copy2(best_model, blla_model_path)
|
||||
log.info("Replaced blla model at %s", blla_model_path)
|
||||
|
||||
return {"loss": None, "accuracy": accuracy, "cer": None, "epochs": epochs}
|
||||
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_segtrain)
|
||||
return result
|
||||
|
||||
69
ocr-service/test_ensure_blla_model.py
Normal file
69
ocr-service/test_ensure_blla_model.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Unit tests for ensure_blla_model.main()."""
|
||||
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import ensure_blla_model
|
||||
|
||||
|
||||
# ─── Model already loadable ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_main_returns_early_when_model_is_loadable():
|
||||
"""When the model exists and loads cleanly, no download or rename occurs."""
|
||||
with (
|
||||
patch("os.path.exists", return_value=True),
|
||||
patch.object(ensure_blla_model, "_model_is_loadable", return_value=True),
|
||||
patch.object(ensure_blla_model, "_download_blla") as mock_download,
|
||||
patch("os.rename") as mock_rename,
|
||||
):
|
||||
ensure_blla_model.main()
|
||||
|
||||
mock_download.assert_not_called()
|
||||
mock_rename.assert_not_called()
|
||||
|
||||
|
||||
# ─── Model exists but is incompatible ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_main_replaces_incompatible_model():
|
||||
"""An incompatible model is renamed and replaced with a fresh download."""
|
||||
fake_path = "/app/models/blla.mlmodel"
|
||||
downloaded_path = "/tmp/downloaded.mlmodel"
|
||||
|
||||
with (
|
||||
patch.object(ensure_blla_model, "BLLA_MODEL_PATH", fake_path),
|
||||
patch("os.path.exists", return_value=True),
|
||||
patch.object(ensure_blla_model, "_model_is_loadable", return_value=False),
|
||||
patch.object(ensure_blla_model, "_download_blla", return_value=downloaded_path),
|
||||
patch("os.rename") as mock_rename,
|
||||
patch("shutil.copy2") as mock_copy,
|
||||
patch("os.makedirs"),
|
||||
):
|
||||
ensure_blla_model.main()
|
||||
|
||||
mock_rename.assert_called_once_with(fake_path, fake_path + ".incompatible")
|
||||
mock_copy.assert_called_once_with(downloaded_path, fake_path)
|
||||
|
||||
|
||||
# ─── Model missing ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_main_downloads_when_model_missing():
|
||||
"""When the model file doesn't exist at all, it is downloaded without rename."""
|
||||
fake_path = "/app/models/blla.mlmodel"
|
||||
downloaded_path = "/tmp/downloaded.mlmodel"
|
||||
|
||||
with (
|
||||
patch.object(ensure_blla_model, "BLLA_MODEL_PATH", fake_path),
|
||||
patch("os.path.exists", return_value=False),
|
||||
patch.object(ensure_blla_model, "_model_is_loadable") as mock_loadable,
|
||||
patch.object(ensure_blla_model, "_download_blla", return_value=downloaded_path),
|
||||
patch("os.rename") as mock_rename,
|
||||
patch("shutil.copy2") as mock_copy,
|
||||
patch("os.makedirs"),
|
||||
):
|
||||
ensure_blla_model.main()
|
||||
|
||||
mock_loadable.assert_not_called()
|
||||
mock_rename.assert_not_called()
|
||||
mock_copy.assert_called_once_with(downloaded_path, fake_path)
|
||||
Reference in New Issue
Block a user