refactor(ocr): move sender training methods from OcrTrainingService to SenderModelService

Eliminates cross-domain repository access: OcrTrainingService no longer
holds SenderModelRepository. SenderModelService now owns the full sender
training lifecycle (runOrQueueSenderTraining, triggerSenderTraining,
promoteNextQueuedRun), removing the circular dependency risk.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-17 19:08:10 +02:00
parent 60c1ec7b5f
commit 386dc83958
4 changed files with 311 additions and 175 deletions

View File

@@ -5,10 +5,8 @@ import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.slf4j.MDC;
import org.springframework.boot.context.event.ApplicationReadyEvent;
@@ -30,7 +28,6 @@ import java.util.UUID;
public class OcrTrainingService {
private final OcrTrainingRunRepository trainingRunRepository;
private final SenderModelRepository senderModelRepository;
private final TrainingDataExportService trainingDataExportService;
private final SegmentationTrainingExportService segmentationTrainingExportService;
private final OcrClient ocrClient;
@@ -212,110 +209,6 @@ public class OcrTrainingService {
);
}
/**
* Called from SenderModelService (which is @Async). Inserts a QUEUED row when training
* is already running (coalescing: at most one QUEUED per person_id). Runs immediately if idle.
* Returns true if training was started, false if queued or skipped.
*/
@Transactional
public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) {
if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) {
log.info("Sender training already queued for person {} — skipping duplicate trigger", personId);
return false;
}
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId);
trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.QUEUED)
.personId(personId)
.blockCount(blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
log.info("Queued sender training for person {} — training already running", personId);
return false;
}
return true;
}
/** Executes sender training synchronously. Caller must run this on a background thread. */
public void triggerSenderTraining(UUID personId, int correctedLines) {
String outputModelPath = "/app/models/sender_" + personId + ".mlmodel";
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId);
return trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.personId(personId)
.blockCount((int) blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
}));
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started sender training run {} for person {}", runId, personId);
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
trainingDataExportService.exportForSender(personId).writeTo(baos);
byte[] zipBytes = baos.toByteArray();
log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath);
txTemplate.execute(status -> {
SenderModel model = senderModelRepository.findByPersonId(personId)
.orElseGet(() -> SenderModel.builder().personId(personId).build());
model.setModelPath(outputModelPath);
model.setCer(result.cer());
model.setAccuracy(result.accuracy());
model.setCorrectedLinesAtTraining(correctedLines);
senderModelRepository.save(model);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setCer(result.cer());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
trainingRunRepository.save(run);
log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer());
return null;
});
} catch (Exception e) {
txTemplate.execute(status -> {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
trainingRunRepository.save(run);
log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e);
return null;
});
} finally {
MDC.remove("trainingRunId");
promoteNextQueuedRun();
}
}
private void promoteNextQueuedRun() {
Optional<OcrTrainingRun> queuedOpt = txTemplate.execute(status ->
trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)
.map(queued -> {
queued.setStatus(TrainingStatus.RUNNING);
return trainingRunRepository.save(queued);
}));
if (queuedOpt != null && queuedOpt.isPresent()) {
OcrTrainingRun promoted = queuedOpt.get();
log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId());
long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId());
triggerSenderTraining(promoted.getPersonId(), (int) freshCount);
}
}
@EventListener(ApplicationReadyEvent.class)
@Transactional
public void recoverOrphanedRuns() {

View File

@@ -0,0 +1,182 @@
package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.slf4j.MDC;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.TransactionTemplate;
import java.io.ByteArrayOutputStream;
import java.time.Instant;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
@Service
@RequiredArgsConstructor
@Slf4j
public class SenderModelService {
private final SenderModelRepository senderModelRepository;
private final TranscriptionBlockRepository blockRepository;
private final OcrTrainingRunRepository trainingRunRepository;
private final OcrClient ocrClient;
private final TransactionTemplate txTemplate;
private final TrainingDataExportService trainingDataExportService;
@Value("${ocr.sender-model.activation-threshold:100}")
private int activationThreshold;
@Value("${ocr.sender-model.retrain-delta:50}")
private int retrainDelta;
/** Returns the model path if a trained sender model exists for this person. */
public Optional<String> maybeGetModelPath(UUID personId) {
return senderModelRepository.findByPersonId(personId)
.map(SenderModel::getModelPath);
}
/**
* Called after every MANUAL block save for HANDWRITING_KURRENT documents.
* Checks activation and retrain thresholds; enqueues or starts sender training when met.
*/
@Async
public void checkAndTriggerTraining(UUID personId) {
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
Optional<SenderModel> existing = senderModelRepository.findByPersonId(personId);
boolean shouldActivate = existing.isEmpty() && correctedLines >= activationThreshold;
boolean shouldRetrain = existing.isPresent()
&& (correctedLines - existing.get().getCorrectedLinesAtTraining()) >= retrainDelta;
if (!shouldActivate && !shouldRetrain) {
return;
}
log.info("Sender training threshold met for person {} (correctedLines={}, activate={}, retrain={})",
personId, correctedLines, shouldActivate, shouldRetrain);
boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines);
if (runNow) {
triggerSenderTraining(personId, (int) correctedLines);
}
}
/**
* Inserts a QUEUED row when training is already running (coalescing: at most one QUEUED per
* person_id). Runs immediately if idle. Returns true if training should start now, false if
* queued or skipped.
*/
@Transactional
public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) {
if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) {
log.info("Sender training already queued for person {} — skipping duplicate trigger", personId);
return false;
}
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId);
trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.QUEUED)
.personId(personId)
.blockCount(blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
log.info("Queued sender training for person {} — training already running", personId);
return false;
}
return true;
}
/** Executes sender training synchronously. Caller must run this on a background thread. */
public void triggerSenderTraining(UUID personId, int correctedLines) {
String outputModelPath = "/app/models/sender_" + personId + ".mlmodel";
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId);
return trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.personId(personId)
.blockCount((int) blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
}));
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started sender training run {} for person {}", runId, personId);
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
trainingDataExportService.exportForSender(personId).writeTo(baos);
byte[] zipBytes = baos.toByteArray();
log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath);
txTemplate.execute(status -> {
SenderModel model = senderModelRepository.findByPersonId(personId)
.orElseGet(() -> SenderModel.builder().personId(personId).build());
model.setModelPath(outputModelPath);
model.setCer(result.cer());
model.setAccuracy(result.accuracy());
model.setCorrectedLinesAtTraining(correctedLines);
senderModelRepository.save(model);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setCer(result.cer());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
trainingRunRepository.save(run);
log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer());
return null;
});
} catch (Exception e) {
txTemplate.execute(status -> {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
trainingRunRepository.save(run);
log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e);
return null;
});
} finally {
MDC.remove("trainingRunId");
promoteNextQueuedRun();
}
}
/**
* Promotes the oldest QUEUED sender run to RUNNING and triggers its training.
* Called in the finally block of triggerSenderTraining, creating a sequential chain:
* each run promotes the next only after it fully completes (success or failure).
*/
private void promoteNextQueuedRun() {
Optional<OcrTrainingRun> queuedOpt = txTemplate.execute(status ->
trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)
.map(queued -> {
queued.setStatus(TrainingStatus.RUNNING);
return trainingRunRepository.save(queued);
}));
if (queuedOpt != null && queuedOpt.isPresent()) {
OcrTrainingRun promoted = queuedOpt.get();
log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId());
long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId());
triggerSenderTraining(promoted.getPersonId(), (int) freshCount);
}
}
}

