refactor(ocr): move sender training methods from OcrTrainingService to SenderModelService
Eliminates cross-domain repository access: OcrTrainingService no longer holds SenderModelRepository. SenderModelService now owns the full sender training lifecycle (runOrQueueSenderTraining, triggerSenderTraining, promoteNextQueuedRun), removing the circular dependency risk. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -4,11 +4,9 @@ import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingLabel;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.SenderModelRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.transaction.support.TransactionCallback;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
@@ -27,7 +25,6 @@ import static org.mockito.Mockito.*;
|
||||
class OcrTrainingServiceTest {
|
||||
|
||||
OcrTrainingRunRepository runRepository;
|
||||
SenderModelRepository senderModelRepository;
|
||||
TrainingDataExportService exportService;
|
||||
SegmentationTrainingExportService segExportService;
|
||||
OcrClient ocrClient;
|
||||
@@ -39,7 +36,6 @@ class OcrTrainingServiceTest {
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
runRepository = mock(OcrTrainingRunRepository.class);
|
||||
senderModelRepository = mock(SenderModelRepository.class);
|
||||
exportService = mock(TrainingDataExportService.class);
|
||||
segExportService = mock(SegmentationTrainingExportService.class);
|
||||
ocrClient = mock(OcrClient.class);
|
||||
@@ -53,7 +49,7 @@ class OcrTrainingServiceTest {
|
||||
return callback.doInTransaction(null);
|
||||
});
|
||||
|
||||
service = new OcrTrainingService(runRepository, senderModelRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
|
||||
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
|
||||
|
||||
when(blockRepository.count()).thenReturn(0L);
|
||||
when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||
@@ -233,48 +229,6 @@ class OcrTrainingServiceTest {
|
||||
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
|
||||
}
|
||||
|
||||
// ─── runOrQueueSenderTraining ────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_returnsTrue_whenIdle() {
|
||||
UUID personId = UUID.randomUUID();
|
||||
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isTrue();
|
||||
verify(runRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED));
|
||||
}
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_savesQueued_whenRunning() {
|
||||
UUID personId = UUID.randomUUID();
|
||||
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
|
||||
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(
|
||||
Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent").build()));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
|
||||
when(runRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isFalse();
|
||||
verify(runRepository).save(argThat(r ->
|
||||
r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId())));
|
||||
}
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() {
|
||||
UUID personId = UUID.randomUUID();
|
||||
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true);
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isFalse();
|
||||
verify(runRepository, never()).save(any());
|
||||
}
|
||||
|
||||
// ─── Orphan recovery ──────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
|
||||
@@ -2,23 +2,31 @@ package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.SenderModel;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.SenderModelRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.test.util.ReflectionTestUtils;
|
||||
import org.springframework.transaction.support.TransactionCallback;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyInt;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
class SenderModelServiceTest {
|
||||
|
||||
SenderModelRepository senderModelRepository;
|
||||
TranscriptionBlockRepository blockRepository;
|
||||
OcrTrainingService ocrTrainingService;
|
||||
OcrTrainingRunRepository trainingRunRepository;
|
||||
OcrClient ocrClient;
|
||||
TransactionTemplate txTemplate;
|
||||
TrainingDataExportService trainingDataExportService;
|
||||
SenderModelService service;
|
||||
|
||||
UUID personId = UUID.randomUUID();
|
||||
@@ -27,8 +35,19 @@ class SenderModelServiceTest {
|
||||
void setUp() {
|
||||
senderModelRepository = mock(SenderModelRepository.class);
|
||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||
ocrTrainingService = mock(OcrTrainingService.class);
|
||||
service = new SenderModelService(senderModelRepository, blockRepository, ocrTrainingService);
|
||||
trainingRunRepository = mock(OcrTrainingRunRepository.class);
|
||||
ocrClient = mock(OcrClient.class);
|
||||
txTemplate = mock(TransactionTemplate.class);
|
||||
trainingDataExportService = mock(TrainingDataExportService.class);
|
||||
|
||||
// Execute transaction callbacks inline so unit tests run without a real DataSource
|
||||
when(txTemplate.execute(any())).thenAnswer(inv -> {
|
||||
TransactionCallback<?> callback = inv.getArgument(0);
|
||||
return callback.doInTransaction(null);
|
||||
});
|
||||
|
||||
service = new SenderModelService(senderModelRepository, blockRepository,
|
||||
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService);
|
||||
ReflectionTestUtils.setField(service, "activationThreshold", 100);
|
||||
ReflectionTestUtils.setField(service, "retrainDelta", 50);
|
||||
}
|
||||
@@ -40,20 +59,23 @@ class SenderModelServiceTest {
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L);
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
|
||||
service.checkAndTriggerTraining(personId);
|
||||
SenderModelService spy = spy(service);
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt());
|
||||
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
|
||||
}
|
||||
|
||||
@Test
|
||||
void checkAndTriggerTraining_triggersTraining_atActivationThreshold() {
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false);
|
||||
|
||||
service.checkAndTriggerTraining(personId);
|
||||
SenderModelService spy = spy(service);
|
||||
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
|
||||
|
||||
verify(ocrTrainingService).runOrQueueSenderTraining(personId, 100);
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy).runOrQueueSenderTraining(personId, 100);
|
||||
}
|
||||
|
||||
// ─── Retrain delta ────────────────────────────────────────────────────────
|
||||
@@ -65,9 +87,10 @@ class SenderModelServiceTest {
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L);
|
||||
|
||||
service.checkAndTriggerTraining(personId);
|
||||
SenderModelService spy = spy(service);
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt());
|
||||
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -76,11 +99,13 @@ class SenderModelServiceTest {
|
||||
.correctedLinesAtTraining(100).build();
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L);
|
||||
when(ocrTrainingService.runOrQueueSenderTraining(personId, 150)).thenReturn(false);
|
||||
|
||||
service.checkAndTriggerTraining(personId);
|
||||
SenderModelService spy = spy(service);
|
||||
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 150);
|
||||
|
||||
verify(ocrTrainingService).runOrQueueSenderTraining(personId, 150);
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy).runOrQueueSenderTraining(personId, 150);
|
||||
}
|
||||
|
||||
// ─── runNow flag ──────────────────────────────────────────────────────────
|
||||
@@ -89,21 +114,103 @@ class SenderModelServiceTest {
|
||||
void checkAndTriggerTraining_callsTrigger_whenRunNow() {
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(true);
|
||||
|
||||
service.checkAndTriggerTraining(personId);
|
||||
SenderModelService spy = spy(service);
|
||||
doReturn(true).when(spy).runOrQueueSenderTraining(personId, 100);
|
||||
doNothing().when(spy).triggerSenderTraining(personId, 100);
|
||||
|
||||
verify(ocrTrainingService).triggerSenderTraining(personId, 100);
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy).triggerSenderTraining(personId, 100);
|
||||
}
|
||||
|
||||
@Test
|
||||
void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() {
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false);
|
||||
|
||||
service.checkAndTriggerTraining(personId);
|
||||
SenderModelService spy = spy(service);
|
||||
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
|
||||
|
||||
verify(ocrTrainingService, never()).triggerSenderTraining(any(), anyInt());
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy, never()).triggerSenderTraining(any(), anyInt());
|
||||
}
|
||||
|
||||
// ─── runOrQueueSenderTraining ─────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_returnsTrue_whenIdle() {
|
||||
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
|
||||
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isTrue();
|
||||
verify(trainingRunRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED));
|
||||
}
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_savesQueued_whenRunning() {
|
||||
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
|
||||
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(
|
||||
Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent").build()));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
|
||||
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isFalse();
|
||||
verify(trainingRunRepository).save(argThat(r ->
|
||||
r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId())));
|
||||
}
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() {
|
||||
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true);
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isFalse();
|
||||
verify(trainingRunRepository, never()).save(any());
|
||||
}
|
||||
|
||||
// ─── triggerSenderTraining — queue promotion ──────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception {
|
||||
UUID nextPersonId = UUID.randomUUID();
|
||||
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(10L);
|
||||
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
|
||||
OcrTrainingRun r = inv.getArgument(0);
|
||||
if (r.getId() == null) r.setId(UUID.randomUUID());
|
||||
return r;
|
||||
});
|
||||
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
|
||||
when(ocrClient.trainSenderModel(any(), any()))
|
||||
.thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
|
||||
OcrTrainingRun queued = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.QUEUED)
|
||||
.personId(nextPersonId).blockCount(5).documentCount(0)
|
||||
.modelName("sender_" + nextPersonId).build();
|
||||
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
|
||||
.thenReturn(Optional.of(queued));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(nextPersonId)).thenReturn(5L);
|
||||
|
||||
SenderModelService spy = spy(service);
|
||||
// Stub the recursive call to stop the chain after one promotion
|
||||
doNothing().when(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
|
||||
|
||||
spy.triggerSenderTraining(personId, 10);
|
||||
|
||||
// Queued run promoted to RUNNING
|
||||
verify(trainingRunRepository).save(argThat(r ->
|
||||
nextPersonId.equals(r.getPersonId()) && r.getStatus() == TrainingStatus.RUNNING));
|
||||
// Next training was kicked off
|
||||
verify(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user