diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java index c41542df..ff80e1e6 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java @@ -5,10 +5,8 @@ import lombok.extern.slf4j.Slf4j; import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.exception.ErrorCode; import org.raddatz.familienarchiv.model.OcrTrainingRun; -import org.raddatz.familienarchiv.model.SenderModel; import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; -import org.raddatz.familienarchiv.repository.SenderModelRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.slf4j.MDC; import org.springframework.boot.context.event.ApplicationReadyEvent; @@ -30,7 +28,6 @@ import java.util.UUID; public class OcrTrainingService { private final OcrTrainingRunRepository trainingRunRepository; - private final SenderModelRepository senderModelRepository; private final TrainingDataExportService trainingDataExportService; private final SegmentationTrainingExportService segmentationTrainingExportService; private final OcrClient ocrClient; @@ -212,110 +209,6 @@ public class OcrTrainingService { ); } - /** - * Called from SenderModelService (which is @Async). Inserts a QUEUED row when training - * is already running (coalescing: at most one QUEUED per person_id). Runs immediately if idle. - * Returns true if training was started, false if queued or skipped. - */ - @Transactional - public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) { - if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) { - log.info("Sender training already queued for person {} — skipping duplicate trigger", personId); - return false; - } - - if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { - int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId); - trainingRunRepository.save(OcrTrainingRun.builder() - .status(TrainingStatus.QUEUED) - .personId(personId) - .blockCount(blockCount) - .documentCount(0) - .modelName("sender_" + personId) - .build()); - log.info("Queued sender training for person {} — training already running", personId); - return false; - } - - return true; - } - - /** Executes sender training synchronously. Caller must run this on a background thread. */ - public void triggerSenderTraining(UUID personId, int correctedLines) { - String outputModelPath = "/app/models/sender_" + personId + ".mlmodel"; - - OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { - long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId); - return trainingRunRepository.save(OcrTrainingRun.builder() - .status(TrainingStatus.RUNNING) - .personId(personId) - .blockCount((int) blockCount) - .documentCount(0) - .modelName("sender_" + personId) - .build()); - })); - - String runId = run.getId().toString(); - MDC.put("trainingRunId", runId); - log.info("Started sender training run {} for person {}", runId, personId); - - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - trainingDataExportService.exportForSender(personId).writeTo(baos); - byte[] zipBytes = baos.toByteArray(); - - log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length); - OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath); - - txTemplate.execute(status -> { - SenderModel model = senderModelRepository.findByPersonId(personId) - .orElseGet(() -> SenderModel.builder().personId(personId).build()); - model.setModelPath(outputModelPath); - model.setCer(result.cer()); - model.setAccuracy(result.accuracy()); - model.setCorrectedLinesAtTraining(correctedLines); - senderModelRepository.save(model); - - run.setStatus(TrainingStatus.DONE); - run.setCompletedAt(Instant.now()); - run.setCer(result.cer()); - run.setAccuracy(result.accuracy()); - run.setEpochs(result.epochs()); - trainingRunRepository.save(run); - log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer()); - return null; - }); - } catch (Exception e) { - txTemplate.execute(status -> { - run.setStatus(TrainingStatus.FAILED); - run.setErrorMessage(e.getMessage()); - run.setCompletedAt(Instant.now()); - trainingRunRepository.save(run); - log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e); - return null; - }); - } finally { - MDC.remove("trainingRunId"); - promoteNextQueuedRun(); - } - } - - private void promoteNextQueuedRun() { - Optional queuedOpt = txTemplate.execute(status -> - trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED) - .map(queued -> { - queued.setStatus(TrainingStatus.RUNNING); - return trainingRunRepository.save(queued); - })); - - if (queuedOpt != null && queuedOpt.isPresent()) { - OcrTrainingRun promoted = queuedOpt.get(); - log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId()); - long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId()); - triggerSenderTraining(promoted.getPersonId(), (int) freshCount); - } - } - @EventListener(ApplicationReadyEvent.class) @Transactional public void recoverOrphanedRuns() { diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java new file mode 100644 index 00000000..465d6a8e --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java @@ -0,0 +1,182 @@ +package org.raddatz.familienarchiv.service; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.raddatz.familienarchiv.model.OcrTrainingRun; +import org.raddatz.familienarchiv.model.SenderModel; +import org.raddatz.familienarchiv.model.TrainingStatus; +import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; +import org.raddatz.familienarchiv.repository.SenderModelRepository; +import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.slf4j.MDC; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; +import org.springframework.transaction.support.TransactionTemplate; + +import java.io.ByteArrayOutputStream; +import java.time.Instant; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; + +@Service +@RequiredArgsConstructor +@Slf4j +public class SenderModelService { + + private final SenderModelRepository senderModelRepository; + private final TranscriptionBlockRepository blockRepository; + private final OcrTrainingRunRepository trainingRunRepository; + private final OcrClient ocrClient; + private final TransactionTemplate txTemplate; + private final TrainingDataExportService trainingDataExportService; + + @Value("${ocr.sender-model.activation-threshold:100}") + private int activationThreshold; + + @Value("${ocr.sender-model.retrain-delta:50}") + private int retrainDelta; + + /** Returns the model path if a trained sender model exists for this person. */ + public Optional maybeGetModelPath(UUID personId) { + return senderModelRepository.findByPersonId(personId) + .map(SenderModel::getModelPath); + } + + /** + * Called after every MANUAL block save for HANDWRITING_KURRENT documents. + * Checks activation and retrain thresholds; enqueues or starts sender training when met. + */ + @Async + public void checkAndTriggerTraining(UUID personId) { + long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId); + Optional existing = senderModelRepository.findByPersonId(personId); + + boolean shouldActivate = existing.isEmpty() && correctedLines >= activationThreshold; + boolean shouldRetrain = existing.isPresent() + && (correctedLines - existing.get().getCorrectedLinesAtTraining()) >= retrainDelta; + + if (!shouldActivate && !shouldRetrain) { + return; + } + + log.info("Sender training threshold met for person {} (correctedLines={}, activate={}, retrain={})", + personId, correctedLines, shouldActivate, shouldRetrain); + + boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines); + if (runNow) { + triggerSenderTraining(personId, (int) correctedLines); + } + } + + /** + * Inserts a QUEUED row when training is already running (coalescing: at most one QUEUED per + * person_id). Runs immediately if idle. Returns true if training should start now, false if + * queued or skipped. + */ + @Transactional + public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) { + if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) { + log.info("Sender training already queued for person {} — skipping duplicate trigger", personId); + return false; + } + + if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { + int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId); + trainingRunRepository.save(OcrTrainingRun.builder() + .status(TrainingStatus.QUEUED) + .personId(personId) + .blockCount(blockCount) + .documentCount(0) + .modelName("sender_" + personId) + .build()); + log.info("Queued sender training for person {} — training already running", personId); + return false; + } + + return true; + } + + /** Executes sender training synchronously. Caller must run this on a background thread. */ + public void triggerSenderTraining(UUID personId, int correctedLines) { + String outputModelPath = "/app/models/sender_" + personId + ".mlmodel"; + + OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> { + long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId); + return trainingRunRepository.save(OcrTrainingRun.builder() + .status(TrainingStatus.RUNNING) + .personId(personId) + .blockCount((int) blockCount) + .documentCount(0) + .modelName("sender_" + personId) + .build()); + })); + + String runId = run.getId().toString(); + MDC.put("trainingRunId", runId); + log.info("Started sender training run {} for person {}", runId, personId); + + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + trainingDataExportService.exportForSender(personId).writeTo(baos); + byte[] zipBytes = baos.toByteArray(); + + log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length); + OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath); + + txTemplate.execute(status -> { + SenderModel model = senderModelRepository.findByPersonId(personId) + .orElseGet(() -> SenderModel.builder().personId(personId).build()); + model.setModelPath(outputModelPath); + model.setCer(result.cer()); + model.setAccuracy(result.accuracy()); + model.setCorrectedLinesAtTraining(correctedLines); + senderModelRepository.save(model); + + run.setStatus(TrainingStatus.DONE); + run.setCompletedAt(Instant.now()); + run.setCer(result.cer()); + run.setAccuracy(result.accuracy()); + run.setEpochs(result.epochs()); + trainingRunRepository.save(run); + log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer()); + return null; + }); + } catch (Exception e) { + txTemplate.execute(status -> { + run.setStatus(TrainingStatus.FAILED); + run.setErrorMessage(e.getMessage()); + run.setCompletedAt(Instant.now()); + trainingRunRepository.save(run); + log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e); + return null; + }); + } finally { + MDC.remove("trainingRunId"); + promoteNextQueuedRun(); + } + } + + /** + * Promotes the oldest QUEUED sender run to RUNNING and triggers its training. + * Called in the finally block of triggerSenderTraining, creating a sequential chain: + * each run promotes the next only after it fully completes (success or failure). + */ + private void promoteNextQueuedRun() { + Optional queuedOpt = txTemplate.execute(status -> + trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED) + .map(queued -> { + queued.setStatus(TrainingStatus.RUNNING); + return trainingRunRepository.save(queued); + })); + + if (queuedOpt != null && queuedOpt.isPresent()) { + OcrTrainingRun promoted = queuedOpt.get(); + log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId()); + long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId()); + triggerSenderTraining(promoted.getPersonId(), (int) freshCount); + } + } +} diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java index b89e1ae6..e6e43f00 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -4,11 +4,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.model.OcrTrainingRun; -import org.raddatz.familienarchiv.model.TrainingLabel; import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; -import org.raddatz.familienarchiv.repository.SenderModelRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionTemplate; @@ -27,7 +25,6 @@ import static org.mockito.Mockito.*; class OcrTrainingServiceTest { OcrTrainingRunRepository runRepository; - SenderModelRepository senderModelRepository; TrainingDataExportService exportService; SegmentationTrainingExportService segExportService; OcrClient ocrClient; @@ -39,7 +36,6 @@ class OcrTrainingServiceTest { @BeforeEach void setUp() { runRepository = mock(OcrTrainingRunRepository.class); - senderModelRepository = mock(SenderModelRepository.class); exportService = mock(TrainingDataExportService.class); segExportService = mock(SegmentationTrainingExportService.class); ocrClient = mock(OcrClient.class); @@ -53,7 +49,7 @@ class OcrTrainingServiceTest { return callback.doInTransaction(null); }); - service = new OcrTrainingService(runRepository, senderModelRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); + service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); when(blockRepository.count()).thenReturn(0L); when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of()); @@ -233,48 +229,6 @@ class OcrTrainingServiceTest { run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null)); } - // ─── runOrQueueSenderTraining ──────────────────────────────────────────── - - @Test - void runOrQueueSenderTraining_returnsTrue_whenIdle() { - UUID personId = UUID.randomUUID(); - when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); - when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); - - boolean result = service.runOrQueueSenderTraining(personId, 120); - - assertThat(result).isTrue(); - verify(runRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED)); - } - - @Test - void runOrQueueSenderTraining_savesQueued_whenRunning() { - UUID personId = UUID.randomUUID(); - when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); - when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn( - Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING) - .blockCount(5).documentCount(1).modelName("german_kurrent").build())); - when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L); - when(runRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); - - boolean result = service.runOrQueueSenderTraining(personId, 120); - - assertThat(result).isFalse(); - verify(runRepository).save(argThat(r -> - r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId()))); - } - - @Test - void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() { - UUID personId = UUID.randomUUID(); - when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true); - - boolean result = service.runOrQueueSenderTraining(personId, 120); - - assertThat(result).isFalse(); - verify(runRepository, never()).save(any()); - } - // ─── Orphan recovery ────────────────────────────────────────────────────── @Test diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java index b3a3a5d9..f46c5905 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java @@ -2,23 +2,31 @@ package org.raddatz.familienarchiv.service; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.raddatz.familienarchiv.model.OcrTrainingRun; import org.raddatz.familienarchiv.model.SenderModel; +import org.raddatz.familienarchiv.model.TrainingStatus; +import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.SenderModelRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.transaction.support.TransactionCallback; +import org.springframework.transaction.support.TransactionTemplate; import java.util.Optional; import java.util.UUID; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.*; class SenderModelServiceTest { SenderModelRepository senderModelRepository; TranscriptionBlockRepository blockRepository; - OcrTrainingService ocrTrainingService; + OcrTrainingRunRepository trainingRunRepository; + OcrClient ocrClient; + TransactionTemplate txTemplate; + TrainingDataExportService trainingDataExportService; SenderModelService service; UUID personId = UUID.randomUUID(); @@ -27,8 +35,19 @@ class SenderModelServiceTest { void setUp() { senderModelRepository = mock(SenderModelRepository.class); blockRepository = mock(TranscriptionBlockRepository.class); - ocrTrainingService = mock(OcrTrainingService.class); - service = new SenderModelService(senderModelRepository, blockRepository, ocrTrainingService); + trainingRunRepository = mock(OcrTrainingRunRepository.class); + ocrClient = mock(OcrClient.class); + txTemplate = mock(TransactionTemplate.class); + trainingDataExportService = mock(TrainingDataExportService.class); + + // Execute transaction callbacks inline so unit tests run without a real DataSource + when(txTemplate.execute(any())).thenAnswer(inv -> { + TransactionCallback callback = inv.getArgument(0); + return callback.doInTransaction(null); + }); + + service = new SenderModelService(senderModelRepository, blockRepository, + trainingRunRepository, ocrClient, txTemplate, trainingDataExportService); ReflectionTestUtils.setField(service, "activationThreshold", 100); ReflectionTestUtils.setField(service, "retrainDelta", 50); } @@ -40,20 +59,23 @@ class SenderModelServiceTest { when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L); when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); - service.checkAndTriggerTraining(personId); + SenderModelService spy = spy(service); + spy.checkAndTriggerTraining(personId); - verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt()); + verify(spy, never()).runOrQueueSenderTraining(any(), anyInt()); } @Test void checkAndTriggerTraining_triggersTraining_atActivationThreshold() { when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); - when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false); - service.checkAndTriggerTraining(personId); + SenderModelService spy = spy(service); + doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100); - verify(ocrTrainingService).runOrQueueSenderTraining(personId, 100); + spy.checkAndTriggerTraining(personId); + + verify(spy).runOrQueueSenderTraining(personId, 100); } // ─── Retrain delta ──────────────────────────────────────────────────────── @@ -65,9 +87,10 @@ class SenderModelServiceTest { when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing)); when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L); - service.checkAndTriggerTraining(personId); + SenderModelService spy = spy(service); + spy.checkAndTriggerTraining(personId); - verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt()); + verify(spy, never()).runOrQueueSenderTraining(any(), anyInt()); } @Test @@ -76,11 +99,13 @@ class SenderModelServiceTest { .correctedLinesAtTraining(100).build(); when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing)); when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L); - when(ocrTrainingService.runOrQueueSenderTraining(personId, 150)).thenReturn(false); - service.checkAndTriggerTraining(personId); + SenderModelService spy = spy(service); + doReturn(false).when(spy).runOrQueueSenderTraining(personId, 150); - verify(ocrTrainingService).runOrQueueSenderTraining(personId, 150); + spy.checkAndTriggerTraining(personId); + + verify(spy).runOrQueueSenderTraining(personId, 150); } // ─── runNow flag ────────────────────────────────────────────────────────── @@ -89,21 +114,103 @@ class SenderModelServiceTest { void checkAndTriggerTraining_callsTrigger_whenRunNow() { when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); - when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(true); - service.checkAndTriggerTraining(personId); + SenderModelService spy = spy(service); + doReturn(true).when(spy).runOrQueueSenderTraining(personId, 100); + doNothing().when(spy).triggerSenderTraining(personId, 100); - verify(ocrTrainingService).triggerSenderTraining(personId, 100); + spy.checkAndTriggerTraining(personId); + + verify(spy).triggerSenderTraining(personId, 100); } @Test void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() { when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L); when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); - when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false); - service.checkAndTriggerTraining(personId); + SenderModelService spy = spy(service); + doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100); - verify(ocrTrainingService, never()).triggerSenderTraining(any(), anyInt()); + spy.checkAndTriggerTraining(personId); + + verify(spy, never()).triggerSenderTraining(any(), anyInt()); + } + + // ─── runOrQueueSenderTraining ───────────────────────────────────────────── + + @Test + void runOrQueueSenderTraining_returnsTrue_whenIdle() { + when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty()); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isTrue(); + verify(trainingRunRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED)); + } + + @Test + void runOrQueueSenderTraining_savesQueued_whenRunning() { + when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false); + when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn( + Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .blockCount(5).documentCount(1).modelName("german_kurrent").build())); + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L); + when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isFalse(); + verify(trainingRunRepository).save(argThat(r -> + r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId()))); + } + + @Test + void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() { + when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true); + + boolean result = service.runOrQueueSenderTraining(personId, 120); + + assertThat(result).isFalse(); + verify(trainingRunRepository, never()).save(any()); + } + + // ─── triggerSenderTraining — queue promotion ────────────────────────────── + + @Test + void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception { + UUID nextPersonId = UUID.randomUUID(); + + when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(10L); + when(trainingRunRepository.save(any())).thenAnswer(inv -> { + OcrTrainingRun r = inv.getArgument(0); + if (r.getId() == null) r.setId(UUID.randomUUID()); + return r; + }); + when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {}); + when(ocrClient.trainSenderModel(any(), any())) + .thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3)); + when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty()); + + OcrTrainingRun queued = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.QUEUED) + .personId(nextPersonId).blockCount(5).documentCount(0) + .modelName("sender_" + nextPersonId).build(); + when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)) + .thenReturn(Optional.of(queued)); + when(blockRepository.countManualKurrentBlocksByPerson(nextPersonId)).thenReturn(5L); + + SenderModelService spy = spy(service); + // Stub the recursive call to stop the chain after one promotion + doNothing().when(spy).triggerSenderTraining(eq(nextPersonId), anyInt()); + + spy.triggerSenderTraining(personId, 10); + + // Queued run promoted to RUNNING + verify(trainingRunRepository).save(argThat(r -> + nextPersonId.equals(r.getPersonId()) && r.getStatus() == TrainingStatus.RUNNING)); + // Next training was kicked off + verify(spy).triggerSenderTraining(eq(nextPersonId), anyInt()); } }