fix(training): remove @Transactional from triggerTraining to avoid holding DB connection during OCR HTTP call

OcrTrainingService.triggerTraining() and triggerSegTraining() held a DB
connection open for the entire ketos training run (potentially minutes),
risking connection pool exhaustion. Replaced class-level @Transactional
with TransactionTemplate for narrow DB writes: guard+create and
result-record each run in their own short transaction; the HTTP call to
the OCR service runs between them with no open connection.

Also replaces blockRepository.findAll().size() with blockRepository.count()
in getTrainingInfo() to avoid loading every block into heap on each poll.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-14 09:59:12 +02:00
parent 62be895b9e
commit dc283ba271
2 changed files with 104 additions and 78 deletions

View File

@@ -13,11 +13,13 @@ import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.event.EventListener; import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.TransactionTemplate;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.time.Instant; import java.time.Instant;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.UUID; import java.util.UUID;
@Service @Service
@@ -31,6 +33,7 @@ public class OcrTrainingService {
private final OcrClient ocrClient; private final OcrClient ocrClient;
private final OcrHealthClient ocrHealthClient; private final OcrHealthClient ocrHealthClient;
private final TranscriptionBlockRepository blockRepository; private final TranscriptionBlockRepository blockRepository;
private final TransactionTemplate txTemplate;
public record TrainingInfoResponse( public record TrainingInfoResponse(
int availableBlocks, int availableBlocks,
@@ -46,37 +49,40 @@ public class OcrTrainingService {
// Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique // Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique
// index) prevents concurrent training API calls, but cannot prevent two OCR service replicas // index) prevents concurrent training API calls, but cannot prevent two OCR service replicas
// from diverging on model state. Deploy as a single instance only. See ADR-001. // from diverging on model state. Deploy as a single instance only. See ADR-001.
@Transactional
public OcrTrainingRun triggerTraining(UUID triggeredBy) { public OcrTrainingRun triggerTraining(UUID triggeredBy) {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { // Short transaction: guard check + create RUNNING row, then commit immediately.
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, // The DB connection is released before the OCR HTTP call, which can take several minutes.
"A training run is already in progress"); OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
} if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks(); var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
if (eligibleBlocks.size() < 5) { if (eligibleBlocks.size() < 5) {
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR, throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")"); "At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
} }
long documentCount = eligibleBlocks.stream() long documentCount = eligibleBlocks.stream()
.map(b -> b.getDocumentId()) .map(b -> b.getDocumentId())
.distinct() .distinct()
.count(); .count();
OcrTrainingRun run = OcrTrainingRun.builder() OcrTrainingRun newRun = OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING) .status(TrainingStatus.RUNNING)
.blockCount(eligibleBlocks.size()) .blockCount(eligibleBlocks.size())
.documentCount((int) documentCount) .documentCount((int) documentCount)
.modelName("german_kurrent") .modelName("german_kurrent")
.triggeredBy(triggeredBy) .triggeredBy(triggeredBy)
.build(); .build();
run = trainingRunRepository.save(run); return trainingRunRepository.save(newRun);
}));
String runId = run.getId().toString(); String runId = run.getId().toString();
MDC.put("trainingRunId", runId); MDC.put("trainingRunId", runId);
log.info("Started training run {} with {} blocks from {} documents", log.info("Started training run {} with {} blocks from {} documents",
runId, eligibleBlocks.size(), documentCount); runId, run.getBlockCount(), run.getDocumentCount());
try { try {
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
@@ -86,58 +92,64 @@ public class OcrTrainingService {
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length); log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes); OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
run.setStatus(TrainingStatus.DONE); return Objects.requireNonNull(txTemplate.execute(status -> {
run.setCompletedAt(Instant.now()); run.setStatus(TrainingStatus.DONE);
run.setCer(result.cer()); run.setCompletedAt(Instant.now());
run.setLoss(result.loss()); run.setCer(result.cer());
run.setAccuracy(result.accuracy()); run.setLoss(result.loss());
run.setEpochs(result.epochs()); run.setAccuracy(result.accuracy());
run = trainingRunRepository.save(run); run.setEpochs(result.epochs());
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs()); OcrTrainingRun updated = trainingRunRepository.save(run);
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
return updated;
}));
} catch (Exception e) { } catch (Exception e) {
run.setStatus(TrainingStatus.FAILED); return Objects.requireNonNull(txTemplate.execute(status -> {
run.setErrorMessage(e.getMessage()); run.setStatus(TrainingStatus.FAILED);
run.setCompletedAt(Instant.now()); run.setErrorMessage(e.getMessage());
run = trainingRunRepository.save(run); run.setCompletedAt(Instant.now());
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e); OcrTrainingRun failed = trainingRunRepository.save(run);
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
return failed;
}));
} finally { } finally {
MDC.remove("trainingRunId"); MDC.remove("trainingRunId");
} }
return run;
} }
@Transactional
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) { public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { // Same pattern as triggerTraining: narrow transactions around DB writes only.
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
"A training run is already in progress"); if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
} throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
"A training run is already in progress");
}
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks(); var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
if (segBlocks.size() < 5) { if (segBlocks.size() < 5) {
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR, throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")"); "At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
} }
long documentCount = segBlocks.stream() long documentCount = segBlocks.stream()
.map(b -> b.getDocumentId()) .map(b -> b.getDocumentId())
.distinct() .distinct()
.count(); .count();
OcrTrainingRun run = OcrTrainingRun.builder() OcrTrainingRun newRun = OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING) .status(TrainingStatus.RUNNING)
.blockCount(segBlocks.size()) .blockCount(segBlocks.size())
.documentCount((int) documentCount) .documentCount((int) documentCount)
.modelName("blla") .modelName("blla")
.triggeredBy(triggeredBy) .triggeredBy(triggeredBy)
.build(); .build();
run = trainingRunRepository.save(run); return trainingRunRepository.save(newRun);
}));
String runId = run.getId().toString(); String runId = run.getId().toString();
MDC.put("trainingRunId", runId); MDC.put("trainingRunId", runId);
log.info("Started segmentation training run {} with {} segments from {} documents", log.info("Started segmentation training run {} with {} segments from {} documents",
runId, segBlocks.size(), documentCount); runId, run.getBlockCount(), run.getDocumentCount());
try { try {
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
@@ -147,24 +159,28 @@ public class OcrTrainingService {
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length); log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes); OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
run.setStatus(TrainingStatus.DONE); return Objects.requireNonNull(txTemplate.execute(status -> {
run.setCompletedAt(Instant.now()); run.setStatus(TrainingStatus.DONE);
run.setLoss(result.loss()); run.setCompletedAt(Instant.now());
run.setAccuracy(result.accuracy()); run.setLoss(result.loss());
run.setEpochs(result.epochs()); run.setAccuracy(result.accuracy());
run = trainingRunRepository.save(run); run.setEpochs(result.epochs());
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs()); OcrTrainingRun updated = trainingRunRepository.save(run);
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
return updated;
}));
} catch (Exception e) { } catch (Exception e) {
run.setStatus(TrainingStatus.FAILED); return Objects.requireNonNull(txTemplate.execute(status -> {
run.setErrorMessage(e.getMessage()); run.setStatus(TrainingStatus.FAILED);
run.setCompletedAt(Instant.now()); run.setErrorMessage(e.getMessage());
run = trainingRunRepository.save(run); run.setCompletedAt(Instant.now());
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e); OcrTrainingRun failed = trainingRunRepository.save(run);
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
return failed;
}));
} finally { } finally {
MDC.remove("trainingRunId"); MDC.remove("trainingRunId");
} }
return run;
} }
public TrainingInfoResponse getTrainingInfo() { public TrainingInfoResponse getTrainingInfo() {
@@ -174,7 +190,7 @@ public class OcrTrainingService {
.distinct() .distinct()
.count(); .count();
int totalOcrBlocks = blockRepository.findAll().size(); int totalOcrBlocks = (int) blockRepository.count();
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size(); int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc(); List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();

View File

@@ -9,6 +9,8 @@ import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
import java.time.Instant; import java.time.Instant;
import java.util.List; import java.util.List;
@@ -29,6 +31,7 @@ class OcrTrainingServiceTest {
OcrClient ocrClient; OcrClient ocrClient;
OcrHealthClient healthClient; OcrHealthClient healthClient;
TranscriptionBlockRepository blockRepository; TranscriptionBlockRepository blockRepository;
TransactionTemplate txTemplate;
OcrTrainingService service; OcrTrainingService service;
@BeforeEach @BeforeEach
@@ -39,10 +42,17 @@ class OcrTrainingServiceTest {
ocrClient = mock(OcrClient.class); ocrClient = mock(OcrClient.class);
healthClient = mock(OcrHealthClient.class); healthClient = mock(OcrHealthClient.class);
blockRepository = mock(TranscriptionBlockRepository.class); blockRepository = mock(TranscriptionBlockRepository.class);
txTemplate = mock(TransactionTemplate.class);
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository); // Execute transaction callbacks inline so unit tests run without a real DataSource
when(txTemplate.execute(any())).thenAnswer(inv -> {
TransactionCallback<?> callback = inv.getArgument(0);
return callback.doInTransaction(null);
});
when(blockRepository.findAll()).thenReturn(List.of()); service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
when(blockRepository.count()).thenReturn(0L);
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of()); when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
when(segExportService.querySegmentationBlocks()).thenReturn(List.of()); when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
} }