diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java index e241aaa3..9c8f4e5d 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java @@ -13,11 +13,13 @@ import org.springframework.boot.context.event.ApplicationReadyEvent; import org.springframework.context.event.EventListener; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import org.springframework.transaction.support.TransactionTemplate; import java.io.ByteArrayOutputStream; import java.time.Instant; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.UUID; @Service @@ -31,6 +33,7 @@ public class OcrTrainingService { private final OcrClient ocrClient; private final OcrHealthClient ocrHealthClient; private final TranscriptionBlockRepository blockRepository; + private final TransactionTemplate txTemplate; public record TrainingInfoResponse( int availableBlocks, @@ -46,37 +49,40 @@ public class OcrTrainingService { // Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique // index) prevents concurrent training API calls, but cannot prevent two OCR service replicas // from diverging on model state. Deploy as a single instance only. See ADR-001. - @Transactional public OcrTrainingRun triggerTraining(UUID triggeredBy) { - if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { - throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, - "A training run is already in progress"); - } + // Short transaction: guard check + create RUNNING row, then commit immediately. + // The DB connection is released before the OCR HTTP call, which can take several minutes. + OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { + if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { + throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, + "A training run is already in progress"); + } - var eligibleBlocks = trainingDataExportService.queryEligibleBlocks(); - if (eligibleBlocks.size() < 5) { - throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR, - "At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")"); - } + var eligibleBlocks = trainingDataExportService.queryEligibleBlocks(); + if (eligibleBlocks.size() < 5) { + throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR, + "At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")"); + } - long documentCount = eligibleBlocks.stream() - .map(b -> b.getDocumentId()) - .distinct() - .count(); + long documentCount = eligibleBlocks.stream() + .map(b -> b.getDocumentId()) + .distinct() + .count(); - OcrTrainingRun run = OcrTrainingRun.builder() - .status(TrainingStatus.RUNNING) - .blockCount(eligibleBlocks.size()) - .documentCount((int) documentCount) - .modelName("german_kurrent") - .triggeredBy(triggeredBy) - .build(); - run = trainingRunRepository.save(run); + OcrTrainingRun newRun = OcrTrainingRun.builder() + .status(TrainingStatus.RUNNING) + .blockCount(eligibleBlocks.size()) + .documentCount((int) documentCount) + .modelName("german_kurrent") + .triggeredBy(triggeredBy) + .build(); + return trainingRunRepository.save(newRun); + })); String runId = run.getId().toString(); MDC.put("trainingRunId", runId); log.info("Started training run {} with {} blocks from {} documents", - runId, eligibleBlocks.size(), documentCount); + runId, run.getBlockCount(), run.getDocumentCount()); try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); @@ -86,58 +92,64 @@ public class OcrTrainingService { log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length); OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes); - run.setStatus(TrainingStatus.DONE); - run.setCompletedAt(Instant.now()); - run.setCer(result.cer()); - run.setLoss(result.loss()); - run.setAccuracy(result.accuracy()); - run.setEpochs(result.epochs()); - run = trainingRunRepository.save(run); - log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs()); + return Objects.requireNonNull(txTemplate.execute(status -> { + run.setStatus(TrainingStatus.DONE); + run.setCompletedAt(Instant.now()); + run.setCer(result.cer()); + run.setLoss(result.loss()); + run.setAccuracy(result.accuracy()); + run.setEpochs(result.epochs()); + OcrTrainingRun updated = trainingRunRepository.save(run); + log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs()); + return updated; + })); } catch (Exception e) { - run.setStatus(TrainingStatus.FAILED); - run.setErrorMessage(e.getMessage()); - run.setCompletedAt(Instant.now()); - run = trainingRunRepository.save(run); - log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e); + return Objects.requireNonNull(txTemplate.execute(status -> { + run.setStatus(TrainingStatus.FAILED); + run.setErrorMessage(e.getMessage()); + run.setCompletedAt(Instant.now()); + OcrTrainingRun failed = trainingRunRepository.save(run); + log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e); + return failed; + })); } finally { MDC.remove("trainingRunId"); } - - return run; } - @Transactional public OcrTrainingRun triggerSegTraining(UUID triggeredBy) { - if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { - throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, - "A training run is already in progress"); - } + // Same pattern as triggerTraining: narrow transactions around DB writes only. + OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { + if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { + throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING, + "A training run is already in progress"); + } - var segBlocks = segmentationTrainingExportService.querySegmentationBlocks(); - if (segBlocks.size() < 5) { - throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR, - "At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")"); - } + var segBlocks = segmentationTrainingExportService.querySegmentationBlocks(); + if (segBlocks.size() < 5) { + throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR, + "At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")"); + } - long documentCount = segBlocks.stream() - .map(b -> b.getDocumentId()) - .distinct() - .count(); + long documentCount = segBlocks.stream() + .map(b -> b.getDocumentId()) + .distinct() + .count(); - OcrTrainingRun run = OcrTrainingRun.builder() - .status(TrainingStatus.RUNNING) - .blockCount(segBlocks.size()) - .documentCount((int) documentCount) - .modelName("blla") - .triggeredBy(triggeredBy) - .build(); - run = trainingRunRepository.save(run); + OcrTrainingRun newRun = OcrTrainingRun.builder() + .status(TrainingStatus.RUNNING) + .blockCount(segBlocks.size()) + .documentCount((int) documentCount) + .modelName("blla") + .triggeredBy(triggeredBy) + .build(); + return trainingRunRepository.save(newRun); + })); String runId = run.getId().toString(); MDC.put("trainingRunId", runId); log.info("Started segmentation training run {} with {} segments from {} documents", - runId, segBlocks.size(), documentCount); + runId, run.getBlockCount(), run.getDocumentCount()); try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); @@ -147,24 +159,28 @@ public class OcrTrainingService { log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length); OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes); - run.setStatus(TrainingStatus.DONE); - run.setCompletedAt(Instant.now()); - run.setLoss(result.loss()); - run.setAccuracy(result.accuracy()); - run.setEpochs(result.epochs()); - run = trainingRunRepository.save(run); - log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs()); + return Objects.requireNonNull(txTemplate.execute(status -> { + run.setStatus(TrainingStatus.DONE); + run.setCompletedAt(Instant.now()); + run.setLoss(result.loss()); + run.setAccuracy(result.accuracy()); + run.setEpochs(result.epochs()); + OcrTrainingRun updated = trainingRunRepository.save(run); + log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs()); + return updated; + })); } catch (Exception e) { - run.setStatus(TrainingStatus.FAILED); - run.setErrorMessage(e.getMessage()); - run.setCompletedAt(Instant.now()); - run = trainingRunRepository.save(run); - log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e); + return Objects.requireNonNull(txTemplate.execute(status -> { + run.setStatus(TrainingStatus.FAILED); + run.setErrorMessage(e.getMessage()); + run.setCompletedAt(Instant.now()); + OcrTrainingRun failed = trainingRunRepository.save(run); + log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e); + return failed; + })); } finally { MDC.remove("trainingRunId"); } - - return run; } public TrainingInfoResponse getTrainingInfo() { @@ -174,7 +190,7 @@ public class OcrTrainingService { .distinct() .count(); - int totalOcrBlocks = blockRepository.findAll().size(); + int totalOcrBlocks = (int) blockRepository.count(); int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size(); List recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc(); diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java index f94e6e8f..7d83188f 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -9,6 +9,8 @@ import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.springframework.transaction.support.TransactionCallback; +import org.springframework.transaction.support.TransactionTemplate; import java.time.Instant; import java.util.List; @@ -29,6 +31,7 @@ class OcrTrainingServiceTest { OcrClient ocrClient; OcrHealthClient healthClient; TranscriptionBlockRepository blockRepository; + TransactionTemplate txTemplate; OcrTrainingService service; @BeforeEach @@ -39,10 +42,17 @@ class OcrTrainingServiceTest { ocrClient = mock(OcrClient.class); healthClient = mock(OcrHealthClient.class); blockRepository = mock(TranscriptionBlockRepository.class); + txTemplate = mock(TransactionTemplate.class); - service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository); + // Execute transaction callbacks inline so unit tests run without a real DataSource + when(txTemplate.execute(any())).thenAnswer(inv -> { + TransactionCallback callback = inv.getArgument(0); + return callback.doInTransaction(null); + }); - when(blockRepository.findAll()).thenReturn(List.of()); + service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); + + when(blockRepository.count()).thenReturn(0L); when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of()); when(segExportService.querySegmentationBlocks()).thenReturn(List.of()); }