diff --git a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java index 1253b317..4ada18e4 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java @@ -120,6 +120,14 @@ public class OcrController { return ocrTrainingService.triggerTraining(userId); } + @PostMapping("/api/ocr/segtrain") + @ResponseStatus(HttpStatus.CREATED) + @RequirePermission(Permission.ADMIN) + public OcrTrainingRun triggerSegTraining(Authentication authentication) { + UUID userId = resolveUserId(authentication); + return ocrTrainingService.triggerSegTraining(userId); + } + @GetMapping("/api/ocr/training-info") @RequirePermission(Permission.ADMIN) public OcrTrainingService.TrainingInfoResponse getTrainingInfo() { diff --git a/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java b/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java index add2a2a7..e5db5231 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/model/OcrTrainingRun.java @@ -41,6 +41,18 @@ public class OcrTrainingRun { @Schema(requiredMode = Schema.RequiredMode.REQUIRED) private String modelName; + @Column(name = "cer") + private Double cer; + + @Column(name = "loss") + private Double loss; + + @Column(name = "accuracy") + private Double accuracy; + + @Column(name = "epochs") + private Integer epochs; + @Column(name = "error_message") private String errorMessage; diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java index 384eefe7..45c2f021 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrClient.java @@ -26,7 +26,7 @@ public interface OcrClient { */ TrainingResult trainModel(byte[] trainingDataZip); - record TrainingResult(Double loss, Double accuracy, Integer epochs) {} + record TrainingResult(Double loss, Double accuracy, Double cer, Integer epochs) {} /** * Send a segmentation training ZIP to the OCR service for fine-tuning the blla model. diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java index 04844fd0..4b688119 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java @@ -80,12 +80,16 @@ public class OcrTrainingService { byte[] zipBytes = baos.toByteArray(); log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length); - ocrClient.trainModel(zipBytes); + OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes); run.setStatus(TrainingStatus.DONE); run.setCompletedAt(Instant.now()); + run.setCer(result.cer()); + run.setLoss(result.loss()); + run.setAccuracy(result.accuracy()); + run.setEpochs(result.epochs()); run = trainingRunRepository.save(run); - log.info("[trainingRun={}] Training completed successfully", runId); + log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs()); } catch (Exception e) { run.setStatus(TrainingStatus.FAILED); run.setErrorMessage(e.getMessage()); @@ -99,6 +103,66 @@ public class OcrTrainingService { return run; } + @Transactional + public OcrTrainingRun triggerSegTraining(UUID triggeredBy) { + if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { + throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, + "A training run is already in progress"); + } + + var segBlocks = segmentationTrainingExportService.querySegmentationBlocks(); + if (segBlocks.size() < 5) { + throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR, + "At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")"); + } + + long documentCount = segBlocks.stream() + .map(b -> b.getDocumentId()) + .distinct() + .count(); + + OcrTrainingRun run = OcrTrainingRun.builder() + .status(TrainingStatus.RUNNING) + .blockCount(segBlocks.size()) + .documentCount((int) documentCount) + .modelName("blla") + .triggeredBy(triggeredBy) + .build(); + run = trainingRunRepository.save(run); + + String runId = run.getId().toString(); + MDC.put("trainingRunId", runId); + log.info("Started segmentation training run {} with {} segments from {} documents", + runId, segBlocks.size(), documentCount); + + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + segmentationTrainingExportService.exportToZip().writeTo(baos); + byte[] zipBytes = baos.toByteArray(); + + log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length); + OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes); + + run.setStatus(TrainingStatus.DONE); + run.setCompletedAt(Instant.now()); + run.setLoss(result.loss()); + run.setAccuracy(result.accuracy()); + run.setEpochs(result.epochs()); + run = trainingRunRepository.save(run); + log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs()); + } catch (Exception e) { + run.setStatus(TrainingStatus.FAILED); + run.setErrorMessage(e.getMessage()); + run.setCompletedAt(Instant.now()); + run = trainingRunRepository.save(run); + log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e); + } finally { + MDC.remove("trainingRunId"); + } + + return run; + } + public TrainingInfoResponse getTrainingInfo() { var eligibleBlocks = trainingDataExportService.queryEligibleBlocks(); int availableDocuments = (int) eligibleBlocks.stream() diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java b/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java index e6087ab5..ddec08a8 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/RestClientOcrClient.java @@ -127,8 +127,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { .retrieve() .body(TrainingResultJson.class); - if (result == null) return new OcrClient.TrainingResult(null, null, null); - return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs()); + if (result == null) return new OcrClient.TrainingResult(null, null, null, null); + return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs()); } @Override @@ -156,8 +156,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { .retrieve() .body(TrainingResultJson.class); - if (result == null) return new OcrClient.TrainingResult(null, null, null); - return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs()); + if (result == null) return new OcrClient.TrainingResult(null, null, null, null); + return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs()); } @Override @@ -255,7 +255,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient { } } - record TrainingResultJson(Double loss, Double accuracy, Integer epochs) {} + record TrainingResultJson(Double loss, Double accuracy, Double cer, Integer epochs) {} record OcrBlockJson( @JsonProperty("pageNumber") int pageNumber, diff --git a/backend/src/main/resources/db/migration/V32__add_training_metrics.sql b/backend/src/main/resources/db/migration/V32__add_training_metrics.sql new file mode 100644 index 00000000..d510a079 --- /dev/null +++ b/backend/src/main/resources/db/migration/V32__add_training_metrics.sql @@ -0,0 +1,5 @@ +ALTER TABLE ocr_training_runs + ADD COLUMN cer DOUBLE PRECISION, + ADD COLUMN loss DOUBLE PRECISION, + ADD COLUMN accuracy DOUBLE PRECISION, + ADD COLUMN epochs INT; diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java index 5dcf1fcb..f94e6e8f 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -93,7 +93,7 @@ class OcrTrainingServiceTest { ); when(exportService.queryEligibleBlocks()).thenReturn(blocks); when(exportService.exportToZip()).thenReturn(out -> {}); - when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 3)); + when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3)); OcrTrainingRun saved = OcrTrainingRun.builder() .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) diff --git a/frontend/messages/de.json b/frontend/messages/de.json index 1a7912f1..52a82271 100644 --- a/frontend/messages/de.json +++ b/frontend/messages/de.json @@ -550,6 +550,7 @@ "training_history_col_status": "Status", "training_history_col_blocks": "Blöcke", "training_history_col_docs": "Dokumente", + "training_history_col_cer": "Fehlerrate", "training_status_done": "Fertig", "training_status_failed": "Fehler", "training_status_running": "Läuft…", diff --git a/frontend/messages/en.json b/frontend/messages/en.json index 58633166..b8670f73 100644 --- a/frontend/messages/en.json +++ b/frontend/messages/en.json @@ -550,6 +550,7 @@ "training_history_col_status": "Status", "training_history_col_blocks": "Blocks", "training_history_col_docs": "Documents", + "training_history_col_cer": "Error Rate", "training_status_done": "Done", "training_status_failed": "Failed", "training_status_running": "Running…", diff --git a/frontend/messages/es.json b/frontend/messages/es.json index 74051eaa..f6611cef 100644 --- a/frontend/messages/es.json +++ b/frontend/messages/es.json @@ -550,6 +550,7 @@ "training_history_col_status": "Estado", "training_history_col_blocks": "Bloques", "training_history_col_docs": "Documentos", + "training_history_col_cer": "Tasa de error", "training_status_done": "Listo", "training_status_failed": "Error", "training_status_running": "Ejecutando…", diff --git a/frontend/src/lib/components/TrainingHistory.svelte b/frontend/src/lib/components/TrainingHistory.svelte index 1260c15d..f8b23f49 100644 --- a/frontend/src/lib/components/TrainingHistory.svelte +++ b/frontend/src/lib/components/TrainingHistory.svelte @@ -11,6 +11,7 @@ interface Run { triggeredBy?: string; createdAt: string; completedAt?: string; + cer?: number; } interface Props { @@ -28,6 +29,11 @@ const dateFormatter = new Intl.DateTimeFormat('de-DE', { function formatDate(iso: string): string { return dateFormatter.format(new Date(iso)); } + +function formatCer(cer: number | undefined | null): string { + if (cer == null) return '—'; + return (cer * 100).toFixed(1) + ' %'; +}
| {m.training_history_col_status()} | {m.training_history_col_blocks()} | {m.training_history_col_docs()} | +{m.training_history_col_cer()} | {#if runs.length === 0}|||||
|---|---|---|---|---|---|---|---|---|
| + | {m.training_history_empty()} | {run.blockCount} | {run.documentCount} | +{run.status === 'DONE' && run.cer != null ? formatCer(run.cer) : '—'} | {/each} {/if} diff --git a/ocr-service/main.py b/ocr-service/main.py index 5e820401..0484e2b1 100644 --- a/ocr-service/main.py +++ b/ocr-service/main.py @@ -333,8 +333,10 @@ async def train_model( epochs = getattr(result, "epochs", None) or 0 loss = getattr(result, "best_loss", None) accuracy = getattr(result, "best_accuracy", None) + # CER = 1 - character accuracy (Kraken reports character-level accuracy) + cer = round(1.0 - accuracy, 4) if accuracy is not None else None - log.info("Training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy) + log.info("Training complete — epochs=%s loss=%s accuracy=%s cer=%s", epochs, loss, accuracy, cer) # Backup existing model and replace if os.path.exists(KRAKEN_MODEL_PATH): @@ -351,7 +353,7 @@ async def train_model( kraken_engine.load_models() log.info("Reloaded Kraken model in-process") - return {"loss": loss, "accuracy": accuracy, "epochs": epochs} + return {"loss": loss, "accuracy": accuracy, "cer": cer, "epochs": epochs} result = await asyncio.to_thread(_run_training) return result @@ -409,7 +411,7 @@ async def segtrain_model( loss = getattr(result, "best_loss", None) accuracy = getattr(result, "best_accuracy", None) - log.info("Segmentation training complete — epochs=%s loss=%s", epochs, loss) + log.info("Segmentation training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy) if os.path.exists(blla_model_path): timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") @@ -420,7 +422,7 @@ async def segtrain_model( shutil.move(output_model_path, blla_model_path) log.info("Replaced blla model at %s", blla_model_path) - return {"loss": loss, "accuracy": accuracy, "epochs": epochs} + return {"loss": loss, "accuracy": accuracy, "cer": None, "epochs": epochs} result = await asyncio.to_thread(_run_segtrain) return result||||