feat(ocr): full OCR pipeline with polygon annotations, training, and guided mode #232
@@ -13,11 +13,13 @@ import org.springframework.boot.context.event.ApplicationReadyEvent;
|
|||||||
import org.springframework.context.event.EventListener;
|
import org.springframework.context.event.EventListener;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
import org.springframework.transaction.support.TransactionTemplate;
|
||||||
|
|
||||||
import java.io.ByteArrayOutputStream;
|
import java.io.ByteArrayOutputStream;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@@ -31,6 +33,7 @@ public class OcrTrainingService {
|
|||||||
private final OcrClient ocrClient;
|
private final OcrClient ocrClient;
|
||||||
private final OcrHealthClient ocrHealthClient;
|
private final OcrHealthClient ocrHealthClient;
|
||||||
private final TranscriptionBlockRepository blockRepository;
|
private final TranscriptionBlockRepository blockRepository;
|
||||||
|
private final TransactionTemplate txTemplate;
|
||||||
|
|
||||||
public record TrainingInfoResponse(
|
public record TrainingInfoResponse(
|
||||||
int availableBlocks,
|
int availableBlocks,
|
||||||
@@ -46,8 +49,10 @@ public class OcrTrainingService {
|
|||||||
// Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique
|
// Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique
|
||||||
// index) prevents concurrent training API calls, but cannot prevent two OCR service replicas
|
// index) prevents concurrent training API calls, but cannot prevent two OCR service replicas
|
||||||
// from diverging on model state. Deploy as a single instance only. See ADR-001.
|
// from diverging on model state. Deploy as a single instance only. See ADR-001.
|
||||||
@Transactional
|
|
||||||
public OcrTrainingRun triggerTraining(UUID triggeredBy) {
|
public OcrTrainingRun triggerTraining(UUID triggeredBy) {
|
||||||
|
// Short transaction: guard check + create RUNNING row, then commit immediately.
|
||||||
|
// The DB connection is released before the OCR HTTP call, which can take several minutes.
|
||||||
|
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
||||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||||
"A training run is already in progress");
|
"A training run is already in progress");
|
||||||
@@ -64,19 +69,20 @@ public class OcrTrainingService {
|
|||||||
.distinct()
|
.distinct()
|
||||||
.count();
|
.count();
|
||||||
|
|
||||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
OcrTrainingRun newRun = OcrTrainingRun.builder()
|
||||||
.status(TrainingStatus.RUNNING)
|
.status(TrainingStatus.RUNNING)
|
||||||
.blockCount(eligibleBlocks.size())
|
.blockCount(eligibleBlocks.size())
|
||||||
.documentCount((int) documentCount)
|
.documentCount((int) documentCount)
|
||||||
.modelName("german_kurrent")
|
.modelName("german_kurrent")
|
||||||
.triggeredBy(triggeredBy)
|
.triggeredBy(triggeredBy)
|
||||||
.build();
|
.build();
|
||||||
run = trainingRunRepository.save(run);
|
return trainingRunRepository.save(newRun);
|
||||||
|
}));
|
||||||
|
|
||||||
String runId = run.getId().toString();
|
String runId = run.getId().toString();
|
||||||
MDC.put("trainingRunId", runId);
|
MDC.put("trainingRunId", runId);
|
||||||
log.info("Started training run {} with {} blocks from {} documents",
|
log.info("Started training run {} with {} blocks from {} documents",
|
||||||
runId, eligibleBlocks.size(), documentCount);
|
runId, run.getBlockCount(), run.getDocumentCount());
|
||||||
|
|
||||||
try {
|
try {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
@@ -86,29 +92,34 @@ public class OcrTrainingService {
|
|||||||
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
|
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
|
||||||
OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
|
OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
|
||||||
|
|
||||||
|
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||||
run.setStatus(TrainingStatus.DONE);
|
run.setStatus(TrainingStatus.DONE);
|
||||||
run.setCompletedAt(Instant.now());
|
run.setCompletedAt(Instant.now());
|
||||||
run.setCer(result.cer());
|
run.setCer(result.cer());
|
||||||
run.setLoss(result.loss());
|
run.setLoss(result.loss());
|
||||||
run.setAccuracy(result.accuracy());
|
run.setAccuracy(result.accuracy());
|
||||||
run.setEpochs(result.epochs());
|
run.setEpochs(result.epochs());
|
||||||
run = trainingRunRepository.save(run);
|
OcrTrainingRun updated = trainingRunRepository.save(run);
|
||||||
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
|
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
|
||||||
|
return updated;
|
||||||
|
}));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||||
run.setStatus(TrainingStatus.FAILED);
|
run.setStatus(TrainingStatus.FAILED);
|
||||||
run.setErrorMessage(e.getMessage());
|
run.setErrorMessage(e.getMessage());
|
||||||
run.setCompletedAt(Instant.now());
|
run.setCompletedAt(Instant.now());
|
||||||
run = trainingRunRepository.save(run);
|
OcrTrainingRun failed = trainingRunRepository.save(run);
|
||||||
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
|
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
|
||||||
|
return failed;
|
||||||
|
}));
|
||||||
} finally {
|
} finally {
|
||||||
MDC.remove("trainingRunId");
|
MDC.remove("trainingRunId");
|
||||||
}
|
}
|
||||||
|
|
||||||
return run;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Transactional
|
|
||||||
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
|
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
|
||||||
|
// Same pattern as triggerTraining: narrow transactions around DB writes only.
|
||||||
|
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
||||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||||
"A training run is already in progress");
|
"A training run is already in progress");
|
||||||
@@ -125,19 +136,20 @@ public class OcrTrainingService {
|
|||||||
.distinct()
|
.distinct()
|
||||||
.count();
|
.count();
|
||||||
|
|
||||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
OcrTrainingRun newRun = OcrTrainingRun.builder()
|
||||||
.status(TrainingStatus.RUNNING)
|
.status(TrainingStatus.RUNNING)
|
||||||
.blockCount(segBlocks.size())
|
.blockCount(segBlocks.size())
|
||||||
.documentCount((int) documentCount)
|
.documentCount((int) documentCount)
|
||||||
.modelName("blla")
|
.modelName("blla")
|
||||||
.triggeredBy(triggeredBy)
|
.triggeredBy(triggeredBy)
|
||||||
.build();
|
.build();
|
||||||
run = trainingRunRepository.save(run);
|
return trainingRunRepository.save(newRun);
|
||||||
|
}));
|
||||||
|
|
||||||
String runId = run.getId().toString();
|
String runId = run.getId().toString();
|
||||||
MDC.put("trainingRunId", runId);
|
MDC.put("trainingRunId", runId);
|
||||||
log.info("Started segmentation training run {} with {} segments from {} documents",
|
log.info("Started segmentation training run {} with {} segments from {} documents",
|
||||||
runId, segBlocks.size(), documentCount);
|
runId, run.getBlockCount(), run.getDocumentCount());
|
||||||
|
|
||||||
try {
|
try {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
@@ -147,24 +159,28 @@ public class OcrTrainingService {
|
|||||||
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
|
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
|
||||||
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
|
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
|
||||||
|
|
||||||
|
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||||
run.setStatus(TrainingStatus.DONE);
|
run.setStatus(TrainingStatus.DONE);
|
||||||
run.setCompletedAt(Instant.now());
|
run.setCompletedAt(Instant.now());
|
||||||
run.setLoss(result.loss());
|
run.setLoss(result.loss());
|
||||||
run.setAccuracy(result.accuracy());
|
run.setAccuracy(result.accuracy());
|
||||||
run.setEpochs(result.epochs());
|
run.setEpochs(result.epochs());
|
||||||
run = trainingRunRepository.save(run);
|
OcrTrainingRun updated = trainingRunRepository.save(run);
|
||||||
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
|
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
|
||||||
|
return updated;
|
||||||
|
}));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||||
run.setStatus(TrainingStatus.FAILED);
|
run.setStatus(TrainingStatus.FAILED);
|
||||||
run.setErrorMessage(e.getMessage());
|
run.setErrorMessage(e.getMessage());
|
||||||
run.setCompletedAt(Instant.now());
|
run.setCompletedAt(Instant.now());
|
||||||
run = trainingRunRepository.save(run);
|
OcrTrainingRun failed = trainingRunRepository.save(run);
|
||||||
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
|
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
|
||||||
|
return failed;
|
||||||
|
}));
|
||||||
} finally {
|
} finally {
|
||||||
MDC.remove("trainingRunId");
|
MDC.remove("trainingRunId");
|
||||||
}
|
}
|
||||||
|
|
||||||
return run;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public TrainingInfoResponse getTrainingInfo() {
|
public TrainingInfoResponse getTrainingInfo() {
|
||||||
@@ -174,7 +190,7 @@ public class OcrTrainingService {
|
|||||||
.distinct()
|
.distinct()
|
||||||
.count();
|
.count();
|
||||||
|
|
||||||
int totalOcrBlocks = blockRepository.findAll().size();
|
int totalOcrBlocks = (int) blockRepository.count();
|
||||||
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
|
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
|
||||||
|
|
||||||
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
|
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import org.raddatz.familienarchiv.model.TrainingStatus;
|
|||||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||||
|
import org.springframework.transaction.support.TransactionCallback;
|
||||||
|
import org.springframework.transaction.support.TransactionTemplate;
|
||||||
|
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -29,6 +31,7 @@ class OcrTrainingServiceTest {
|
|||||||
OcrClient ocrClient;
|
OcrClient ocrClient;
|
||||||
OcrHealthClient healthClient;
|
OcrHealthClient healthClient;
|
||||||
TranscriptionBlockRepository blockRepository;
|
TranscriptionBlockRepository blockRepository;
|
||||||
|
TransactionTemplate txTemplate;
|
||||||
OcrTrainingService service;
|
OcrTrainingService service;
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
@@ -39,10 +42,17 @@ class OcrTrainingServiceTest {
|
|||||||
ocrClient = mock(OcrClient.class);
|
ocrClient = mock(OcrClient.class);
|
||||||
healthClient = mock(OcrHealthClient.class);
|
healthClient = mock(OcrHealthClient.class);
|
||||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||||
|
txTemplate = mock(TransactionTemplate.class);
|
||||||
|
|
||||||
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository);
|
// Execute transaction callbacks inline so unit tests run without a real DataSource
|
||||||
|
when(txTemplate.execute(any())).thenAnswer(inv -> {
|
||||||
|
TransactionCallback<?> callback = inv.getArgument(0);
|
||||||
|
return callback.doInTransaction(null);
|
||||||
|
});
|
||||||
|
|
||||||
when(blockRepository.findAll()).thenReturn(List.of());
|
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
|
||||||
|
|
||||||
|
when(blockRepository.count()).thenReturn(0L);
|
||||||
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||||
when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
|
when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user