From dc283ba271186144039ef2b016ab793e11193fbb Mon Sep 17 00:00:00 2001 From: Marcel Date: Tue, 14 Apr 2026 09:59:12 +0200 Subject: [PATCH] fix(training): remove @Transactional from triggerTraining to avoid holding DB connection during OCR HTTP call OcrTrainingService.triggerTraining() and triggerSegTraining() held a DB connection open for the entire ketos training run (potentially minutes), risking connection pool exhaustion. Replaced class-level @Transactional with TransactionTemplate for narrow DB writes: guard+create and result-record each run in their own short transaction; the HTTP call to the OCR service runs between them with no open connection. Also replaces blockRepository.findAll().size() with blockRepository.count() in getTrainingInfo() to avoid loading every block into heap on each poll. Co-Authored-By: Claude Sonnet 4.6 --- .../service/OcrTrainingService.java | 168 ++++++++++-------- .../service/OcrTrainingServiceTest.java | 14 +- 2 files changed, 104 insertions(+), 78 deletions(-) diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java index e241aaa3..9c8f4e5d 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java @@ -13,11 +13,13 @@ import org.springframework.boot.context.event.ApplicationReadyEvent; import org.springframework.context.event.EventListener; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import org.springframework.transaction.support.TransactionTemplate; import java.io.ByteArrayOutputStream; import java.time.Instant; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.UUID; @Service @@ -31,6 +33,7 @@ public class OcrTrainingService { private final OcrClient ocrClient; private final OcrHealthClient ocrHealthClient; private final TranscriptionBlockRepository blockRepository; + private final TransactionTemplate txTemplate; public record TrainingInfoResponse( int availableBlocks, @@ -46,37 +49,40 @@ public class OcrTrainingService { // Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique // index) prevents concurrent training API calls, but cannot prevent two OCR service replicas // from diverging on model state. Deploy as a single instance only. See ADR-001. - @Transactional public OcrTrainingRun triggerTraining(UUID triggeredBy) { - if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { - throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, - "A training run is already in progress"); - } + // Short transaction: guard check + create RUNNING row, then commit immediately. + // The DB connection is released before the OCR HTTP call, which can take several minutes. + OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { + if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { + throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, + "A training run is already in progress"); + } - var eligibleBlocks = trainingDataExportService.queryEligibleBlocks(); - if (eligibleBlocks.size() < 5) { - throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR, - "At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")"); - } + var eligibleBlocks = trainingDataExportService.queryEligibleBlocks(); + if (eligibleBlocks.size() < 5) { + throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR, + "At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")"); + } - long documentCount = eligibleBlocks.stream() - .map(b -> b.getDocumentId()) - .distinct() - .count(); + long documentCount = eligibleBlocks.stream() + .map(b -> b.getDocumentId()) + .distinct() + .count(); - OcrTrainingRun run = OcrTrainingRun.builder() - .status(TrainingStatus.RUNNING) - .blockCount(eligibleBlocks.size()) - .documentCount((int) documentCount) - .modelName("german_kurrent") - .triggeredBy(triggeredBy) - .build(); - run = trainingRunRepository.save(run); + OcrTrainingRun newRun = OcrTrainingRun.builder() + .status(TrainingStatus.RUNNING) + .blockCount(eligibleBlocks.size()) + .documentCount((int) documentCount) + .modelName("german_kurrent") + .triggeredBy(triggeredBy) + .build(); + return trainingRunRepository.save(newRun); + })); String runId = run.getId().toString(); MDC.put("trainingRunId", runId); log.info("Started training run {} with {} blocks from {} documents", - runId, eligibleBlocks.size(), documentCount); + runId, run.getBlockCount(), run.getDocumentCount()); try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); @@ -86,58 +92,64 @@ public class OcrTrainingService { log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length); OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes); - run.setStatus(TrainingStatus.DONE); - run.setCompletedAt(Instant.now()); - run.setCer(result.cer()); - run.setLoss(result.loss()); - run.setAccuracy(result.accuracy()); - run.setEpochs(result.epochs()); - run = trainingRunRepository.save(run); - log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs()); + return Objects.requireNonNull(txTemplate.execute(status -> { + run.setStatus(TrainingStatus.DONE); + run.setCompletedAt(Instant.now()); + run.setCer(result.cer()); + run.setLoss(result.loss()); + run.setAccuracy(result.accuracy()); + run.setEpochs(result.epochs()); + OcrTrainingRun updated = trainingRunRepository.save(run); + log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs()); + return updated; + })); } catch (Exception e) { - run.setStatus(TrainingStatus.FAILED); - run.setErrorMessage(e.getMessage()); - run.setCompletedAt(Instant.now()); - run = trainingRunRepository.save(run); - log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e); + return Objects.requireNonNull(txTemplate.execute(status -> { + run.setStatus(TrainingStatus.FAILED); + run.setErrorMessage(e.getMessage()); + run.setCompletedAt(Instant.now()); + OcrTrainingRun failed = trainingRunRepository.save(run); + log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e); + return failed; + })); } finally { MDC.remove("trainingRunId"); } - - return run; } - @Transactional public OcrTrainingRun triggerSegTraining(UUID triggeredBy) { - if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { - throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, - "A training run is already in progress"); - } + // Same pattern as triggerTraining: narrow transactions around DB writes only. + OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { + if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { + throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, + "A training run is already in progress"); + } - var segBlocks = segmentationTrainingExportService.querySegmentationBlocks(); - if (segBlocks.size() < 5) { - throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR, - "At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")"); - } + var segBlocks = segmentationTrainingExportService.querySegmentationBlocks(); + if (segBlocks.size() < 5) { + throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR, + "At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")"); + } - long documentCount = segBlocks.stream() - .map(b -> b.getDocumentId()) - .distinct() - .count(); + long documentCount = segBlocks.stream() + .map(b -> b.getDocumentId()) + .distinct() + .count(); - OcrTrainingRun run = OcrTrainingRun.builder() - .status(TrainingStatus.RUNNING) - .blockCount(segBlocks.size()) - .documentCount((int) documentCount) - .modelName("blla") - .triggeredBy(triggeredBy) - .build(); - run = trainingRunRepository.save(run); + OcrTrainingRun newRun = OcrTrainingRun.builder() + .status(TrainingStatus.RUNNING) + .blockCount(segBlocks.size()) + .documentCount((int) documentCount) + .modelName("blla") + .triggeredBy(triggeredBy) + .build(); + return trainingRunRepository.save(newRun); + })); String runId = run.getId().toString(); MDC.put("trainingRunId", runId); log.info("Started segmentation training run {} with {} segments from {} documents", - runId, segBlocks.size(), documentCount); + runId, run.getBlockCount(), run.getDocumentCount()); try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); @@ -147,24 +159,28 @@ public class OcrTrainingService { log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length); OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes); - run.setStatus(TrainingStatus.DONE); - run.setCompletedAt(Instant.now()); - run.setLoss(result.loss()); - run.setAccuracy(result.accuracy()); - run.setEpochs(result.epochs()); - run = trainingRunRepository.save(run); - log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs()); + return Objects.requireNonNull(txTemplate.execute(status -> { + run.setStatus(TrainingStatus.DONE); + run.setCompletedAt(Instant.now()); + run.setLoss(result.loss()); + run.setAccuracy(result.accuracy()); + run.setEpochs(result.epochs()); + OcrTrainingRun updated = trainingRunRepository.save(run); + log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs()); + return updated; + })); } catch (Exception e) { - run.setStatus(TrainingStatus.FAILED); - run.setErrorMessage(e.getMessage()); - run.setCompletedAt(Instant.now()); - run = trainingRunRepository.save(run); - log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e); + return Objects.requireNonNull(txTemplate.execute(status -> { + run.setStatus(TrainingStatus.FAILED); + run.setErrorMessage(e.getMessage()); + run.setCompletedAt(Instant.now()); + OcrTrainingRun failed = trainingRunRepository.save(run); + log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e); + return failed; + })); } finally { MDC.remove("trainingRunId"); } - - return run; } public TrainingInfoResponse getTrainingInfo() { @@ -174,7 +190,7 @@ public class OcrTrainingService { .distinct() .count(); - int totalOcrBlocks = blockRepository.findAll().size(); + int totalOcrBlocks = (int) blockRepository.count(); int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size(); List recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc(); diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java index f94e6e8f..7d83188f 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -9,6 +9,8 @@ import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.springframework.transaction.support.TransactionCallback; +import org.springframework.transaction.support.TransactionTemplate; import java.time.Instant; import java.util.List; @@ -29,6 +31,7 @@ class OcrTrainingServiceTest { OcrClient ocrClient; OcrHealthClient healthClient; TranscriptionBlockRepository blockRepository; + TransactionTemplate txTemplate; OcrTrainingService service; @BeforeEach @@ -39,10 +42,17 @@ class OcrTrainingServiceTest { ocrClient = mock(OcrClient.class); healthClient = mock(OcrHealthClient.class); blockRepository = mock(TranscriptionBlockRepository.class); + txTemplate = mock(TransactionTemplate.class); - service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository); + // Execute transaction callbacks inline so unit tests run without a real DataSource + when(txTemplate.execute(any())).thenAnswer(inv -> { + TransactionCallback callback = inv.getArgument(0); + return callback.doInTransaction(null); + }); - when(blockRepository.findAll()).thenReturn(List.of()); + service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); + + when(blockRepository.count()).thenReturn(0L); when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of()); when(segExportService.querySegmentationBlocks()).thenReturn(List.of()); }