feat(training): track and display CER per training run

After each training run, the Character Error Rate (CER = 1 - accuracy),
loss, accuracy, and epoch count are now stored on the OcrTrainingRun
record and shown in the training history table.

Also adds the missing POST /api/ocr/segtrain endpoint and the
triggerSegTraining service method so the segmentation training card
can actually trigger training.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-13 19:01:10 +02:00
parent a99afef319
commit 22954f348a
12 changed files with 118 additions and 14 deletions

View File

@@ -120,6 +120,14 @@ public class OcrController {
return ocrTrainingService.triggerTraining(userId); return ocrTrainingService.triggerTraining(userId);
} }
@PostMapping("/api/ocr/segtrain")
@ResponseStatus(HttpStatus.CREATED)
@RequirePermission(Permission.ADMIN)
public OcrTrainingRun triggerSegTraining(Authentication authentication) {
UUID userId = resolveUserId(authentication);
return ocrTrainingService.triggerSegTraining(userId);
}
@GetMapping("/api/ocr/training-info") @GetMapping("/api/ocr/training-info")
@RequirePermission(Permission.ADMIN) @RequirePermission(Permission.ADMIN)
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() { public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {

View File

@@ -41,6 +41,18 @@ public class OcrTrainingRun {
@Schema(requiredMode = Schema.RequiredMode.REQUIRED) @Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private String modelName; private String modelName;
@Column(name = "cer")
private Double cer;
@Column(name = "loss")
private Double loss;
@Column(name = "accuracy")
private Double accuracy;
@Column(name = "epochs")
private Integer epochs;
@Column(name = "error_message") @Column(name = "error_message")
private String errorMessage; private String errorMessage;

View File

@@ -26,7 +26,7 @@ public interface OcrClient {
*/ */
TrainingResult trainModel(byte[] trainingDataZip); TrainingResult trainModel(byte[] trainingDataZip);
record TrainingResult(Double loss, Double accuracy, Integer epochs) {} record TrainingResult(Double loss, Double accuracy, Double cer, Integer epochs) {}
/** /**
* Send a segmentation training ZIP to the OCR service for fine-tuning the blla model. * Send a segmentation training ZIP to the OCR service for fine-tuning the blla model.

View File

@@ -80,12 +80,16 @@ public class OcrTrainingService {
byte[] zipBytes = baos.toByteArray(); byte[] zipBytes = baos.toByteArray();
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length); log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
ocrClient.trainModel(zipBytes); OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
run.setStatus(TrainingStatus.DONE); run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now()); run.setCompletedAt(Instant.now());
run.setCer(result.cer());
run.setLoss(result.loss());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
run = trainingRunRepository.save(run); run = trainingRunRepository.save(run);
log.info("[trainingRun={}] Training completed successfully", runId); log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
} catch (Exception e) { } catch (Exception e) {
run.setStatus(TrainingStatus.FAILED); run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage()); run.setErrorMessage(e.getMessage());
@@ -99,6 +103,66 @@ public class OcrTrainingService {
return run; return run;
} }
@Transactional
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
if (segBlocks.size() < 5) {
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
}
long documentCount = segBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
OcrTrainingRun run = OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.blockCount(segBlocks.size())
.documentCount((int) documentCount)
.modelName("blla")
.triggeredBy(triggeredBy)
.build();
run = trainingRunRepository.save(run);
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started segmentation training run {} with {} segments from {} documents",
runId, segBlocks.size(), documentCount);
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
segmentationTrainingExportService.exportToZip().writeTo(baos);
byte[] zipBytes = baos.toByteArray();
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setLoss(result.loss());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
run = trainingRunRepository.save(run);
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
} catch (Exception e) {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
run = trainingRunRepository.save(run);
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
} finally {
MDC.remove("trainingRunId");
}
return run;
}
public TrainingInfoResponse getTrainingInfo() { public TrainingInfoResponse getTrainingInfo() {
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks(); var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
int availableDocuments = (int) eligibleBlocks.stream() int availableDocuments = (int) eligibleBlocks.stream()

View File

@@ -127,8 +127,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.retrieve() .retrieve()
.body(TrainingResultJson.class); .body(TrainingResultJson.class);
if (result == null) return new OcrClient.TrainingResult(null, null, null); if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs()); return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
} }
@Override @Override
@@ -156,8 +156,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.retrieve() .retrieve()
.body(TrainingResultJson.class); .body(TrainingResultJson.class);
if (result == null) return new OcrClient.TrainingResult(null, null, null); if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs()); return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
} }
@Override @Override
@@ -255,7 +255,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
} }
} }
record TrainingResultJson(Double loss, Double accuracy, Integer epochs) {} record TrainingResultJson(Double loss, Double accuracy, Double cer, Integer epochs) {}
record OcrBlockJson( record OcrBlockJson(
@JsonProperty("pageNumber") int pageNumber, @JsonProperty("pageNumber") int pageNumber,

View File

@@ -0,0 +1,5 @@
ALTER TABLE ocr_training_runs
ADD COLUMN cer DOUBLE PRECISION,
ADD COLUMN loss DOUBLE PRECISION,
ADD COLUMN accuracy DOUBLE PRECISION,
ADD COLUMN epochs INT;

View File

@@ -93,7 +93,7 @@ class OcrTrainingServiceTest {
); );
when(exportService.queryEligibleBlocks()).thenReturn(blocks); when(exportService.queryEligibleBlocks()).thenReturn(blocks);
when(exportService.exportToZip()).thenReturn(out -> {}); when(exportService.exportToZip()).thenReturn(out -> {});
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 3)); when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
OcrTrainingRun saved = OcrTrainingRun.builder() OcrTrainingRun saved = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING) .id(UUID.randomUUID()).status(TrainingStatus.RUNNING)

View File

@@ -550,6 +550,7 @@
"training_history_col_status": "Status", "training_history_col_status": "Status",
"training_history_col_blocks": "Blöcke", "training_history_col_blocks": "Blöcke",
"training_history_col_docs": "Dokumente", "training_history_col_docs": "Dokumente",
"training_history_col_cer": "Fehlerrate",
"training_status_done": "Fertig", "training_status_done": "Fertig",
"training_status_failed": "Fehler", "training_status_failed": "Fehler",
"training_status_running": "Läuft…", "training_status_running": "Läuft…",

View File

@@ -550,6 +550,7 @@
"training_history_col_status": "Status", "training_history_col_status": "Status",
"training_history_col_blocks": "Blocks", "training_history_col_blocks": "Blocks",
"training_history_col_docs": "Documents", "training_history_col_docs": "Documents",
"training_history_col_cer": "Error Rate",
"training_status_done": "Done", "training_status_done": "Done",
"training_status_failed": "Failed", "training_status_failed": "Failed",
"training_status_running": "Running…", "training_status_running": "Running…",

View File

@@ -550,6 +550,7 @@
"training_history_col_status": "Estado", "training_history_col_status": "Estado",
"training_history_col_blocks": "Bloques", "training_history_col_blocks": "Bloques",
"training_history_col_docs": "Documentos", "training_history_col_docs": "Documentos",
"training_history_col_cer": "Tasa de error",
"training_status_done": "Listo", "training_status_done": "Listo",
"training_status_failed": "Error", "training_status_failed": "Error",
"training_status_running": "Ejecutando…", "training_status_running": "Ejecutando…",

View File

@@ -11,6 +11,7 @@ interface Run {
triggeredBy?: string; triggeredBy?: string;
createdAt: string; createdAt: string;
completedAt?: string; completedAt?: string;
cer?: number;
} }
interface Props { interface Props {
@@ -28,6 +29,11 @@ const dateFormatter = new Intl.DateTimeFormat('de-DE', {
function formatDate(iso: string): string { function formatDate(iso: string): string {
return dateFormatter.format(new Date(iso)); return dateFormatter.format(new Date(iso));
} }
function formatCer(cer: number | undefined | null): string {
if (cer == null) return '—';
return (cer * 100).toFixed(1) + ' %';
}
</script> </script>
<table class="w-full text-sm"> <table class="w-full text-sm">
@@ -37,12 +43,13 @@ function formatDate(iso: string): string {
<th class="pb-2 text-left">{m.training_history_col_status()}</th> <th class="pb-2 text-left">{m.training_history_col_status()}</th>
<th class="pb-2 text-right">{m.training_history_col_blocks()}</th> <th class="pb-2 text-right">{m.training_history_col_blocks()}</th>
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_docs()}</th> <th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_docs()}</th>
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_cer()}</th>
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
{#if runs.length === 0} {#if runs.length === 0}
<tr> <tr>
<td colspan="4" class="py-4 text-center text-sm text-ink-2"> <td colspan="5" class="py-4 text-center text-sm text-ink-2">
{m.training_history_empty()} {m.training_history_empty()}
</td> </td>
</tr> </tr>
@@ -71,6 +78,9 @@ function formatDate(iso: string): string {
</td> </td>
<td class="py-2 text-right text-ink-2">{run.blockCount}</td> <td class="py-2 text-right text-ink-2">{run.blockCount}</td>
<td class="hidden py-2 text-right text-ink-2 md:table-cell">{run.documentCount}</td> <td class="hidden py-2 text-right text-ink-2 md:table-cell">{run.documentCount}</td>
<td class="hidden py-2 text-right md:table-cell"
>{run.status === 'DONE' && run.cer != null ? formatCer(run.cer) : '—'}</td
>
</tr> </tr>
{/each} {/each}
{/if} {/if}

View File

@@ -333,8 +333,10 @@ async def train_model(
epochs = getattr(result, "epochs", None) or 0 epochs = getattr(result, "epochs", None) or 0
loss = getattr(result, "best_loss", None) loss = getattr(result, "best_loss", None)
accuracy = getattr(result, "best_accuracy", None) accuracy = getattr(result, "best_accuracy", None)
# CER = 1 - character accuracy (Kraken reports character-level accuracy)
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
log.info("Training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy) log.info("Training complete — epochs=%s loss=%s accuracy=%s cer=%s", epochs, loss, accuracy, cer)
# Backup existing model and replace # Backup existing model and replace
if os.path.exists(KRAKEN_MODEL_PATH): if os.path.exists(KRAKEN_MODEL_PATH):
@@ -351,7 +353,7 @@ async def train_model(
kraken_engine.load_models() kraken_engine.load_models()
log.info("Reloaded Kraken model in-process") log.info("Reloaded Kraken model in-process")
return {"loss": loss, "accuracy": accuracy, "epochs": epochs} return {"loss": loss, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_training) result = await asyncio.to_thread(_run_training)
return result return result
@@ -409,7 +411,7 @@ async def segtrain_model(
loss = getattr(result, "best_loss", None) loss = getattr(result, "best_loss", None)
accuracy = getattr(result, "best_accuracy", None) accuracy = getattr(result, "best_accuracy", None)
log.info("Segmentation training complete — epochs=%s loss=%s", epochs, loss) log.info("Segmentation training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
if os.path.exists(blla_model_path): if os.path.exists(blla_model_path):
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
@@ -420,7 +422,7 @@ async def segtrain_model(
shutil.move(output_model_path, blla_model_path) shutil.move(output_model_path, blla_model_path)
log.info("Replaced blla model at %s", blla_model_path) log.info("Replaced blla model at %s", blla_model_path)
return {"loss": loss, "accuracy": accuracy, "epochs": epochs} return {"loss": loss, "accuracy": accuracy, "cer": None, "epochs": epochs}
result = await asyncio.to_thread(_run_segtrain) result = await asyncio.to_thread(_run_segtrain)
return result return result