test: add unit tests for SenderModelService, runOrQueueSenderTraining, and updateBlock hook

- SenderModelServiceTest: 6 tests covering activation threshold (99/100),
  retrain delta (149/150), runNow flag (queued vs triggered)
- OcrTrainingServiceTest: 3 tests for runOrQueueSenderTraining — idle returns
  true, running saves QUEUED, duplicate QUEUED coalesces
- TranscriptionServiceTest: 3 tests for updateBlock — sets source=MANUAL,
  triggers training for HANDWRITING_KURRENT with sender, skips when no sender

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-17 18:00:59 +02:00
parent bd23a76330
commit 7a342a07cf
3 changed files with 208 additions and 2 deletions

View File

@@ -8,6 +8,7 @@ import org.raddatz.familienarchiv.model.TrainingLabel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
@@ -26,6 +27,7 @@ import static org.mockito.Mockito.*;
class OcrTrainingServiceTest {
OcrTrainingRunRepository runRepository;
SenderModelRepository senderModelRepository;
TrainingDataExportService exportService;
SegmentationTrainingExportService segExportService;
OcrClient ocrClient;
@@ -37,6 +39,7 @@ class OcrTrainingServiceTest {
@BeforeEach
void setUp() {
runRepository = mock(OcrTrainingRunRepository.class);
senderModelRepository = mock(SenderModelRepository.class);
exportService = mock(TrainingDataExportService.class);
segExportService = mock(SegmentationTrainingExportService.class);
ocrClient = mock(OcrClient.class);
@@ -50,10 +53,10 @@ class OcrTrainingServiceTest {
return callback.doInTransaction(null);
});
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
service = new OcrTrainingService(runRepository, senderModelRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
when(blockRepository.count()).thenReturn(0L);
when(runRepository.findTop10ByOrderByCreatedAtDesc()).thenReturn(List.of());
when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of());
when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
}
@@ -230,6 +233,48 @@ class OcrTrainingServiceTest {
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
}
// ─── runOrQueueSenderTraining ────────────────────────────────────────────
@Test
void runOrQueueSenderTraining_returnsTrue_whenIdle() {
UUID personId = UUID.randomUUID();
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isTrue();
verify(runRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED));
}
@Test
void runOrQueueSenderTraining_savesQueued_whenRunning() {
UUID personId = UUID.randomUUID();
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(runRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(
Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent").build()));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
when(runRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(runRepository).save(argThat(r ->
r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId())));
}
@Test
void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() {
UUID personId = UUID.randomUUID();
when(runRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true);
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(runRepository, never()).save(any());
}
// ─── Orphan recovery ──────────────────────────────────────────────────────
@Test

View File

@@ -0,0 +1,109 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.test.util.ReflectionTestUtils;
import java.util.Optional;
import java.util.UUID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.*;
class SenderModelServiceTest {
SenderModelRepository senderModelRepository;
TranscriptionBlockRepository blockRepository;
OcrTrainingService ocrTrainingService;
SenderModelService service;
UUID personId = UUID.randomUUID();
@BeforeEach
void setUp() {
senderModelRepository = mock(SenderModelRepository.class);
blockRepository = mock(TranscriptionBlockRepository.class);
ocrTrainingService = mock(OcrTrainingService.class);
service = new SenderModelService(senderModelRepository, blockRepository, ocrTrainingService);
ReflectionTestUtils.setField(service, "activationThreshold", 100);
ReflectionTestUtils.setField(service, "retrainDelta", 50);
}
// ─── Activation threshold ─────────────────────────────────────────────────
@Test
void checkAndTriggerTraining_doesNothing_belowActivationThreshold() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
service.checkAndTriggerTraining(personId);
verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt());
}
@Test
void checkAndTriggerTraining_triggersTraining_atActivationThreshold() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false);
service.checkAndTriggerTraining(personId);
verify(ocrTrainingService).runOrQueueSenderTraining(personId, 100);
}
// ─── Retrain delta ────────────────────────────────────────────────────────
@Test
void checkAndTriggerTraining_doesNothing_belowRetrainDelta() {
SenderModel existing = SenderModel.builder().personId(personId)
.correctedLinesAtTraining(100).build();
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L);
service.checkAndTriggerTraining(personId);
verify(ocrTrainingService, never()).runOrQueueSenderTraining(any(), anyInt());
}
@Test
void checkAndTriggerTraining_triggersRetrain_atRetrainDelta() {
SenderModel existing = SenderModel.builder().personId(personId)
.correctedLinesAtTraining(100).build();
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L);
when(ocrTrainingService.runOrQueueSenderTraining(personId, 150)).thenReturn(false);
service.checkAndTriggerTraining(personId);
verify(ocrTrainingService).runOrQueueSenderTraining(personId, 150);
}
// ─── runNow flag ──────────────────────────────────────────────────────────
@Test
void checkAndTriggerTraining_callsTrigger_whenRunNow() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(true);
service.checkAndTriggerTraining(personId);
verify(ocrTrainingService).triggerSenderTraining(personId, 100);
}
@Test
void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
when(ocrTrainingService.runOrQueueSenderTraining(personId, 100)).thenReturn(false);
service.checkAndTriggerTraining(personId);
verify(ocrTrainingService, never()).triggerSenderTraining(any(), anyInt());
}
}

View File

@@ -13,6 +13,7 @@ import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.BlockSource;
import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.model.DocumentAnnotation;
import org.raddatz.familienarchiv.model.Person;
import org.raddatz.familienarchiv.model.ScriptType;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.model.TranscriptionBlockVersion;
@@ -189,6 +190,57 @@ class TranscriptionServiceTest {
assertThat(result.getLabel()).isEqualTo("Anrede");
}
@Test
void updateBlock_setsSourceToManual() {
UUID docId = UUID.randomUUID();
UUID blockId = UUID.randomUUID();
TranscriptionBlock block = TranscriptionBlock.builder()
.id(blockId).documentId(docId).text("old").source(BlockSource.OCR).build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.TYPEWRITER).build());
TranscriptionBlock result = transcriptionService.updateBlock(
docId, blockId, new UpdateTranscriptionBlockDTO("new", null), UUID.randomUUID());
assertThat(result.getSource()).isEqualTo(BlockSource.MANUAL);
}
@Test
void updateBlock_triggersTraining_whenKurrentSenderPresent() {
UUID docId = UUID.randomUUID();
UUID blockId = UUID.randomUUID();
UUID senderId = UUID.randomUUID();
Person sender = Person.builder().id(senderId).build();
TranscriptionBlock block = TranscriptionBlock.builder()
.id(blockId).documentId(docId).text("old").build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).sender(sender).build());
transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID());
verify(senderModelService).checkAndTriggerTraining(senderId);
}
@Test
void updateBlock_doesNotTriggerTraining_whenNoSender() {
UUID docId = UUID.randomUUID();
UUID blockId = UUID.randomUUID();
TranscriptionBlock block = TranscriptionBlock.builder()
.id(blockId).documentId(docId).text("old").build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).build());
transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID());
verify(senderModelService, never()).checkAndTriggerTraining(any());
}
// ─── deleteBlock ─────────────────────────────────────────────────────────────
@Test