refactor(ocr): move sender training methods from OcrTrainingService to SenderModelService
Eliminates cross-domain repository access: OcrTrainingService no longer holds SenderModelRepository. SenderModelService now owns the full sender training lifecycle (runOrQueueSenderTraining, triggerSenderTraining, promoteNextQueuedRun), removing the circular dependency risk. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -5,10 +5,8 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.SenderModel;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.SenderModelRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.slf4j.MDC;
|
||||
import org.springframework.boot.context.event.ApplicationReadyEvent;
|
||||
@@ -30,7 +28,6 @@ import java.util.UUID;
|
||||
public class OcrTrainingService {
|
||||
|
||||
private final OcrTrainingRunRepository trainingRunRepository;
|
||||
private final SenderModelRepository senderModelRepository;
|
||||
private final TrainingDataExportService trainingDataExportService;
|
||||
private final SegmentationTrainingExportService segmentationTrainingExportService;
|
||||
private final OcrClient ocrClient;
|
||||
@@ -212,110 +209,6 @@ public class OcrTrainingService {
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Called from SenderModelService (which is @Async). Inserts a QUEUED row when training
|
||||
* is already running (coalescing: at most one QUEUED per person_id). Runs immediately if idle.
|
||||
* Returns true if training was started, false if queued or skipped.
|
||||
*/
|
||||
@Transactional
|
||||
public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) {
|
||||
if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) {
|
||||
log.info("Sender training already queued for person {} — skipping duplicate trigger", personId);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||
trainingRunRepository.save(OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.QUEUED)
|
||||
.personId(personId)
|
||||
.blockCount(blockCount)
|
||||
.documentCount(0)
|
||||
.modelName("sender_" + personId)
|
||||
.build());
|
||||
log.info("Queued sender training for person {} — training already running", personId);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/** Executes sender training synchronously. Caller must run this on a background thread. */
|
||||
public void triggerSenderTraining(UUID personId, int correctedLines) {
|
||||
String outputModelPath = "/app/models/sender_" + personId + ".mlmodel";
|
||||
|
||||
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||
return trainingRunRepository.save(OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.personId(personId)
|
||||
.blockCount((int) blockCount)
|
||||
.documentCount(0)
|
||||
.modelName("sender_" + personId)
|
||||
.build());
|
||||
}));
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started sender training run {} for person {}", runId, personId);
|
||||
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
trainingDataExportService.exportForSender(personId).writeTo(baos);
|
||||
byte[] zipBytes = baos.toByteArray();
|
||||
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length);
|
||||
OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath);
|
||||
|
||||
txTemplate.execute(status -> {
|
||||
SenderModel model = senderModelRepository.findByPersonId(personId)
|
||||
.orElseGet(() -> SenderModel.builder().personId(personId).build());
|
||||
model.setModelPath(outputModelPath);
|
||||
model.setCer(result.cer());
|
||||
model.setAccuracy(result.accuracy());
|
||||
model.setCorrectedLinesAtTraining(correctedLines);
|
||||
senderModelRepository.save(model);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setCer(result.cer());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer());
|
||||
return null;
|
||||
});
|
||||
} catch (Exception e) {
|
||||
txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e);
|
||||
return null;
|
||||
});
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
promoteNextQueuedRun();
|
||||
}
|
||||
}
|
||||
|
||||
private void promoteNextQueuedRun() {
|
||||
Optional<OcrTrainingRun> queuedOpt = txTemplate.execute(status ->
|
||||
trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)
|
||||
.map(queued -> {
|
||||
queued.setStatus(TrainingStatus.RUNNING);
|
||||
return trainingRunRepository.save(queued);
|
||||
}));
|
||||
|
||||
if (queuedOpt != null && queuedOpt.isPresent()) {
|
||||
OcrTrainingRun promoted = queuedOpt.get();
|
||||
log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId());
|
||||
long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId());
|
||||
triggerSenderTraining(promoted.getPersonId(), (int) freshCount);
|
||||
}
|
||||
}
|
||||
|
||||
@EventListener(ApplicationReadyEvent.class)
|
||||
@Transactional
|
||||
public void recoverOrphanedRuns() {
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.SenderModel;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.SenderModelRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.slf4j.MDC;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.time.Instant;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class SenderModelService {
|
||||
|
||||
private final SenderModelRepository senderModelRepository;
|
||||
private final TranscriptionBlockRepository blockRepository;
|
||||
private final OcrTrainingRunRepository trainingRunRepository;
|
||||
private final OcrClient ocrClient;
|
||||
private final TransactionTemplate txTemplate;
|
||||
private final TrainingDataExportService trainingDataExportService;
|
||||
|
||||
@Value("${ocr.sender-model.activation-threshold:100}")
|
||||
private int activationThreshold;
|
||||
|
||||
@Value("${ocr.sender-model.retrain-delta:50}")
|
||||
private int retrainDelta;
|
||||
|
||||
/** Returns the model path if a trained sender model exists for this person. */
|
||||
public Optional<String> maybeGetModelPath(UUID personId) {
|
||||
return senderModelRepository.findByPersonId(personId)
|
||||
.map(SenderModel::getModelPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Called after every MANUAL block save for HANDWRITING_KURRENT documents.
|
||||
* Checks activation and retrain thresholds; enqueues or starts sender training when met.
|
||||
*/
|
||||
@Async
|
||||
public void checkAndTriggerTraining(UUID personId) {
|
||||
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||
Optional<SenderModel> existing = senderModelRepository.findByPersonId(personId);
|
||||
|
||||
boolean shouldActivate = existing.isEmpty() && correctedLines >= activationThreshold;
|
||||
boolean shouldRetrain = existing.isPresent()
|
||||
&& (correctedLines - existing.get().getCorrectedLinesAtTraining()) >= retrainDelta;
|
||||
|
||||
if (!shouldActivate && !shouldRetrain) {
|
||||
return;
|
||||
}
|
||||
|
||||
log.info("Sender training threshold met for person {} (correctedLines={}, activate={}, retrain={})",
|
||||
personId, correctedLines, shouldActivate, shouldRetrain);
|
||||
|
||||
boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines);
|
||||
if (runNow) {
|
||||
triggerSenderTraining(personId, (int) correctedLines);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Inserts a QUEUED row when training is already running (coalescing: at most one QUEUED per
|
||||
* person_id). Runs immediately if idle. Returns true if training should start now, false if
|
||||
* queued or skipped.
|
||||
*/
|
||||
@Transactional
|
||||
public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) {
|
||||
if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) {
|
||||
log.info("Sender training already queued for person {} — skipping duplicate trigger", personId);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||
trainingRunRepository.save(OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.QUEUED)
|
||||
.personId(personId)
|
||||
.blockCount(blockCount)
|
||||
.documentCount(0)
|
||||
.modelName("sender_" + personId)
|
||||
.build());
|
||||
log.info("Queued sender training for person {} — training already running", personId);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/** Executes sender training synchronously. Caller must run this on a background thread. */
|
||||
public void triggerSenderTraining(UUID personId, int correctedLines) {
|
||||
String outputModelPath = "/app/models/sender_" + personId + ".mlmodel";
|
||||
|
||||
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||
return trainingRunRepository.save(OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.personId(personId)
|
||||
.blockCount((int) blockCount)
|
||||
.documentCount(0)
|
||||
.modelName("sender_" + personId)
|
||||
.build());
|
||||
}));
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started sender training run {} for person {}", runId, personId);
|
||||
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
trainingDataExportService.exportForSender(personId).writeTo(baos);
|
||||
byte[] zipBytes = baos.toByteArray();
|
||||
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length);
|
||||
OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath);
|
||||
|
||||
txTemplate.execute(status -> {
|
||||
SenderModel model = senderModelRepository.findByPersonId(personId)
|
||||
.orElseGet(() -> SenderModel.builder().personId(personId).build());
|
||||
model.setModelPath(outputModelPath);
|
||||
model.setCer(result.cer());
|
||||
model.setAccuracy(result.accuracy());
|
||||
model.setCorrectedLinesAtTraining(correctedLines);
|
||||
senderModelRepository.save(model);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setCer(result.cer());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer());
|
||||
return null;
|
||||
});
|
||||
} catch (Exception e) {
|
||||
txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e);
|
||||
return null;
|
||||
});
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
promoteNextQueuedRun();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Promotes the oldest QUEUED sender run to RUNNING and triggers its training.
|
||||
* Called in the finally block of triggerSenderTraining, creating a sequential chain:
|
||||
* each run promotes the next only after it fully completes (success or failure).
|
||||
*/
|
||||
private void promoteNextQueuedRun() {
|
||||
Optional<OcrTrainingRun> queuedOpt = txTemplate.execute(status ->
|
||||
trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)
|
||||
.map(queued -> {
|
||||
queued.setStatus(TrainingStatus.RUNNING);
|
||||
return trainingRunRepository.save(queued);
|
||||
}));
|
||||
|
||||
if (queuedOpt != null && queuedOpt.isPresent()) {
|
||||
OcrTrainingRun promoted = queuedOpt.get();
|
||||
log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId());
|
||||
long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId());
|
||||
triggerSenderTraining(promoted.getPersonId(), (int) freshCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,11 +4,9 @@ import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingLabel;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.SenderModelRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.transaction.support.TransactionCallback;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
@@ -27,7 +25,6 @@ import static org.mockito.Mockito.*;
|
||||
class OcrTrainingServiceTest {
|
||||
|
||||
OcrTrainingRunRepository runRepository;
|
||||
SenderModelRepository senderModelRepository;
|
||||
TrainingDataExportService exportService;
|
||||
SegmentationTrainingExportService segExportService;
|
||||
OcrClient ocrClient;
|
||||
@@ -39,7 +36,6 @@ class OcrTrainingServiceTest {
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
runRepository = mock(OcrTrainingRunRepository.class);
|
||||
senderModelRepository = mock(SenderModelRepository.class);
|
||||
exportService = mock(TrainingDataExportService.class);
|
||||
segExportService = mock(SegmentationTrainingExportService.class);
|
||||
ocrClient = mock(OcrClient.class);
|
||||
@@ -53,7 +49,7 @@ class OcrTrainingServiceTest {
|
||||
return callback.doInTransaction(null);
|
||||
});
|
||||
|
||||
service = new OcrTrainingService(runRepository, senderModelRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
|
||||
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
|
||||
|
||||
when(blockRepository.count()).thenReturn(0L);
|
||||
when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||
@@ -233,48 +229,6 @@ class OcrTrainingServiceTest {
|
||||
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
|
||||
}
|
||||
|
||||
// ─── runOrQueueSenderTraining ────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_returnsTrue_whenIdle() {
|
||||
UUID personId = UUID.randomUUID();
|
||||
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isTrue();
|
||||
verify(runRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED));
|
||||
}
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_savesQueued_whenRunning() {
|
||||
UUID personId = UUID.randomUUID();
|
||||
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(
|
||||
Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent").build()));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
|
||||
when(runRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isFalse();
|
||||
verify(runRepository).save(argThat(r ->
|
||||
r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId())));
|
||||
}
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() {
|
||||
UUID personId = UUID.randomUUID();
|
||||
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true);
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isFalse();
|
||||
verify(runRepository, never()).save(any());
|
||||
}
|
||||
|
||||
// ─── Orphan recovery ──────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
|
||||
@@ -2,23 +2,31 @@ package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.SenderModel;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.SenderModelRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.test.util.ReflectionTestUtils;
|
||||
import org.springframework.transaction.support.TransactionCallback;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyInt;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class SenderModelServiceTest {
|
||||
|
||||
SenderModelRepository senderModelRepository;
|
||||
TranscriptionBlockRepository blockRepository;
|
||||
OcrTrainingService ocrTrainingService;
|
||||
OcrTrainingRunRepository trainingRunRepository;
|
||||
OcrClient ocrClient;
|
||||
TransactionTemplate txTemplate;
|
||||
TrainingDataExportService trainingDataExportService;
|
||||
SenderModelService service;
|
||||
|
||||
UUID personId = UUID.randomUUID();
|
||||
@@ -27,8 +35,19 @@ class SenderModelServiceTest {
|
||||
void setUp() {
|
||||
senderModelRepository = mock(SenderModelRepository.class);
|
||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||
ocrTrainingService = mock(OcrTrainingService.class);
|
||||
service = new SenderModelService(senderModelRepository, blockRepository, ocrTrainingService);
|
||||
trainingRunRepository = mock(OcrTrainingRunRepository.class);
|
||||
ocrClient = mock(OcrClient.class);
|
||||
txTemplate = mock(TransactionTemplate.class);
|
||||
trainingDataExportService = mock(TrainingDataExportService.class);
|
||||
|
||||
// Execute transaction callbacks inline so unit tests run without a real DataSource
|
||||
when(txTemplate.execute(any())).thenAnswer(inv -> {
|
||||
TransactionCallback<?> callback = inv.getArgument(0);
|
||||
return callback.doInTransaction(null);
|
||||
});
|
||||
|
||||
service = new SenderModelService(senderModelRepository, blockRepository,
|
||||
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService);
|
||||
ReflectionTestUtils.setField(service, "activationThreshold", 100);
|
||||
ReflectionTestUtils.setField(service, "retrainDelta", 50);
|
||||
}
|
||||
@@ -40,20 +59,23 @@ class SenderModelServiceTest {
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L);
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
|
||||
service.checkAndTriggerTraining(personId);
|
||||
SenderModelService spy = spy(service);
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt());
|
||||
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
|
||||
}
|
||||
|
||||
@Test
|
||||
void checkAndTriggerTraining_triggersTraining_atActivationThreshold() {
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false);
|
||||
|
||||
service.checkAndTriggerTraining(personId);
|
||||
SenderModelService spy = spy(service);
|
||||
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
|
||||
|
||||
verify(ocrTrainingService).runOrQueueSenderTraining(personId, 100);
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy).runOrQueueSenderTraining(personId, 100);
|
||||
}
|
||||
|
||||
// ─── Retrain delta ────────────────────────────────────────────────────────
|
||||
@@ -65,9 +87,10 @@ class SenderModelServiceTest {
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L);
|
||||
|
||||
service.checkAndTriggerTraining(personId);
|
||||
SenderModelService spy = spy(service);
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt());
|
||||
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -76,11 +99,13 @@ class SenderModelServiceTest {
|
||||
.correctedLinesAtTraining(100).build();
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L);
|
||||
when(ocrTrainingService.runOrQueueSenderTraining(personId, 150)).thenReturn(false);
|
||||
|
||||
service.checkAndTriggerTraining(personId);
|
||||
SenderModelService spy = spy(service);
|
||||
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 150);
|
||||
|
||||
verify(ocrTrainingService).runOrQueueSenderTraining(personId, 150);
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy).runOrQueueSenderTraining(personId, 150);
|
||||
}
|
||||
|
||||
// ─── runNow flag ──────────────────────────────────────────────────────────
|
||||
@@ -89,21 +114,103 @@ class SenderModelServiceTest {
|
||||
void checkAndTriggerTraining_callsTrigger_whenRunNow() {
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(true);
|
||||
|
||||
service.checkAndTriggerTraining(personId);
|
||||
SenderModelService spy = spy(service);
|
||||
doReturn(true).when(spy).runOrQueueSenderTraining(personId, 100);
|
||||
doNothing().when(spy).triggerSenderTraining(personId, 100);
|
||||
|
||||
verify(ocrTrainingService).triggerSenderTraining(personId, 100);
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy).triggerSenderTraining(personId, 100);
|
||||
}
|
||||
|
||||
@Test
|
||||
void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() {
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false);
|
||||
|
||||
service.checkAndTriggerTraining(personId);
|
||||
SenderModelService spy = spy(service);
|
||||
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
|
||||
|
||||
verify(ocrTrainingService, never()).triggerSenderTraining(any(), anyInt());
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy, never()).triggerSenderTraining(any(), anyInt());
|
||||
}
|
||||
|
||||
// ─── runOrQueueSenderTraining ─────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_returnsTrue_whenIdle() {
|
||||
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
|
||||
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isTrue();
|
||||
verify(trainingRunRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED));
|
||||
}
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_savesQueued_whenRunning() {
|
||||
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
|
||||
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(
|
||||
Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent").build()));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
|
||||
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isFalse();
|
||||
verify(trainingRunRepository).save(argThat(r ->
|
||||
r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId())));
|
||||
}
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() {
|
||||
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true);
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isFalse();
|
||||
verify(trainingRunRepository, never()).save(any());
|
||||
}
|
||||
|
||||
// ─── triggerSenderTraining — queue promotion ──────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception {
|
||||
UUID nextPersonId = UUID.randomUUID();
|
||||
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(10L);
|
||||
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
|
||||
OcrTrainingRun r = inv.getArgument(0);
|
||||
if (r.getId() == null) r.setId(UUID.randomUUID());
|
||||
return r;
|
||||
});
|
||||
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
|
||||
when(ocrClient.trainSenderModel(any(), any()))
|
||||
.thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
|
||||
OcrTrainingRun queued = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.QUEUED)
|
||||
.personId(nextPersonId).blockCount(5).documentCount(0)
|
||||
.modelName("sender_" + nextPersonId).build();
|
||||
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
|
||||
.thenReturn(Optional.of(queued));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(nextPersonId)).thenReturn(5L);
|
||||
|
||||
SenderModelService spy = spy(service);
|
||||
// Stub the recursive call to stop the chain after one promotion
|
||||
doNothing().when(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
|
||||
|
||||
spy.triggerSenderTraining(personId, 10);
|
||||
|
||||
// Queued run promoted to RUNNING
|
||||
verify(trainingRunRepository).save(argThat(r ->
|
||||
nextPersonId.equals(r.getPersonId()) && r.getStatus() == TrainingStatus.RUNNING));
|
||||
// Next training was kicked off
|
||||
verify(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user