fix(ocr): eliminate race window in runOrQueueSenderTraining by creating RUNNING row atomically
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -14,6 +14,8 @@ public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun,
|
|||||||
|
|
||||||
Optional<OcrTrainingRun> findFirstByStatusOrderByCreatedAtAsc(TrainingStatus status);
|
Optional<OcrTrainingRun> findFirstByStatusOrderByCreatedAtAsc(TrainingStatus status);
|
||||||
|
|
||||||
|
Optional<OcrTrainingRun> findFirstByPersonIdAndStatus(UUID personId, TrainingStatus status);
|
||||||
|
|
||||||
boolean existsByPersonIdAndStatus(UUID personId, TrainingStatus status);
|
boolean existsByPersonIdAndStatus(UUID personId, TrainingStatus status);
|
||||||
|
|
||||||
List<OcrTrainingRun> findTop20ByOrderByCreatedAtDesc();
|
List<OcrTrainingRun> findTop20ByOrderByCreatedAtDesc();
|
||||||
|
|||||||
@@ -72,9 +72,10 @@ public class SenderModelService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Inserts a QUEUED row when training is already running (coalescing: at most one QUEUED per
|
* Atomically checks the queue state and either creates a RUNNING row (returns true) or a
|
||||||
* person_id). Runs immediately if idle. Returns true if training should start now, false if
|
* QUEUED row (returns false). All three operations — idle check, duplicate-queue guard, and
|
||||||
* queued or skipped.
|
* RUNNING row creation — happen in one transaction, eliminating the race window that would
|
||||||
|
* otherwise exist between the check and a separate RUNNING row creation.
|
||||||
*/
|
*/
|
||||||
@Transactional
|
@Transactional
|
||||||
public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) {
|
public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) {
|
||||||
@@ -96,23 +97,29 @@ public class SenderModelService {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||||
|
trainingRunRepository.save(OcrTrainingRun.builder()
|
||||||
|
.status(TrainingStatus.RUNNING)
|
||||||
|
.personId(personId)
|
||||||
|
.blockCount((int) blockCount)
|
||||||
|
.documentCount(0)
|
||||||
|
.modelName("sender_" + personId)
|
||||||
|
.build());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Executes sender training synchronously. Caller must run this on a background thread. */
|
/**
|
||||||
|
* Executes sender training synchronously. Caller must run this on a background thread.
|
||||||
|
* The RUNNING row is expected to already exist — created atomically by
|
||||||
|
* runOrQueueSenderTraining (for new runs) or by promoteNextQueuedRun (for promoted runs).
|
||||||
|
*/
|
||||||
public void triggerSenderTraining(UUID personId, int correctedLines) {
|
public void triggerSenderTraining(UUID personId, int correctedLines) {
|
||||||
String outputModelPath = "/app/models/sender_" + personId + ".mlmodel";
|
String outputModelPath = "/app/models/sender_" + personId + ".mlmodel";
|
||||||
|
|
||||||
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status ->
|
||||||
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId);
|
trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING)
|
||||||
return trainingRunRepository.save(OcrTrainingRun.builder()
|
.orElseThrow(() -> new IllegalStateException(
|
||||||
.status(TrainingStatus.RUNNING)
|
"Expected RUNNING row for person " + personId + " but none found"))));
|
||||||
.personId(personId)
|
|
||||||
.blockCount((int) blockCount)
|
|
||||||
.documentCount(0)
|
|
||||||
.modelName("sender_" + personId)
|
|
||||||
.build());
|
|
||||||
}));
|
|
||||||
|
|
||||||
String runId = run.getId().toString();
|
String runId = run.getId().toString();
|
||||||
MDC.put("trainingRunId", runId);
|
MDC.put("trainingRunId", runId);
|
||||||
|
|||||||
@@ -176,12 +176,38 @@ class SenderModelServiceTest {
|
|||||||
verify(trainingRunRepository, never()).save(any());
|
verify(trainingRunRepository, never()).save(any());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void runOrQueueSenderTraining_createsRunningRow_atomically() {
|
||||||
|
// RUNNING row must be created inside the same @Transactional boundary as the idle check,
|
||||||
|
// eliminating the race window between the check and a separate triggerSenderTraining call.
|
||||||
|
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
|
||||||
|
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||||
|
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
|
||||||
|
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
|
||||||
|
OcrTrainingRun r = inv.getArgument(0);
|
||||||
|
if (r.getId() == null) r.setId(UUID.randomUUID());
|
||||||
|
return r;
|
||||||
|
});
|
||||||
|
|
||||||
|
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||||
|
|
||||||
|
assertThat(result).isTrue();
|
||||||
|
verify(trainingRunRepository).save(argThat(r ->
|
||||||
|
r.getStatus() == TrainingStatus.RUNNING && personId.equals(r.getPersonId())));
|
||||||
|
}
|
||||||
|
|
||||||
// ─── triggerSenderTraining — queue promotion ──────────────────────────────
|
// ─── triggerSenderTraining — queue promotion ──────────────────────────────
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception {
|
void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception {
|
||||||
UUID nextPersonId = UUID.randomUUID();
|
UUID nextPersonId = UUID.randomUUID();
|
||||||
|
|
||||||
|
OcrTrainingRun runningRun = OcrTrainingRun.builder()
|
||||||
|
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||||
|
.personId(personId).blockCount(10).documentCount(0)
|
||||||
|
.modelName("sender_" + personId).build();
|
||||||
|
when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING))
|
||||||
|
.thenReturn(Optional.of(runningRun));
|
||||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(10L);
|
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(10L);
|
||||||
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
|
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
|
||||||
OcrTrainingRun r = inv.getArgument(0);
|
OcrTrainingRun r = inv.getArgument(0);
|
||||||
|
|||||||
Reference in New Issue
Block a user