View File

@@ -4,11 +4,9 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingLabel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
@@ -27,7 +25,6 @@ import static org.mockito.Mockito.*;
class OcrTrainingServiceTest {
OcrTrainingRunRepository runRepository;
SenderModelRepository senderModelRepository;
TrainingDataExportService exportService;
SegmentationTrainingExportService segExportService;
OcrClient ocrClient;
@@ -39,7 +36,6 @@ class OcrTrainingServiceTest {
@BeforeEach
void setUp() {
runRepository = mock(OcrTrainingRunRepository.class);
senderModelRepository = mock(SenderModelRepository.class);
exportService = mock(TrainingDataExportService.class);
segExportService = mock(SegmentationTrainingExportService.class);
ocrClient = mock(OcrClient.class);
@@ -53,7 +49,7 @@ class OcrTrainingServiceTest {
return callback.doInTransaction(null);
});
service = new OcrTrainingService(runRepository, senderModelRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
when(blockRepository.count()).thenReturn(0L);
when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of());
@@ -233,48 +229,6 @@ class OcrTrainingServiceTest {
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
}
// ─── runOrQueueSenderTraining ────────────────────────────────────────────
@Test
void runOrQueueSenderTraining_returnsTrue_whenIdle() {
UUID personId = UUID.randomUUID();
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isTrue();
verify(runRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED));
}
@Test
void runOrQueueSenderTraining_savesQueued_whenRunning() {
UUID personId = UUID.randomUUID();
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(
Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent").build()));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
when(runRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(runRepository).save(argThat(r ->
r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId())));
}
@Test
void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() {
UUID personId = UUID.randomUUID();
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true);
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(runRepository, never()).save(any());
}
// ─── Orphan recovery ──────────────────────────────────────────────────────
@Test

View File

@@ -2,23 +2,31 @@ package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
import java.util.Optional;
import java.util.UUID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
class SenderModelServiceTest {
SenderModelRepository senderModelRepository;
TranscriptionBlockRepository blockRepository;
OcrTrainingService ocrTrainingService;
OcrTrainingRunRepository trainingRunRepository;
OcrClient ocrClient;
TransactionTemplate txTemplate;
TrainingDataExportService trainingDataExportService;
SenderModelService service;
UUID personId = UUID.randomUUID();
@@ -27,8 +35,19 @@ class SenderModelServiceTest {
void setUp() {
senderModelRepository = mock(SenderModelRepository.class);
blockRepository = mock(TranscriptionBlockRepository.class);
ocrTrainingService = mock(OcrTrainingService.class);
service = new SenderModelService(senderModelRepository, blockRepository, ocrTrainingService);
trainingRunRepository = mock(OcrTrainingRunRepository.class);
ocrClient = mock(OcrClient.class);
txTemplate = mock(TransactionTemplate.class);
trainingDataExportService = mock(TrainingDataExportService.class);
// Execute transaction callbacks inline so unit tests run without a real DataSource
when(txTemplate.execute(any())).thenAnswer(inv -> {
TransactionCallback<?> callback = inv.getArgument(0);
return callback.doInTransaction(null);
});
service = new SenderModelService(senderModelRepository, blockRepository,
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService);
ReflectionTestUtils.setField(service, "activationThreshold", 100);
ReflectionTestUtils.setField(service, "retrainDelta", 50);
}
@@ -40,20 +59,23 @@ class SenderModelServiceTest {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
service.checkAndTriggerTraining(personId);
SenderModelService spy = spy(service);
spy.checkAndTriggerTraining(personId);
verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt());
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
}
@Test
void checkAndTriggerTraining_triggersTraining_atActivationThreshold() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false);
service.checkAndTriggerTraining(personId);
SenderModelService spy = spy(service);
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
verify(ocrTrainingService).runOrQueueSenderTraining(personId, 100);
spy.checkAndTriggerTraining(personId);
verify(spy).runOrQueueSenderTraining(personId, 100);
}
// ─── Retrain delta ────────────────────────────────────────────────────────
@@ -65,9 +87,10 @@ class SenderModelServiceTest {
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L);
service.checkAndTriggerTraining(personId);
SenderModelService spy = spy(service);
spy.checkAndTriggerTraining(personId);
verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt());
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
}
@Test
@@ -76,11 +99,13 @@ class SenderModelServiceTest {
.correctedLinesAtTraining(100).build();
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L);
when(ocrTrainingService.runOrQueueSenderTraining(personId, 150)).thenReturn(false);
service.checkAndTriggerTraining(personId);
SenderModelService spy = spy(service);
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 150);
verify(ocrTrainingService).runOrQueueSenderTraining(personId, 150);
spy.checkAndTriggerTraining(personId);
verify(spy).runOrQueueSenderTraining(personId, 150);
}
// ─── runNow flag ──────────────────────────────────────────────────────────
@@ -89,21 +114,103 @@ class SenderModelServiceTest {
void checkAndTriggerTraining_callsTrigger_whenRunNow() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(true);
service.checkAndTriggerTraining(personId);
SenderModelService spy = spy(service);
doReturn(true).when(spy).runOrQueueSenderTraining(personId, 100);
doNothing().when(spy).triggerSenderTraining(personId, 100);
verify(ocrTrainingService).triggerSenderTraining(personId, 100);
spy.checkAndTriggerTraining(personId);
verify(spy).triggerSenderTraining(personId, 100);
}
@Test
void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false);
service.checkAndTriggerTraining(personId);
SenderModelService spy = spy(service);
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
verify(ocrTrainingService, never()).triggerSenderTraining(any(), anyInt());
spy.checkAndTriggerTraining(personId);
verify(spy, never()).triggerSenderTraining(any(), anyInt());
}
// ─── runOrQueueSenderTraining ─────────────────────────────────────────────
@Test
void runOrQueueSenderTraining_returnsTrue_whenIdle() {
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isTrue();
verify(trainingRunRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED));
}
@Test
void runOrQueueSenderTraining_savesQueued_whenRunning() {
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(
Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent").build()));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(trainingRunRepository).save(argThat(r ->
r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId())));
}
@Test
void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() {
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true);
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(trainingRunRepository, never()).save(any());
}
// ─── triggerSenderTraining — queue promotion ──────────────────────────────
@Test
void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception {
UUID nextPersonId = UUID.randomUUID();
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(10L);
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
OcrTrainingRun r = inv.getArgument(0);
if (r.getId() == null) r.setId(UUID.randomUUID());
return r;
});
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
when(ocrClient.trainSenderModel(any(), any()))
.thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
OcrTrainingRun queued = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.QUEUED)
.personId(nextPersonId).blockCount(5).documentCount(0)
.modelName("sender_" + nextPersonId).build();
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
.thenReturn(Optional.of(queued));
when(blockRepository.countManualKurrentBlocksByPerson(nextPersonId)).thenReturn(5L);
SenderModelService spy = spy(service);
// Stub the recursive call to stop the chain after one promotion
doNothing().when(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
spy.triggerSenderTraining(personId, 10);
// Queued run promoted to RUNNING
verify(trainingRunRepository).save(argThat(r ->
nextPersonId.equals(r.getPersonId()) && r.getStatus() == TrainingStatus.RUNNING));
// Next training was kicked off
verify(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
}
}