feat(training): track and display CER per training run
After each training run, the Character Error Rate (CER = 1 - accuracy), loss, accuracy, and epoch count are now stored on the OcrTrainingRun record and shown in the training history table. Also adds the missing POST /api/ocr/segtrain endpoint and the triggerSegTraining service method so the segmentation training card can actually trigger training. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -120,6 +120,14 @@ public class OcrController {
|
||||
return ocrTrainingService.triggerTraining(userId);
|
||||
}
|
||||
|
||||
@PostMapping("/api/ocr/segtrain")
|
||||
@ResponseStatus(HttpStatus.CREATED)
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public OcrTrainingRun triggerSegTraining(Authentication authentication) {
|
||||
UUID userId = resolveUserId(authentication);
|
||||
return ocrTrainingService.triggerSegTraining(userId);
|
||||
}
|
||||
|
||||
@GetMapping("/api/ocr/training-info")
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {
|
||||
|
||||
@@ -41,6 +41,18 @@ public class OcrTrainingRun {
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private String modelName;
|
||||
|
||||
@Column(name = "cer")
|
||||
private Double cer;
|
||||
|
||||
@Column(name = "loss")
|
||||
private Double loss;
|
||||
|
||||
@Column(name = "accuracy")
|
||||
private Double accuracy;
|
||||
|
||||
@Column(name = "epochs")
|
||||
private Integer epochs;
|
||||
|
||||
@Column(name = "error_message")
|
||||
private String errorMessage;
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ public interface OcrClient {
|
||||
*/
|
||||
TrainingResult trainModel(byte[] trainingDataZip);
|
||||
|
||||
record TrainingResult(Double loss, Double accuracy, Integer epochs) {}
|
||||
record TrainingResult(Double loss, Double accuracy, Double cer, Integer epochs) {}
|
||||
|
||||
/**
|
||||
* Send a segmentation training ZIP to the OCR service for fine-tuning the blla model.
|
||||
|
||||
@@ -80,12 +80,16 @@ public class OcrTrainingService {
|
||||
byte[] zipBytes = baos.toByteArray();
|
||||
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
|
||||
ocrClient.trainModel(zipBytes);
|
||||
OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setCer(result.cer());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Training completed successfully", runId);
|
||||
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
|
||||
} catch (Exception e) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
@@ -99,6 +103,66 @@ public class OcrTrainingService {
|
||||
return run;
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
|
||||
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
|
||||
if (segBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
|
||||
}
|
||||
|
||||
long documentCount = segBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(segBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("blla")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
run = trainingRunRepository.save(run);
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started segmentation training run {} with {} segments from {} documents",
|
||||
runId, segBlocks.size(), documentCount);
|
||||
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
segmentationTrainingExportService.exportToZip().writeTo(baos);
|
||||
byte[] zipBytes = baos.toByteArray();
|
||||
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
|
||||
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
|
||||
} catch (Exception e) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
}
|
||||
|
||||
return run;
|
||||
}
|
||||
|
||||
public TrainingInfoResponse getTrainingInfo() {
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
int availableDocuments = (int) eligibleBlocks.stream()
|
||||
|
||||
@@ -127,8 +127,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
.retrieve()
|
||||
.body(TrainingResultJson.class);
|
||||
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -156,8 +156,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
.retrieve()
|
||||
.body(TrainingResultJson.class);
|
||||
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -255,7 +255,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
}
|
||||
}
|
||||
|
||||
record TrainingResultJson(Double loss, Double accuracy, Integer epochs) {}
|
||||
record TrainingResultJson(Double loss, Double accuracy, Double cer, Integer epochs) {}
|
||||
|
||||
record OcrBlockJson(
|
||||
@JsonProperty("pageNumber") int pageNumber,
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE ocr_training_runs
|
||||
ADD COLUMN cer DOUBLE PRECISION,
|
||||
ADD COLUMN loss DOUBLE PRECISION,
|
||||
ADD COLUMN accuracy DOUBLE PRECISION,
|
||||
ADD COLUMN epochs INT;
|
||||
@@ -93,7 +93,7 @@ class OcrTrainingServiceTest {
|
||||
);
|
||||
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
|
||||
when(exportService.exportToZip()).thenReturn(out -> {});
|
||||
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 3));
|
||||
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
|
||||
|
||||
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
|
||||
@@ -550,6 +550,7 @@
|
||||
"training_history_col_status": "Status",
|
||||
"training_history_col_blocks": "Blöcke",
|
||||
"training_history_col_docs": "Dokumente",
|
||||
"training_history_col_cer": "Fehlerrate",
|
||||
"training_status_done": "Fertig",
|
||||
"training_status_failed": "Fehler",
|
||||
"training_status_running": "Läuft…",
|
||||
|
||||
@@ -550,6 +550,7 @@
|
||||
"training_history_col_status": "Status",
|
||||
"training_history_col_blocks": "Blocks",
|
||||
"training_history_col_docs": "Documents",
|
||||
"training_history_col_cer": "Error Rate",
|
||||
"training_status_done": "Done",
|
||||
"training_status_failed": "Failed",
|
||||
"training_status_running": "Running…",
|
||||
|
||||
@@ -550,6 +550,7 @@
|
||||
"training_history_col_status": "Estado",
|
||||
"training_history_col_blocks": "Bloques",
|
||||
"training_history_col_docs": "Documentos",
|
||||
"training_history_col_cer": "Tasa de error",
|
||||
"training_status_done": "Listo",
|
||||
"training_status_failed": "Error",
|
||||
"training_status_running": "Ejecutando…",
|
||||
|
||||
@@ -11,6 +11,7 @@ interface Run {
|
||||
triggeredBy?: string;
|
||||
createdAt: string;
|
||||
completedAt?: string;
|
||||
cer?: number;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
@@ -28,6 +29,11 @@ const dateFormatter = new Intl.DateTimeFormat('de-DE', {
|
||||
function formatDate(iso: string): string {
|
||||
return dateFormatter.format(new Date(iso));
|
||||
}
|
||||
|
||||
function formatCer(cer: number | undefined | null): string {
|
||||
if (cer == null) return '—';
|
||||
return (cer * 100).toFixed(1) + ' %';
|
||||
}
|
||||
</script>
|
||||
|
||||
<table class="w-full text-sm">
|
||||
@@ -37,12 +43,13 @@ function formatDate(iso: string): string {
|
||||
<th class="pb-2 text-left">{m.training_history_col_status()}</th>
|
||||
<th class="pb-2 text-right">{m.training_history_col_blocks()}</th>
|
||||
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_docs()}</th>
|
||||
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_cer()}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{#if runs.length === 0}
|
||||
<tr>
|
||||
<td colspan="4" class="py-4 text-center text-sm text-ink-2">
|
||||
<td colspan="5" class="py-4 text-center text-sm text-ink-2">
|
||||
{m.training_history_empty()}
|
||||
</td>
|
||||
</tr>
|
||||
@@ -71,6 +78,9 @@ function formatDate(iso: string): string {
|
||||
</td>
|
||||
<td class="py-2 text-right text-ink-2">{run.blockCount}</td>
|
||||
<td class="hidden py-2 text-right text-ink-2 md:table-cell">{run.documentCount}</td>
|
||||
<td class="hidden py-2 text-right md:table-cell"
|
||||
>{run.status === 'DONE' && run.cer != null ? formatCer(run.cer) : '—'}</td
|
||||
>
|
||||
</tr>
|
||||
{/each}
|
||||
{/if}
|
||||
|
||||
@@ -333,8 +333,10 @@ async def train_model(
|
||||
epochs = getattr(result, "epochs", None) or 0
|
||||
loss = getattr(result, "best_loss", None)
|
||||
accuracy = getattr(result, "best_accuracy", None)
|
||||
# CER = 1 - character accuracy (Kraken reports character-level accuracy)
|
||||
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
||||
|
||||
log.info("Training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
|
||||
log.info("Training complete — epochs=%s loss=%s accuracy=%s cer=%s", epochs, loss, accuracy, cer)
|
||||
|
||||
# Backup existing model and replace
|
||||
if os.path.exists(KRAKEN_MODEL_PATH):
|
||||
@@ -351,7 +353,7 @@ async def train_model(
|
||||
kraken_engine.load_models()
|
||||
log.info("Reloaded Kraken model in-process")
|
||||
|
||||
return {"loss": loss, "accuracy": accuracy, "epochs": epochs}
|
||||
return {"loss": loss, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_training)
|
||||
return result
|
||||
@@ -409,7 +411,7 @@ async def segtrain_model(
|
||||
loss = getattr(result, "best_loss", None)
|
||||
accuracy = getattr(result, "best_accuracy", None)
|
||||
|
||||
log.info("Segmentation training complete — epochs=%s loss=%s", epochs, loss)
|
||||
log.info("Segmentation training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
|
||||
|
||||
if os.path.exists(blla_model_path):
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
@@ -420,7 +422,7 @@ async def segtrain_model(
|
||||
shutil.move(output_model_path, blla_model_path)
|
||||
log.info("Replaced blla model at %s", blla_model_path)
|
||||
|
||||
return {"loss": loss, "accuracy": accuracy, "epochs": epochs}
|
||||
return {"loss": loss, "accuracy": accuracy, "cer": None, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_segtrain)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user