feat(training): track and display CER per training run
After each training run, the Character Error Rate (CER = 1 - accuracy), loss, accuracy, and epoch count are now stored on the OcrTrainingRun record and shown in the training history table. Also adds the missing POST /api/ocr/segtrain endpoint and the triggerSegTraining service method so the segmentation training card can actually trigger training. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -120,6 +120,14 @@ public class OcrController {
|
||||
return ocrTrainingService.triggerTraining(userId);
|
||||
}
|
||||
|
||||
@PostMapping("/api/ocr/segtrain")
|
||||
@ResponseStatus(HttpStatus.CREATED)
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public OcrTrainingRun triggerSegTraining(Authentication authentication) {
|
||||
UUID userId = resolveUserId(authentication);
|
||||
return ocrTrainingService.triggerSegTraining(userId);
|
||||
}
|
||||
|
||||
@GetMapping("/api/ocr/training-info")
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() {
|
||||
|
||||
@@ -41,6 +41,18 @@ public class OcrTrainingRun {
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private String modelName;
|
||||
|
||||
@Column(name = "cer")
|
||||
private Double cer;
|
||||
|
||||
@Column(name = "loss")
|
||||
private Double loss;
|
||||
|
||||
@Column(name = "accuracy")
|
||||
private Double accuracy;
|
||||
|
||||
@Column(name = "epochs")
|
||||
private Integer epochs;
|
||||
|
||||
@Column(name = "error_message")
|
||||
private String errorMessage;
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ public interface OcrClient {
|
||||
*/
|
||||
TrainingResult trainModel(byte[] trainingDataZip);
|
||||
|
||||
record TrainingResult(Double loss, Double accuracy, Integer epochs) {}
|
||||
record TrainingResult(Double loss, Double accuracy, Double cer, Integer epochs) {}
|
||||
|
||||
/**
|
||||
* Send a segmentation training ZIP to the OCR service for fine-tuning the blla model.
|
||||
|
||||
@@ -80,12 +80,16 @@ public class OcrTrainingService {
|
||||
byte[] zipBytes = baos.toByteArray();
|
||||
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
|
||||
ocrClient.trainModel(zipBytes);
|
||||
OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setCer(result.cer());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Training completed successfully", runId);
|
||||
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
|
||||
} catch (Exception e) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
@@ -99,6 +103,66 @@ public class OcrTrainingService {
|
||||
return run;
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
|
||||
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
|
||||
if (segBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
|
||||
}
|
||||
|
||||
long documentCount = segBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(segBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("blla")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
run = trainingRunRepository.save(run);
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started segmentation training run {} with {} segments from {} documents",
|
||||
runId, segBlocks.size(), documentCount);
|
||||
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
segmentationTrainingExportService.exportToZip().writeTo(baos);
|
||||
byte[] zipBytes = baos.toByteArray();
|
||||
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
|
||||
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
|
||||
} catch (Exception e) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
}
|
||||
|
||||
return run;
|
||||
}
|
||||
|
||||
public TrainingInfoResponse getTrainingInfo() {
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
int availableDocuments = (int) eligibleBlocks.stream()
|
||||
|
||||
@@ -127,8 +127,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
.retrieve()
|
||||
.body(TrainingResultJson.class);
|
||||
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -156,8 +156,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
.retrieve()
|
||||
.body(TrainingResultJson.class);
|
||||
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -255,7 +255,7 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
}
|
||||
}
|
||||
|
||||
record TrainingResultJson(Double loss, Double accuracy, Integer epochs) {}
|
||||
record TrainingResultJson(Double loss, Double accuracy, Double cer, Integer epochs) {}
|
||||
|
||||
record OcrBlockJson(
|
||||
@JsonProperty("pageNumber") int pageNumber,
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE ocr_training_runs
|
||||
ADD COLUMN cer DOUBLE PRECISION,
|
||||
ADD COLUMN loss DOUBLE PRECISION,
|
||||
ADD COLUMN accuracy DOUBLE PRECISION,
|
||||
ADD COLUMN epochs INT;
|
||||
@@ -93,7 +93,7 @@ class OcrTrainingServiceTest {
|
||||
);
|
||||
when(exportService.queryEligibleBlocks()).thenReturn(blocks);
|
||||
when(exportService.exportToZip()).thenReturn(out -> {});
|
||||
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 3));
|
||||
when(ocrClient.trainModel(any())).thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
|
||||
|
||||
OcrTrainingRun saved = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
|
||||
Reference in New Issue
Block a user