fix(ocr): eliminate race window in runOrQueueSenderTraining by creating RUNNING row atomically

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-17 20:11:56 +02:00
committed by marcel
parent 68ec66002a
commit 3ecda655c5
3 changed files with 49 additions and 14 deletions

View File

@@ -14,6 +14,8 @@ public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun,
Optional<OcrTrainingRun> findFirstByStatusOrderByCreatedAtAsc(TrainingStatus status); Optional<OcrTrainingRun> findFirstByStatusOrderByCreatedAtAsc(TrainingStatus status);
Optional<OcrTrainingRun> findFirstByPersonIdAndStatus(UUID personId, TrainingStatus status);
boolean existsByPersonIdAndStatus(UUID personId, TrainingStatus status); boolean existsByPersonIdAndStatus(UUID personId, TrainingStatus status);
List<OcrTrainingRun> findTop20ByOrderByCreatedAtDesc(); List<OcrTrainingRun> findTop20ByOrderByCreatedAtDesc();

View File

@@ -72,9 +72,10 @@ public class SenderModelService {
} }
/** /**
* Inserts a QUEUED row when training is already running (coalescing: at most one QUEUED per * Atomically checks the queue state and either creates a RUNNING row (returns true) or a
* person_id). Runs immediately if idle. Returns true if training should start now, false if * QUEUED row (returns false). All three operations — idle check, duplicate-queue guard, and
* queued or skipped. * RUNNING row creation — happen in one transaction, eliminating the race window that would
* otherwise exist between the check and a separate RUNNING row creation.
*/ */
@Transactional @Transactional
public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) { public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) {
@@ -96,23 +97,29 @@ public class SenderModelService {
return false; return false;
} }
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId);
trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.personId(personId)
.blockCount((int) blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
return true; return true;
} }
/** Executes sender training synchronously. Caller must run this on a background thread. */ /**
* Executes sender training synchronously. Caller must run this on a background thread.
* The RUNNING row is expected to already exist — created atomically by
* runOrQueueSenderTraining (for new runs) or by promoteNextQueuedRun (for promoted runs).
*/
public void triggerSenderTraining(UUID personId, int correctedLines) { public void triggerSenderTraining(UUID personId, int correctedLines) {
String outputModelPath = "/app/models/sender_" + personId + ".mlmodel"; String outputModelPath = "/app/models/sender_" + personId + ".mlmodel";
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status ->
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId); trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING)
return trainingRunRepository.save(OcrTrainingRun.builder() .orElseThrow(() -> new IllegalStateException(
.status(TrainingStatus.RUNNING) "Expected RUNNING row for person " + personId + " but none found"))));
.personId(personId)
.blockCount((int) blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
}));
String runId = run.getId().toString(); String runId = run.getId().toString();
MDC.put("trainingRunId", runId); MDC.put("trainingRunId", runId);

View File

@@ -176,12 +176,38 @@ class SenderModelServiceTest {
verify(trainingRunRepository, never()).save(any()); verify(trainingRunRepository, never()).save(any());
} }
@Test
void runOrQueueSenderTraining_createsRunningRow_atomically() {
// RUNNING row must be created inside the same @Transactional boundary as the idle check,
// eliminating the race window between the check and a separate triggerSenderTraining call.
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
OcrTrainingRun r = inv.getArgument(0);
if (r.getId() == null) r.setId(UUID.randomUUID());
return r;
});
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isTrue();
verify(trainingRunRepository).save(argThat(r ->
r.getStatus() == TrainingStatus.RUNNING && personId.equals(r.getPersonId())));
}
// ─── triggerSenderTraining — queue promotion ────────────────────────────── // ─── triggerSenderTraining — queue promotion ──────────────────────────────
@Test @Test
void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception { void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception {
UUID nextPersonId = UUID.randomUUID(); UUID nextPersonId = UUID.randomUUID();
OcrTrainingRun runningRun = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.personId(personId).blockCount(10).documentCount(0)
.modelName("sender_" + personId).build();
when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING))
.thenReturn(Optional.of(runningRun));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(10L); when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(10L);
when(trainingRunRepository.save(any())).thenAnswer(inv -> { when(trainingRunRepository.save(any())).thenAnswer(inv -> {
OcrTrainingRun r = inv.getArgument(0); OcrTrainingRun r = inv.getArgument(0);