refactor(ocr): move sender training methods from OcrTrainingService to SenderModelService

Eliminates cross-domain repository access: OcrTrainingService no longer
holds SenderModelRepository. SenderModelService now owns the full sender
training lifecycle (runOrQueueSenderTraining, triggerSenderTraining,
promoteNextQueuedRun), removing the circular dependency risk.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-17 19:08:10 +02:00
parent 60c1ec7b5f
commit 386dc83958
4 changed files with 311 additions and 175 deletions

View File

@@ -4,11 +4,9 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingLabel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
@@ -27,7 +25,6 @@ import static org.mockito.Mockito.*;
class OcrTrainingServiceTest {
OcrTrainingRunRepository runRepository;
SenderModelRepository senderModelRepository;
TrainingDataExportService exportService;
SegmentationTrainingExportService segExportService;
OcrClient ocrClient;
@@ -39,7 +36,6 @@ class OcrTrainingServiceTest {
@BeforeEach
void setUp() {
runRepository = mock(OcrTrainingRunRepository.class);
senderModelRepository = mock(SenderModelRepository.class);
exportService = mock(TrainingDataExportService.class);
segExportService = mock(SegmentationTrainingExportService.class);
ocrClient = mock(OcrClient.class);
@@ -53,7 +49,7 @@ class OcrTrainingServiceTest {
return callback.doInTransaction(null);
});
service = new OcrTrainingService(runRepository, senderModelRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
when(blockRepository.count()).thenReturn(0L);
when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of());
@@ -233,48 +229,6 @@ class OcrTrainingServiceTest {
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
}
// ─── runOrQueueSenderTraining ────────────────────────────────────────────
@Test
void runOrQueueSenderTraining_returnsTrue_whenIdle() {
UUID personId = UUID.randomUUID();
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isTrue();
verify(runRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED));
}
@Test
void runOrQueueSenderTraining_savesQueued_whenRunning() {
UUID personId = UUID.randomUUID();
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(
Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent").build()));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
when(runRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(runRepository).save(argThat(r ->
r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId())));
}
@Test
void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() {
UUID personId = UUID.randomUUID();
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true);
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(runRepository, never()).save(any());
}
// ─── Orphan recovery ──────────────────────────────────────────────────────
@Test

View File

@@ -2,23 +2,31 @@ package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
import java.util.Optional;
import java.util.UUID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
class SenderModelServiceTest {
SenderModelRepository senderModelRepository;
TranscriptionBlockRepository blockRepository;
OcrTrainingService ocrTrainingService;
OcrTrainingRunRepository trainingRunRepository;
OcrClient ocrClient;
TransactionTemplate txTemplate;
TrainingDataExportService trainingDataExportService;
SenderModelService service;
UUID personId = UUID.randomUUID();
@@ -27,8 +35,19 @@ class SenderModelServiceTest {
void setUp() {
senderModelRepository = mock(SenderModelRepository.class);
blockRepository = mock(TranscriptionBlockRepository.class);
ocrTrainingService = mock(OcrTrainingService.class);
service = new SenderModelService(senderModelRepository, blockRepository, ocrTrainingService);
trainingRunRepository = mock(OcrTrainingRunRepository.class);
ocrClient = mock(OcrClient.class);
txTemplate = mock(TransactionTemplate.class);
trainingDataExportService = mock(TrainingDataExportService.class);
// Execute transaction callbacks inline so unit tests run without a real DataSource
when(txTemplate.execute(any())).thenAnswer(inv -> {
TransactionCallback<?> callback = inv.getArgument(0);
return callback.doInTransaction(null);
});
service = new SenderModelService(senderModelRepository, blockRepository,
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService);
ReflectionTestUtils.setField(service, "activationThreshold", 100);
ReflectionTestUtils.setField(service, "retrainDelta", 50);
}
@@ -40,20 +59,23 @@ class SenderModelServiceTest {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
service.checkAndTriggerTraining(personId);
SenderModelService spy = spy(service);
spy.checkAndTriggerTraining(personId);
verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt());
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
}
@Test
void checkAndTriggerTraining_triggersTraining_atActivationThreshold() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false);
service.checkAndTriggerTraining(personId);
SenderModelService spy = spy(service);
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
verify(ocrTrainingService).runOrQueueSenderTraining(personId, 100);
spy.checkAndTriggerTraining(personId);
verify(spy).runOrQueueSenderTraining(personId, 100);
}
// ─── Retrain delta ────────────────────────────────────────────────────────
@@ -65,9 +87,10 @@ class SenderModelServiceTest {
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L);
service.checkAndTriggerTraining(personId);
SenderModelService spy = spy(service);
spy.checkAndTriggerTraining(personId);
verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt());
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
}
@Test
@@ -76,11 +99,13 @@ class SenderModelServiceTest {
.correctedLinesAtTraining(100).build();
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L);
when(ocrTrainingService.runOrQueueSenderTraining(personId, 150)).thenReturn(false);
service.checkAndTriggerTraining(personId);
SenderModelService spy = spy(service);
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 150);
verify(ocrTrainingService).runOrQueueSenderTraining(personId, 150);
spy.checkAndTriggerTraining(personId);
verify(spy).runOrQueueSenderTraining(personId, 150);
}
// ─── runNow flag ──────────────────────────────────────────────────────────
@@ -89,21 +114,103 @@ class SenderModelServiceTest {
void checkAndTriggerTraining_callsTrigger_whenRunNow() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(true);
service.checkAndTriggerTraining(personId);
SenderModelService spy = spy(service);
doReturn(true).when(spy).runOrQueueSenderTraining(personId, 100);
doNothing().when(spy).triggerSenderTraining(personId, 100);
verify(ocrTrainingService).triggerSenderTraining(personId, 100);
spy.checkAndTriggerTraining(personId);
verify(spy).triggerSenderTraining(personId, 100);
}
@Test
void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false);
service.checkAndTriggerTraining(personId);
SenderModelService spy = spy(service);
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
verify(ocrTrainingService, never()).triggerSenderTraining(any(), anyInt());
spy.checkAndTriggerTraining(personId);
verify(spy, never()).triggerSenderTraining(any(), anyInt());
}
// ─── runOrQueueSenderTraining ─────────────────────────────────────────────
@Test
void runOrQueueSenderTraining_returnsTrue_whenIdle() {
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isTrue();
verify(trainingRunRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED));
}
@Test
void runOrQueueSenderTraining_savesQueued_whenRunning() {
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(
Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent").build()));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(trainingRunRepository).save(argThat(r ->
r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId())));
}
@Test
void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() {
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true);
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(trainingRunRepository, never()).save(any());
}
// ─── triggerSenderTraining — queue promotion ──────────────────────────────
@Test
void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception {
UUID nextPersonId = UUID.randomUUID();
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(10L);
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
OcrTrainingRun r = inv.getArgument(0);
if (r.getId() == null) r.setId(UUID.randomUUID());
return r;
});
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
when(ocrClient.trainSenderModel(any(), any()))
.thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
OcrTrainingRun queued = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.QUEUED)
.personId(nextPersonId).blockCount(5).documentCount(0)
.modelName("sender_" + nextPersonId).build();
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
.thenReturn(Optional.of(queued));
when(blockRepository.countManualKurrentBlocksByPerson(nextPersonId)).thenReturn(5L);
SenderModelService spy = spy(service);
// Stub the recursive call to stop the chain after one promotion
doNothing().when(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
spy.triggerSenderTraining(personId, 10);
// Queued run promoted to RUNNING
verify(trainingRunRepository).save(argThat(r ->
nextPersonId.equals(r.getPersonId()) && r.getStatus() == TrainingStatus.RUNNING));
// Next training was kicked off
verify(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
}
}