feat(ocr): full OCR pipeline with polygon annotations, training, and guided mode #232
@@ -13,11 +13,13 @@ import org.springframework.boot.context.event.ApplicationReadyEvent;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
@@ -31,6 +33,7 @@ public class OcrTrainingService {
|
||||
private final OcrClient ocrClient;
|
||||
private final OcrHealthClient ocrHealthClient;
|
||||
private final TranscriptionBlockRepository blockRepository;
|
||||
private final TransactionTemplate txTemplate;
|
||||
|
||||
public record TrainingInfoResponse(
|
||||
int availableBlocks,
|
||||
@@ -46,37 +49,40 @@ public class OcrTrainingService {
|
||||
// Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique
|
||||
// index) prevents concurrent training API calls, but cannot prevent two OCR service replicas
|
||||
// from diverging on model state. Deploy as a single instance only. See ADR-001.
|
||||
@Transactional
|
||||
public OcrTrainingRun triggerTraining(UUID triggeredBy) {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
// Short transaction: guard check + create RUNNING row, then commit immediately.
|
||||
// The DB connection is released before the OCR HTTP call, which can take several minutes.
|
||||
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
if (eligibleBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
|
||||
}
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
if (eligibleBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
|
||||
}
|
||||
|
||||
long documentCount = eligibleBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
long documentCount = eligibleBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(eligibleBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("german_kurrent")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
run = trainingRunRepository.save(run);
|
||||
OcrTrainingRun newRun = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(eligibleBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("german_kurrent")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
return trainingRunRepository.save(newRun);
|
||||
}));
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started training run {} with {} blocks from {} documents",
|
||||
runId, eligibleBlocks.size(), documentCount);
|
||||
runId, run.getBlockCount(), run.getDocumentCount());
|
||||
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
@@ -86,58 +92,64 @@ public class OcrTrainingService {
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
|
||||
OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setCer(result.cer());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setCer(result.cer());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
OcrTrainingRun updated = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
|
||||
return updated;
|
||||
}));
|
||||
} catch (Exception e) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
OcrTrainingRun failed = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
|
||||
return failed;
|
||||
}));
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
}
|
||||
|
||||
return run;
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
// Same pattern as triggerTraining: narrow transactions around DB writes only.
|
||||
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
|
||||
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
|
||||
if (segBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
|
||||
}
|
||||
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
|
||||
if (segBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
|
||||
}
|
||||
|
||||
long documentCount = segBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
long documentCount = segBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(segBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("blla")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
run = trainingRunRepository.save(run);
|
||||
OcrTrainingRun newRun = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(segBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("blla")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
return trainingRunRepository.save(newRun);
|
||||
}));
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started segmentation training run {} with {} segments from {} documents",
|
||||
runId, segBlocks.size(), documentCount);
|
||||
runId, run.getBlockCount(), run.getDocumentCount());
|
||||
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
@@ -147,24 +159,28 @@ public class OcrTrainingService {
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
|
||||
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
OcrTrainingRun updated = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
|
||||
return updated;
|
||||
}));
|
||||
} catch (Exception e) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
OcrTrainingRun failed = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
|
||||
return failed;
|
||||
}));
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
}
|
||||
|
||||
return run;
|
||||
}
|
||||
|
||||
public TrainingInfoResponse getTrainingInfo() {
|
||||
@@ -174,7 +190,7 @@ public class OcrTrainingService {
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
int totalOcrBlocks = blockRepository.findAll().size();
|
||||
int totalOcrBlocks = (int) blockRepository.count();
|
||||
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
|
||||
|
||||
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
|
||||
|
||||
@@ -9,6 +9,8 @@ import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.transaction.support.TransactionCallback;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
@@ -29,6 +31,7 @@ class OcrTrainingServiceTest {
|
||||
OcrClient ocrClient;
|
||||
OcrHealthClient healthClient;
|
||||
TranscriptionBlockRepository blockRepository;
|
||||
TransactionTemplate txTemplate;
|
||||
OcrTrainingService service;
|
||||
|
||||
@BeforeEach
|
||||
@@ -39,10 +42,17 @@ class OcrTrainingServiceTest {
|
||||
ocrClient = mock(OcrClient.class);
|
||||
healthClient = mock(OcrHealthClient.class);
|
||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||
txTemplate = mock(TransactionTemplate.class);
|
||||
|
||||
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository);
|
||||
// Execute transaction callbacks inline so unit tests run without a real DataSource
|
||||
when(txTemplate.execute(any())).thenAnswer(inv -> {
|
||||
TransactionCallback<?> callback = inv.getArgument(0);
|
||||
return callback.doInTransaction(null);
|
||||
});
|
||||
|
||||
when(blockRepository.findAll()).thenReturn(List.of());
|
||||
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
|
||||
|
||||
when(blockRepository.count()).thenReturn(0L);
|
||||
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||
when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user