feat(training): track and display CER per training run
After each training run, the Character Error Rate (CER = 1 - accuracy), loss, accuracy, and epoch count are now stored on the OcrTrainingRun record and shown in the training history table. Also adds the missing POST /api/ocr/segtrain endpoint and the triggerSegTraining service method so the segmentation training card can actually trigger training. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -120,6 +120,14 @@ public class OcrController {
|
|||||||
return ocrTrainingService.triggerTraining(userId);
|
return ocrTrainingService.triggerTraining(userId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@PostMapping("/api/ocr/segtrain")
|
||||||
|
@ResponseStatus(HttpStatus.CREATED)
|
||||||
|
@RequirePermission(Permission.ADMIN)
|
||||||
|
public OcrTrainingRun triggerSegTraining(Authentication authentication) {
|
||||||
|
UUID userId = resolveUserId(authentication);
|
||||||
|
return ocrTrainingService.triggerSegTraining(userId);
|
||||||
|
}
|
||||||
|
|
||||||
@GetMapping("/api/ocr/training-info")
|
@GetMapping("/api/ocr/training-info")
|
||||||
@RequirePermission(Permission.ADMIN)
|
@RequirePermission(Permission.ADMIN)
|
||||||
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {
|
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {
|
||||||
|
|||||||
@@ -41,6 +41,18 @@ public class OcrTrainingRun {
|
|||||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||||
private String modelName;
|
private String modelName;
|
||||||
|
|
||||||
|
@Column(name = "cer")
|
||||||
|
private Double cer;
|
||||||
|
|
||||||
|
@Column(name = "loss")
|
||||||
|
private Double loss;
|
||||||
|
|
||||||
|
@Column(name = "accuracy")
|
||||||
|
private Double accuracy;
|
||||||
|
|
||||||
|
@Column(name = "epochs")
|
||||||
|
private Integer epochs;
|
||||||
|
|
||||||
@Column(name = "error_message")
|
@Column(name = "error_message")
|
||||||
private String errorMessage;
|
private String errorMessage;
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ public interface OcrClient {
|
|||||||
*/
|
*/
|
||||||
TrainingResult trainModel(byte[] trainingDataZip);
|
TrainingResult trainModel(byte[] trainingDataZip);
|
||||||
|
|
||||||
record TrainingResult(Double loss, Double accuracy, Integer epochs) {}
|
record TrainingResult(Double loss, Double accuracy, Double cer, Integer epochs) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Send a segmentation training ZIP to the OCR service for fine-tuning the blla model.
|
* Send a segmentation training ZIP to the OCR service for fine-tuning the blla model.
|
||||||
|
|||||||
@@ -80,12 +80,16 @@ public class OcrTrainingService {
|
|||||||
byte[] zipBytes = baos.toByteArray();
|
byte[] zipBytes = baos.toByteArray();
|
||||||
|
|
||||||
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
|
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
|
||||||
ocrClient.trainModel(zipBytes);
|
OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
|
||||||
|
|
||||||
run.setStatus(TrainingStatus.DONE);
|
run.setStatus(TrainingStatus.DONE);
|
||||||
run.setCompletedAt(Instant.now());
|
run.setCompletedAt(Instant.now());
|
||||||
|
run.setCer(result.cer());
|
||||||
|
run.setLoss(result.loss());
|
||||||
|
run.setAccuracy(result.accuracy());
|
||||||
|
run.setEpochs(result.epochs());
|
||||||
run = trainingRunRepository.save(run);
|
run = trainingRunRepository.save(run);
|
||||||
log.info("[trainingRun={}] Training completed successfully", runId);
|
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
run.setStatus(TrainingStatus.FAILED);
|
run.setStatus(TrainingStatus.FAILED);
|
||||||
run.setErrorMessage(e.getMessage());
|
run.setErrorMessage(e.getMessage());
|
||||||
@@ -99,6 +103,66 @@ public class OcrTrainingService {
|
|||||||
return run;
|
return run;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Transactional
|
||||||
|
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
|
||||||
|
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||||
|
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||||
|
"A training run is already in progress");
|
||||||
|
}
|
||||||
|
|
||||||
|
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
|
||||||
|
if (segBlocks.size() < 5) {
|
||||||
|
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||||
|
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
long documentCount = segBlocks.stream()
|
||||||
|
.map(b -> b.getDocumentId())
|
||||||
|
.distinct()
|
||||||
|
.count();
|
||||||
|
|
||||||
|
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||||
|
.status(TrainingStatus.RUNNING)
|
||||||
|
.blockCount(segBlocks.size())
|
||||||
|
.documentCount((int) documentCount)
|
||||||
|
.modelName("blla")
|
||||||
|
.triggeredBy(triggeredBy)
|
||||||
|
.build();
|
||||||
|
run = trainingRunRepository.save(run);
|
||||||
|
|
||||||
|
String runId = run.getId().toString();
|
||||||
|
MDC.put("trainingRunId", runId);
|
||||||
|
log.info("Started segmentation training run {} with {} segments from {} documents",
|
||||||
|
runId, segBlocks.size(), documentCount);
|
||||||
|
|
||||||
|
try {
|
||||||
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
|
segmentationTrainingExportService.exportToZip().writeTo(baos);
|
||||||
|
byte[] zipBytes = baos.toByteArray();
|
||||||
|
|
||||||
|
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
|
||||||
|
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
|
||||||
|
|
||||||
|
run.setStatus(TrainingStatus.DONE);
|
||||||
|
run.setCompletedAt(Instant.now());
|
||||||
|
run.setLoss(result.loss());
|
||||||
|
run.setAccuracy(result.accuracy());
|
||||||
|
run.setEpochs(result.epochs());
|
||||||
|
run = trainingRunRepository.save(run);
|
||||||
|
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
|
||||||
|
} catch (Exception e) {
|
||||||
|
run.setStatus(TrainingStatus.FAILED);
|
||||||
|
run.setErrorMessage(e.getMessage());
|
||||||
|
run.setCompletedAt(Instant.now());
|
||||||
|
run = trainingRunRepository.save(run);
|
||||||
|
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
|
||||||
|
} finally {
|
||||||
|
MDC.remove("trainingRunId");
|
||||||
|
}
|
||||||
|
|
||||||
|
return run;
|
||||||
|
}
|
||||||
|
|
||||||
public TrainingInfoResponse getTrainingInfo() {
|
public TrainingInfoResponse getTrainingInfo() {
|
||||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||||
int availableDocuments = (int) eligibleBlocks.stream()
|
int availableDocuments = (int) eligibleBlocks.stream()
|
||||||
|
|||||||
@@ -127,8 +127,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
|||||||
.retrieve()
|
.retrieve()
|
||||||
.body(TrainingResultJson.class);
|
.body(TrainingResultJson.class);
|
||||||
|
|
||||||
if (result == null) return new OcrClient.TrainingResult(null, null, null);
|
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
|
||||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
|
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -156,8 +156,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
|||||||
.retrieve()
|
.retrieve()
|
||||||
.body(TrainingResultJson.class);
|
.body(TrainingResultJson.class);
|
||||||
|
|
||||||
if (result == null) return new OcrClient.TrainingResult(null, null, null);
|
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
|
||||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
|
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -255,7 +255,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
record TrainingResultJson(Double loss, Double accuracy, Integer epochs) {}
|
record TrainingResultJson(Double loss, Double accuracy, Double cer, Integer epochs) {}
|
||||||
|
|
||||||
record OcrBlockJson(
|
record OcrBlockJson(
|
||||||
@JsonProperty("pageNumber") int pageNumber,
|
@JsonProperty("pageNumber") int pageNumber,
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
ALTER TABLE ocr_training_runs
|
||||||
|
ADD COLUMN cer DOUBLE PRECISION,
|
||||||
|
ADD COLUMN loss DOUBLE PRECISION,
|
||||||
|
ADD COLUMN accuracy DOUBLE PRECISION,
|
||||||
|
ADD COLUMN epochs INT;
|
||||||
@@ -93,7 +93,7 @@ class OcrTrainingServiceTest {
|
|||||||
);
|
);
|
||||||
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
|
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
|
||||||
when(exportService.exportToZip()).thenReturn(out -> {});
|
when(exportService.exportToZip()).thenReturn(out -> {});
|
||||||
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 3));
|
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
|
||||||
|
|
||||||
OcrTrainingRun saved = OcrTrainingRun.builder()
|
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||||
|
|||||||
@@ -550,6 +550,7 @@
|
|||||||
"training_history_col_status": "Status",
|
"training_history_col_status": "Status",
|
||||||
"training_history_col_blocks": "Blöcke",
|
"training_history_col_blocks": "Blöcke",
|
||||||
"training_history_col_docs": "Dokumente",
|
"training_history_col_docs": "Dokumente",
|
||||||
|
"training_history_col_cer": "Fehlerrate",
|
||||||
"training_status_done": "Fertig",
|
"training_status_done": "Fertig",
|
||||||
"training_status_failed": "Fehler",
|
"training_status_failed": "Fehler",
|
||||||
"training_status_running": "Läuft…",
|
"training_status_running": "Läuft…",
|
||||||
|
|||||||
@@ -550,6 +550,7 @@
|
|||||||
"training_history_col_status": "Status",
|
"training_history_col_status": "Status",
|
||||||
"training_history_col_blocks": "Blocks",
|
"training_history_col_blocks": "Blocks",
|
||||||
"training_history_col_docs": "Documents",
|
"training_history_col_docs": "Documents",
|
||||||
|
"training_history_col_cer": "Error Rate",
|
||||||
"training_status_done": "Done",
|
"training_status_done": "Done",
|
||||||
"training_status_failed": "Failed",
|
"training_status_failed": "Failed",
|
||||||
"training_status_running": "Running…",
|
"training_status_running": "Running…",
|
||||||
|
|||||||
@@ -550,6 +550,7 @@
|
|||||||
"training_history_col_status": "Estado",
|
"training_history_col_status": "Estado",
|
||||||
"training_history_col_blocks": "Bloques",
|
"training_history_col_blocks": "Bloques",
|
||||||
"training_history_col_docs": "Documentos",
|
"training_history_col_docs": "Documentos",
|
||||||
|
"training_history_col_cer": "Tasa de error",
|
||||||
"training_status_done": "Listo",
|
"training_status_done": "Listo",
|
||||||
"training_status_failed": "Error",
|
"training_status_failed": "Error",
|
||||||
"training_status_running": "Ejecutando…",
|
"training_status_running": "Ejecutando…",
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ interface Run {
|
|||||||
triggeredBy?: string;
|
triggeredBy?: string;
|
||||||
createdAt: string;
|
createdAt: string;
|
||||||
completedAt?: string;
|
completedAt?: string;
|
||||||
|
cer?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
@@ -28,6 +29,11 @@ const dateFormatter = new Intl.DateTimeFormat('de-DE', {
|
|||||||
function formatDate(iso: string): string {
|
function formatDate(iso: string): string {
|
||||||
return dateFormatter.format(new Date(iso));
|
return dateFormatter.format(new Date(iso));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function formatCer(cer: number | undefined | null): string {
|
||||||
|
if (cer == null) return '—';
|
||||||
|
return (cer * 100).toFixed(1) + ' %';
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<table class="w-full text-sm">
|
<table class="w-full text-sm">
|
||||||
@@ -37,12 +43,13 @@ function formatDate(iso: string): string {
|
|||||||
<th class="pb-2 text-left">{m.training_history_col_status()}</th>
|
<th class="pb-2 text-left">{m.training_history_col_status()}</th>
|
||||||
<th class="pb-2 text-right">{m.training_history_col_blocks()}</th>
|
<th class="pb-2 text-right">{m.training_history_col_blocks()}</th>
|
||||||
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_docs()}</th>
|
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_docs()}</th>
|
||||||
|
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_cer()}</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
{#if runs.length === 0}
|
{#if runs.length === 0}
|
||||||
<tr>
|
<tr>
|
||||||
<td colspan="4" class="py-4 text-center text-sm text-ink-2">
|
<td colspan="5" class="py-4 text-center text-sm text-ink-2">
|
||||||
{m.training_history_empty()}
|
{m.training_history_empty()}
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
@@ -71,6 +78,9 @@ function formatDate(iso: string): string {
|
|||||||
</td>
|
</td>
|
||||||
<td class="py-2 text-right text-ink-2">{run.blockCount}</td>
|
<td class="py-2 text-right text-ink-2">{run.blockCount}</td>
|
||||||
<td class="hidden py-2 text-right text-ink-2 md:table-cell">{run.documentCount}</td>
|
<td class="hidden py-2 text-right text-ink-2 md:table-cell">{run.documentCount}</td>
|
||||||
|
<td class="hidden py-2 text-right md:table-cell"
|
||||||
|
>{run.status === 'DONE' && run.cer != null ? formatCer(run.cer) : '—'}</td
|
||||||
|
>
|
||||||
</tr>
|
</tr>
|
||||||
{/each}
|
{/each}
|
||||||
{/if}
|
{/if}
|
||||||
|
|||||||
@@ -333,8 +333,10 @@ async def train_model(
|
|||||||
epochs = getattr(result, "epochs", None) or 0
|
epochs = getattr(result, "epochs", None) or 0
|
||||||
loss = getattr(result, "best_loss", None)
|
loss = getattr(result, "best_loss", None)
|
||||||
accuracy = getattr(result, "best_accuracy", None)
|
accuracy = getattr(result, "best_accuracy", None)
|
||||||
|
# CER = 1 - character accuracy (Kraken reports character-level accuracy)
|
||||||
|
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
||||||
|
|
||||||
log.info("Training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
|
log.info("Training complete — epochs=%s loss=%s accuracy=%s cer=%s", epochs, loss, accuracy, cer)
|
||||||
|
|
||||||
# Backup existing model and replace
|
# Backup existing model and replace
|
||||||
if os.path.exists(KRAKEN_MODEL_PATH):
|
if os.path.exists(KRAKEN_MODEL_PATH):
|
||||||
@@ -351,7 +353,7 @@ async def train_model(
|
|||||||
kraken_engine.load_models()
|
kraken_engine.load_models()
|
||||||
log.info("Reloaded Kraken model in-process")
|
log.info("Reloaded Kraken model in-process")
|
||||||
|
|
||||||
return {"loss": loss, "accuracy": accuracy, "epochs": epochs}
|
return {"loss": loss, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||||
|
|
||||||
result = await asyncio.to_thread(_run_training)
|
result = await asyncio.to_thread(_run_training)
|
||||||
return result
|
return result
|
||||||
@@ -409,7 +411,7 @@ async def segtrain_model(
|
|||||||
loss = getattr(result, "best_loss", None)
|
loss = getattr(result, "best_loss", None)
|
||||||
accuracy = getattr(result, "best_accuracy", None)
|
accuracy = getattr(result, "best_accuracy", None)
|
||||||
|
|
||||||
log.info("Segmentation training complete — epochs=%s loss=%s", epochs, loss)
|
log.info("Segmentation training complete — epochs=%s loss=%s accuracy=%s", epochs, loss, accuracy)
|
||||||
|
|
||||||
if os.path.exists(blla_model_path):
|
if os.path.exists(blla_model_path):
|
||||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||||
@@ -420,7 +422,7 @@ async def segtrain_model(
|
|||||||
shutil.move(output_model_path, blla_model_path)
|
shutil.move(output_model_path, blla_model_path)
|
||||||
log.info("Replaced blla model at %s", blla_model_path)
|
log.info("Replaced blla model at %s", blla_model_path)
|
||||||
|
|
||||||
return {"loss": loss, "accuracy": accuracy, "epochs": epochs}
|
return {"loss": loss, "accuracy": accuracy, "cer": None, "epochs": epochs}
|
||||||
|
|
||||||
result = await asyncio.to_thread(_run_segtrain)
|
result = await asyncio.to_thread(_run_segtrain)
|
||||||
return result
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user