feat(ocr): full OCR pipeline with polygon annotations, training, and guided mode #232

Merged
marcel merged 40 commits from feat/issue-226-227-ocr-pipeline-polygon into main 2026-04-14 10:31:35 +02:00
2 changed files with 104 additions and 78 deletions
Showing only changes of commit dc283ba271 - Show all commits

View File

@@ -13,11 +13,13 @@ import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.TransactionTemplate;
import java.io.ByteArrayOutputStream;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
@Service
@@ -31,6 +33,7 @@ public class OcrTrainingService {
private final OcrClient ocrClient;
private final OcrHealthClient ocrHealthClient;
private final TranscriptionBlockRepository blockRepository;
private final TransactionTemplate txTemplate;
public record TrainingInfoResponse(
int availableBlocks,
@@ -46,37 +49,40 @@ public class OcrTrainingService {
// Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique
// index) prevents concurrent training API calls, but cannot prevent two OCR service replicas
// from diverging on model state. Deploy as a single instance only. See ADR-001.
@Transactional
public OcrTrainingRun triggerTraining(UUID triggeredBy) {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
// Short transaction: guard check + create RUNNING row, then commit immediately.
// The DB connection is released before the OCR HTTP call, which can take several minutes.
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
if (eligibleBlocks.size() < 5) {
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
}
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
if (eligibleBlocks.size() < 5) {
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
}
long documentCount = eligibleBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
long documentCount = eligibleBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
OcrTrainingRun run = OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.blockCount(eligibleBlocks.size())
.documentCount((int) documentCount)
.modelName("german_kurrent")
.triggeredBy(triggeredBy)
.build();
run = trainingRunRepository.save(run);
OcrTrainingRun newRun = OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.blockCount(eligibleBlocks.size())
.documentCount((int) documentCount)
.modelName("german_kurrent")
.triggeredBy(triggeredBy)
.build();
return trainingRunRepository.save(newRun);
}));
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started training run {} with {} blocks from {} documents",
runId, eligibleBlocks.size(), documentCount);
runId, run.getBlockCount(), run.getDocumentCount());
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
@@ -86,58 +92,64 @@ public class OcrTrainingService {
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setCer(result.cer());
run.setLoss(result.loss());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
run = trainingRunRepository.save(run);
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
return Objects.requireNonNull(txTemplate.execute(status -> {
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setCer(result.cer());
run.setLoss(result.loss());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
OcrTrainingRun updated = trainingRunRepository.save(run);
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
return updated;
}));
} catch (Exception e) {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
run = trainingRunRepository.save(run);
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
return Objects.requireNonNull(txTemplate.execute(status -> {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
OcrTrainingRun failed = trainingRunRepository.save(run);
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
return failed;
}));
} finally {
MDC.remove("trainingRunId");
}
return run;
}
@Transactional
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
// Same pattern as triggerTraining: narrow transactions around DB writes only.
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
if (segBlocks.size() < 5) {
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
}
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
if (segBlocks.size() < 5) {
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
}
long documentCount = segBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
long documentCount = segBlocks.stream()
.map(b -> b.getDocumentId())
.distinct()
.count();
OcrTrainingRun run = OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.blockCount(segBlocks.size())
.documentCount((int) documentCount)
.modelName("blla")
.triggeredBy(triggeredBy)
.build();
run = trainingRunRepository.save(run);
OcrTrainingRun newRun = OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.blockCount(segBlocks.size())
.documentCount((int) documentCount)
.modelName("blla")
.triggeredBy(triggeredBy)
.build();
return trainingRunRepository.save(newRun);
}));
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started segmentation training run {} with {} segments from {} documents",
runId, segBlocks.size(), documentCount);
runId, run.getBlockCount(), run.getDocumentCount());
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
@@ -147,24 +159,28 @@ public class OcrTrainingService {
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setLoss(result.loss());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
run = trainingRunRepository.save(run);
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
return Objects.requireNonNull(txTemplate.execute(status -> {
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setLoss(result.loss());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
OcrTrainingRun updated = trainingRunRepository.save(run);
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
return updated;
}));
} catch (Exception e) {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
run = trainingRunRepository.save(run);
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
return Objects.requireNonNull(txTemplate.execute(status -> {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
OcrTrainingRun failed = trainingRunRepository.save(run);
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
return failed;
}));
} finally {
MDC.remove("trainingRunId");
}
return run;
}
public TrainingInfoResponse getTrainingInfo() {
@@ -174,7 +190,7 @@ public class OcrTrainingService {
.distinct()
.count();
int totalOcrBlocks = blockRepository.findAll().size();
int totalOcrBlocks = (int) blockRepository.count();
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();

View File

@@ -9,6 +9,8 @@ import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
import java.time.Instant;
import java.util.List;
@@ -29,6 +31,7 @@ class OcrTrainingServiceTest {
OcrClient ocrClient;
OcrHealthClient healthClient;
TranscriptionBlockRepository blockRepository;
TransactionTemplate txTemplate;
OcrTrainingService service;
@BeforeEach
@@ -39,10 +42,17 @@ class OcrTrainingServiceTest {
ocrClient = mock(OcrClient.class);
healthClient = mock(OcrHealthClient.class);
blockRepository = mock(TranscriptionBlockRepository.class);
txTemplate = mock(TransactionTemplate.class);
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository);
// Execute transaction callbacks inline so unit tests run without a real DataSource
when(txTemplate.execute(any())).thenAnswer(inv -> {
TransactionCallback<?> callback = inv.getArgument(0);
return callback.doInTransaction(null);
});
when(blockRepository.findAll()).thenReturn(List.of());
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
when(blockRepository.count()).thenReturn(0L);
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
}