feat(ocr): promote TrainingInfoResponse to dto, add senderModels field

Move TrainingInfoResponse from private nested record to dto/TrainingInfoResponse.java,
add senderModels field, inject SenderModelService into OcrTrainingService so personNames
covers all known senders rather than only recent-run participants.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-18 00:04:29 +02:00
parent a52c8bf079
commit b1b7418404
6 changed files with 88 additions and 39 deletions

View File

@@ -4,6 +4,7 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.dto.BatchOcrDTO; import org.raddatz.familienarchiv.dto.BatchOcrDTO;
import org.raddatz.familienarchiv.dto.OcrStatusDTO; import org.raddatz.familienarchiv.dto.OcrStatusDTO;
import org.raddatz.familienarchiv.dto.TrainingInfoResponse;
import org.raddatz.familienarchiv.dto.TriggerOcrDTO; import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
import org.raddatz.familienarchiv.model.AppUser; import org.raddatz.familienarchiv.model.AppUser;
import org.raddatz.familienarchiv.model.OcrJob; import org.raddatz.familienarchiv.model.OcrJob;
@@ -130,7 +131,7 @@ public class OcrController {
@GetMapping("/api/ocr/training-info") @GetMapping("/api/ocr/training-info")
@RequirePermission(Permission.ADMIN) @RequirePermission(Permission.ADMIN)
public OcrTrainingService.TrainingInfoResponse getTrainingInfo() { public TrainingInfoResponse getTrainingInfo() {
return ocrTrainingService.getTrainingInfo(); return ocrTrainingService.getTrainingInfo();
} }

View File

@@ -0,0 +1,19 @@
package org.raddatz.familienarchiv.dto;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.SenderModel;
import java.util.List;
import java.util.Map;
public record TrainingInfoResponse(
int availableBlocks,
int totalOcrBlocks,
int availableDocuments,
int availableSegBlocks,
boolean ocrServiceAvailable,
OcrTrainingRun lastRun,
List<OcrTrainingRun> runs,
Map<String, String> personNames,
List<SenderModel> senderModels
) {}

View File

@@ -2,9 +2,11 @@ package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.dto.TrainingInfoResponse;
import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode; import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.OcrTrainingRun; import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
@@ -37,17 +39,7 @@ public class OcrTrainingService {
private final TranscriptionBlockRepository blockRepository; private final TranscriptionBlockRepository blockRepository;
private final TransactionTemplate txTemplate; private final TransactionTemplate txTemplate;
private final PersonService personService; private final PersonService personService;
private final SenderModelService senderModelService;
public record TrainingInfoResponse(
int availableBlocks,
int totalOcrBlocks,
int availableDocuments,
int availableSegBlocks,
boolean ocrServiceAvailable,
OcrTrainingRun lastRun,
List<OcrTrainingRun> runs,
Map<String, String> personNames
) {}
private void assertNoRunningTraining() { private void assertNoRunningTraining() {
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
@@ -202,14 +194,15 @@ public class OcrTrainingService {
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop20ByOrderByCreatedAtDesc(); List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop20ByOrderByCreatedAtDesc();
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0); OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
List<UUID> distinctPersonIds = recentRuns.stream() List<SenderModel> senderModels = senderModelService.getAllSenderModels();
.map(OcrTrainingRun::getPersonId)
.filter(Objects::nonNull) List<UUID> allPersonIds = senderModels.stream()
.map(SenderModel::getPersonId)
.distinct() .distinct()
.toList(); .toList();
Map<String, String> personNames = new HashMap<>(); Map<String, String> personNames = new HashMap<>();
if (!distinctPersonIds.isEmpty()) { if (!allPersonIds.isEmpty()) {
personService.getAllById(distinctPersonIds) personService.getAllById(allPersonIds)
.forEach(p -> personNames.put(p.getId().toString(), p.getDisplayName())); .forEach(p -> personNames.put(p.getId().toString(), p.getDisplayName()));
} }
@@ -221,7 +214,8 @@ public class OcrTrainingService {
ocrHealthClient.isHealthy(), ocrHealthClient.isHealthy(),
lastRun, lastRun,
recentRuns, recentRuns,
personNames personNames,
senderModels
); );
} }

View File

@@ -17,6 +17,7 @@ import org.springframework.transaction.support.TransactionTemplate;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.time.Instant; import java.time.Instant;
import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@@ -45,6 +46,10 @@ public class SenderModelService {
.map(SenderModel::getModelPath); .map(SenderModel::getModelPath);
} }
public List<SenderModel> getAllSenderModels() {
return senderModelRepository.findAll();
}
/** /**
* Called after every MANUAL block save for HANDWRITING_KURRENT documents. * Called after every MANUAL block save for HANDWRITING_KURRENT documents.
* Checks activation and retrain thresholds; enqueues or starts sender training when met. * Checks activation and retrain thresholds; enqueues or starts sender training when met.

View File

@@ -5,6 +5,7 @@ import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.config.SecurityConfig; import org.raddatz.familienarchiv.config.SecurityConfig;
import org.raddatz.familienarchiv.dto.BatchOcrDTO; import org.raddatz.familienarchiv.dto.BatchOcrDTO;
import org.raddatz.familienarchiv.dto.OcrStatusDTO; import org.raddatz.familienarchiv.dto.OcrStatusDTO;
import org.raddatz.familienarchiv.dto.TrainingInfoResponse;
import org.raddatz.familienarchiv.dto.TriggerOcrDTO; import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode; import org.raddatz.familienarchiv.exception.ErrorCode;
@@ -49,6 +50,7 @@ class OcrControllerTest {
@MockitoBean TrainingDataExportService trainingDataExportService; @MockitoBean TrainingDataExportService trainingDataExportService;
@MockitoBean SegmentationTrainingExportService segmentationTrainingExportService; @MockitoBean SegmentationTrainingExportService segmentationTrainingExportService;
@MockitoBean OcrTrainingService ocrTrainingService; @MockitoBean OcrTrainingService ocrTrainingService;
@MockitoBean SenderModelService senderModelService;
@Test @Test
@WithMockUser(authorities = "WRITE_ALL") @WithMockUser(authorities = "WRITE_ALL")
@@ -220,8 +222,7 @@ class OcrControllerTest {
@Test @Test
@WithMockUser(authorities = "ADMIN") @WithMockUser(authorities = "ADMIN")
void getTrainingInfo_returns200_withInfo() throws Exception { void getTrainingInfo_returns200_withInfo() throws Exception {
OcrTrainingService.TrainingInfoResponse info = TrainingInfoResponse info = new TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(), Map.of(), List.of());
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(), Map.of());
when(ocrTrainingService.getTrainingInfo()).thenReturn(info); when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info")) mockMvc.perform(get("/api/ocr/training-info"))
@@ -237,9 +238,8 @@ class OcrControllerTest {
OcrTrainingRun runWithPerson = OcrTrainingRun.builder() OcrTrainingRun runWithPerson = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.DONE) .id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build(); .personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build();
OcrTrainingService.TrainingInfoResponse info = TrainingInfoResponse info = new TrainingInfoResponse(5, 20, 2, 3, true, null,
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(runWithPerson), Map.of(), List.of());
List.of(runWithPerson), Map.of());
when(ocrTrainingService.getTrainingInfo()).thenReturn(info); when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info")) mockMvc.perform(get("/api/ocr/training-info"))
@@ -254,9 +254,8 @@ class OcrControllerTest {
OcrTrainingRun runWithPerson = OcrTrainingRun.builder() OcrTrainingRun runWithPerson = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.DONE) .id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build(); .personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build();
OcrTrainingService.TrainingInfoResponse info = TrainingInfoResponse info = new TrainingInfoResponse(5, 20, 2, 3, true, null,
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(runWithPerson), Map.of(personId.toString(), "Max Mustermann"), List.of());
List.of(runWithPerson), Map.of(personId.toString(), "Max Mustermann"));
when(ocrTrainingService.getTrainingInfo()).thenReturn(info); when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info")) mockMvc.perform(get("/api/ocr/training-info"))
@@ -267,8 +266,7 @@ class OcrControllerTest {
@Test @Test
@WithMockUser(authorities = "ADMIN") @WithMockUser(authorities = "ADMIN")
void getTrainingInfo_serializes_null_lastRun_as_json_null() throws Exception { void getTrainingInfo_serializes_null_lastRun_as_json_null() throws Exception {
OcrTrainingService.TrainingInfoResponse info = TrainingInfoResponse info = new TrainingInfoResponse(0, 0, 0, 0, false, null, List.of(), Map.of(), List.of());
new OcrTrainingService.TrainingInfoResponse(0, 0, 0, 0, false, null, List.of(), Map.of());
when(ocrTrainingService.getTrainingInfo()).thenReturn(info); when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info")) mockMvc.perform(get("/api/ocr/training-info"))
@@ -276,6 +274,22 @@ class OcrControllerTest {
.andExpect(jsonPath("$.lastRun").doesNotExist()); .andExpect(jsonPath("$.lastRun").doesNotExist());
} }
@Test
@WithMockUser(authorities = "ADMIN")
void getTrainingInfo_returns200_includingSenderModels() throws Exception {
UUID personId = UUID.randomUUID();
SenderModel senderModel = SenderModel.builder()
.id(UUID.randomUUID()).personId(personId).correctedLinesAtTraining(80).build();
TrainingInfoResponse info = new TrainingInfoResponse(
5, 20, 2, 3, true, null, List.of(), Map.of(), List.of(senderModel));
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.senderModels").isArray())
.andExpect(jsonPath("$.senderModels[0].personId").value(personId.toString()));
}
@Test @Test
@WithMockUser(authorities = "READ_ALL") @WithMockUser(authorities = "READ_ALL")
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception { void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {

View File

@@ -2,9 +2,11 @@ package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.dto.TrainingInfoResponse;
import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.OcrTrainingRun; import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.Person; import org.raddatz.familienarchiv.model.Person;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
@@ -34,6 +36,7 @@ class OcrTrainingServiceTest {
TranscriptionBlockRepository blockRepository; TranscriptionBlockRepository blockRepository;
TransactionTemplate txTemplate; TransactionTemplate txTemplate;
PersonService personService; PersonService personService;
SenderModelService senderModelService;
OcrTrainingService service; OcrTrainingService service;
@BeforeEach @BeforeEach
@@ -46,6 +49,7 @@ class OcrTrainingServiceTest {
blockRepository = mock(TranscriptionBlockRepository.class); blockRepository = mock(TranscriptionBlockRepository.class);
txTemplate = mock(TransactionTemplate.class); txTemplate = mock(TransactionTemplate.class);
personService = mock(PersonService.class); personService = mock(PersonService.class);
senderModelService = mock(SenderModelService.class);
// Execute transaction callbacks inline so unit tests run without a real DataSource // Execute transaction callbacks inline so unit tests run without a real DataSource
when(txTemplate.execute(any())).thenAnswer(inv -> { when(txTemplate.execute(any())).thenAnswer(inv -> {
@@ -53,11 +57,12 @@ class OcrTrainingServiceTest {
return callback.doInTransaction(null); return callback.doInTransaction(null);
}); });
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate, personService); service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate, personService, senderModelService);
when(blockRepository.count()).thenReturn(0L); when(blockRepository.count()).thenReturn(0L);
when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of()); when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of());
when(segExportService.querySegmentationBlocks()).thenReturn(List.of()); when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
when(senderModelService.getAllSenderModels()).thenReturn(List.of());
} }
// ─── Concurrent guard ───────────────────────────────────────────────────── // ─── Concurrent guard ─────────────────────────────────────────────────────
@@ -236,18 +241,14 @@ class OcrTrainingServiceTest {
// ─── getTrainingInfo: batch person name resolution ──────────────────────── // ─── getTrainingInfo: batch person name resolution ────────────────────────
@Test @Test
void getTrainingInfo_resolves_person_names_in_single_batch_call() { void getTrainingInfo_resolves_person_names_from_all_senderModels_in_batch() {
UUID personA = UUID.randomUUID(); UUID personA = UUID.randomUUID();
UUID personB = UUID.randomUUID(); UUID personB = UUID.randomUUID();
List<OcrTrainingRun> runs = List.of( List<SenderModel> models = List.of(
OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE) SenderModel.builder().id(UUID.randomUUID()).personId(personA).correctedLinesAtTraining(100).build(),
.personId(personA).blockCount(5).documentCount(1).modelName("sender_a").build(), SenderModel.builder().id(UUID.randomUUID()).personId(personB).correctedLinesAtTraining(80).build()
OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personB).blockCount(5).documentCount(1).modelName("sender_b").build(),
OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personA).blockCount(5).documentCount(1).modelName("sender_a").build()
); );
when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(runs); when(senderModelService.getAllSenderModels()).thenReturn(models);
when(exportService.queryEligibleBlocks()).thenReturn(List.of()); when(exportService.queryEligibleBlocks()).thenReturn(List.of());
Person pa = Person.builder().id(personA).firstName("Anna").lastName("Müller").build(); Person pa = Person.builder().id(personA).firstName("Anna").lastName("Müller").build();
@@ -256,7 +257,7 @@ class OcrTrainingServiceTest {
.thenReturn(List.of(pa, pb)); .thenReturn(List.of(pa, pb));
when(healthClient.isHealthy()).thenReturn(true); when(healthClient.isHealthy()).thenReturn(true);
OcrTrainingService.TrainingInfoResponse info = service.getTrainingInfo(); TrainingInfoResponse info = service.getTrainingInfo();
verify(personService, never()).getById(any()); verify(personService, never()).getById(any());
verify(personService, times(1)).getAllById(any()); verify(personService, times(1)).getAllById(any());
@@ -264,6 +265,21 @@ class OcrTrainingServiceTest {
assertThat(info.personNames()).containsKey(personB.toString()); assertThat(info.personNames()).containsKey(personB.toString());
} }
@Test
void getTrainingInfo_includesSenderModels_inResponse() {
UUID personId = UUID.randomUUID();
SenderModel model = SenderModel.builder()
.id(UUID.randomUUID()).personId(personId).correctedLinesAtTraining(120).build();
when(senderModelService.getAllSenderModels()).thenReturn(List.of(model));
when(exportService.queryEligibleBlocks()).thenReturn(List.of());
when(healthClient.isHealthy()).thenReturn(false);
TrainingInfoResponse info = service.getTrainingInfo();
assertThat(info.senderModels()).hasSize(1);
assertThat(info.senderModels().get(0).getPersonId()).isEqualTo(personId);
}
// ─── Orphan recovery ────────────────────────────────────────────────────── // ─── Orphan recovery ──────────────────────────────────────────────────────
@Test @Test