diff --git a/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java b/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java index 99ff81da..b3878164 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/repository/OcrTrainingRunRepository.java @@ -14,6 +14,8 @@ public interface OcrTrainingRunRepository extends JpaRepository findFirstByStatusOrderByCreatedAtAsc(TrainingStatus status); + Optional findFirstByPersonIdAndStatus(UUID personId, TrainingStatus status); + boolean existsByPersonIdAndStatus(UUID personId, TrainingStatus status); List findTop20ByOrderByCreatedAtDesc(); diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java index 465d6a8e..19a796db 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java @@ -72,9 +72,10 @@ public class SenderModelService { } /** - * Inserts a QUEUED row when training is already running (coalescing: at most one QUEUED per - * person_id). Runs immediately if idle. Returns true if training should start now, false if - * queued or skipped. + * Atomically checks the queue state and either creates a RUNNING row (returns true) or a + * QUEUED row (returns false). All three operations — idle check, duplicate-queue guard, and + * RUNNING row creation — happen in one transaction, eliminating the race window that would + * otherwise exist between the check and a separate RUNNING row creation. */ @Transactional public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) { @@ -96,23 +97,29 @@ public class SenderModelService { return false; } + long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId); + trainingRunRepository.save(OcrTrainingRun.builder() + .status(TrainingStatus.RUNNING) + .personId(personId) + .blockCount((int) blockCount) + .documentCount(0) + .modelName("sender_" + personId) + .build()); return true; } - /** Executes sender training synchronously. Caller must run this on a background thread. */ + /** + * Executes sender training synchronously. Caller must run this on a background thread. + * The RUNNING row is expected to already exist — created atomically by + * runOrQueueSenderTraining (for new runs) or by promoteNextQueuedRun (for promoted runs). + */ public void triggerSenderTraining(UUID personId, int correctedLines) { String outputModelPath = "/app/models/sender_" + personId + ".mlmodel"; - OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { - long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId); - return trainingRunRepository.save(OcrTrainingRun.builder() - .status(TrainingStatus.RUNNING) - .personId(personId) - .blockCount((int) blockCount) - .documentCount(0) - .modelName("sender_" + personId) - .build()); - })); + OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> + trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING) + .orElseThrow(() -> new IllegalStateException( + "Expected RUNNING row for person " + personId + " but none found")))); String runId = run.getId().toString(); MDC.put("trainingRunId", runId); diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java index f46c5905..e08aa7b5 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java @@ -176,12 +176,38 @@ class SenderModelServiceTest { verify(trainingRunRepository, never()).save(any()); } + @Test + void runOrQueueSenderTraining_createsRunningRow_atomically() { + // RUNNING row must be created inside the same @Transactional boundary as the idle check, + // eliminating the race window between the check and a separate triggerSenderTraining call. + when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L); + when(trainingRunRepository.save(any())).thenAnswer(inv -> { + OcrTrainingRun r = inv.getArgument(0); + if (r.getId() == null) r.setId(UUID.randomUUID()); + return r; + }); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isTrue(); + verify(trainingRunRepository).save(argThat(r -> + r.getStatus() == TrainingStatus.RUNNING && personId.equals(r.getPersonId()))); + } + // ─── triggerSenderTraining — queue promotion ────────────────────────────── @Test void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception { UUID nextPersonId = UUID.randomUUID(); + OcrTrainingRun runningRun = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .personId(personId).blockCount(10).documentCount(0) + .modelName("sender_" + personId).build(); + when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING)) + .thenReturn(Optional.of(runningRun)); when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(10L); when(trainingRunRepository.save(any())).thenAnswer(inv -> { OcrTrainingRun r = inv.getArgument(0);