feat(training): track and display CER per training run

After each training run, the Character Error Rate (CER = 1 - accuracy),
loss, accuracy, and epoch count are now stored on the OcrTrainingRun
record and shown in the training history table.

Also adds the missing POST /api/ocr/segtrain endpoint and the
triggerSegTraining service method so the segmentation training card
can actually trigger training.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-13 19:01:10 +02:00
parent a99afef319
commit 22954f348a
12 changed files with 118 additions and 14 deletions

View File

@@ -120,6 +120,14 @@ public class OcrController {
return ocrTrainingService.triggerTraining(userId);
}
@PostMapping("/api/ocr/segtrain")
@ResponseStatus(HttpStatus.CREATED)
@RequirePermission(Permission.ADMIN)
public OcrTrainingRun triggerSegTraining(Authentication authentication) {
UUID userId = resolveUserId(authentication);
return ocrTrainingService.triggerSegTraining(userId);
}
@GetMapping("/api/ocr/training-info")
@RequirePermission(Permission.ADMIN)
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {

View File

@@ -41,6 +41,18 @@ public class OcrTrainingRun {
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private String modelName;
@Column(name = "cer")
private Double cer;
@Column(name = "loss")
private Double loss;
@Column(name = "accuracy")
private Double accuracy;
@Column(name = "epochs")
private Integer epochs;
@Column(name = "error_message")
private String errorMessage;

View File

@@ -26,7 +26,7 @@ public interface OcrClient {
*/
TrainingResult trainModel(byte[] trainingDataZip);
record TrainingResult(Double loss, Double accuracy, Integer epochs) {}
record TrainingResult(Double loss, Double accuracy, Double cer, Integer epochs) {}
/**
* Send a segmentation training ZIP to the OCR service for fine-tuning the blla model.

View File

@@ -80,12 +80,16 @@ public class OcrTrainingService {
byte[] zipBytes = baos.toByteArray();
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
ocrClient.trainModel(zipBytes);
OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setCer(result.cer());
run.setLoss(result.loss());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
run = trainingRunRepository.save(run);
log.info("[trainingRun={}] Training completed successfully", runId);
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
} catch (Exception e) {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
@@ -99,6 +103,66 @@ public class OcrTrainingService {
return run;
}
@Transactional
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
if (segBlocks.size() < 5) {
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
}
long documentCount = segBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
OcrTrainingRun run = OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.blockCount(segBlocks.size())
.documentCount((int) documentCount)
.modelName("blla")
.triggeredBy(triggeredBy)
.build();
run = trainingRunRepository.save(run);
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started segmentation training run {} with {} segments from {} documents",
runId, segBlocks.size(), documentCount);
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
segmentationTrainingExportService.exportToZip().writeTo(baos);
byte[] zipBytes = baos.toByteArray();
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setLoss(result.loss());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
run = trainingRunRepository.save(run);
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
} catch (Exception e) {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
run = trainingRunRepository.save(run);
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
} finally {
MDC.remove("trainingRunId");
}
return run;
}
public TrainingInfoResponse getTrainingInfo() {
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
int availableDocuments = (int) eligibleBlocks.stream()

View File

@@ -127,8 +127,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.retrieve()
.body(TrainingResultJson.class);
if (result == null) return new OcrClient.TrainingResult(null, null, null);
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
}
@Override
@@ -156,8 +156,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.retrieve()
.body(TrainingResultJson.class);
if (result == null) return new OcrClient.TrainingResult(null, null, null);
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
}
@Override
@@ -255,7 +255,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
}
}
record TrainingResultJson(Double loss, Double accuracy, Integer epochs) {}
record TrainingResultJson(Double loss, Double accuracy, Double cer, Integer epochs) {}
record OcrBlockJson(
@JsonProperty("pageNumber") int pageNumber,

View File

@@ -0,0 +1,5 @@
ALTER TABLE ocr_training_runs
ADD COLUMN cer DOUBLE PRECISION,
ADD COLUMN loss DOUBLE PRECISION,
ADD COLUMN accuracy DOUBLE PRECISION,
ADD COLUMN epochs INT;

View File

@@ -93,7 +93,7 @@ class OcrTrainingServiceTest {
);
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
when(exportService.exportToZip()).thenReturn(out -> {});
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 3));
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
OcrTrainingRun saved = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)