From b1b7418404762bf2d20061cb4355e42efe3f8a20 Mon Sep 17 00:00:00 2001 From: Marcel Date: Sat, 18 Apr 2026 00:04:29 +0200 Subject: [PATCH] feat(ocr): promote TrainingInfoResponse to dto, add senderModels field Move TrainingInfoResponse from private nested record to dto/TrainingInfoResponse.java, add senderModels field, inject SenderModelService into OcrTrainingService so personNames covers all known senders rather than only recent-run participants. Co-Authored-By: Claude Sonnet 4.6 --- .../controller/OcrController.java | 3 +- .../dto/TrainingInfoResponse.java | 19 ++++++++++ .../service/OcrTrainingService.java | 28 ++++++-------- .../service/SenderModelService.java | 5 +++ .../controller/OcrControllerTest.java | 34 ++++++++++++----- .../service/OcrTrainingServiceTest.java | 38 +++++++++++++------ 6 files changed, 88 insertions(+), 39 deletions(-) create mode 100644 backend/src/main/java/org/raddatz/familienarchiv/dto/TrainingInfoResponse.java diff --git a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java index 4ada18e4..6eb12674 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java @@ -4,6 +4,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.raddatz.familienarchiv.dto.BatchOcrDTO; import org.raddatz.familienarchiv.dto.OcrStatusDTO; +import org.raddatz.familienarchiv.dto.TrainingInfoResponse; import org.raddatz.familienarchiv.dto.TriggerOcrDTO; import org.raddatz.familienarchiv.model.AppUser; import org.raddatz.familienarchiv.model.OcrJob; @@ -130,7 +131,7 @@ public class OcrController { @GetMapping("/api/ocr/training-info") @RequirePermission(Permission.ADMIN) - public OcrTrainingService.TrainingInfoResponse getTrainingInfo() { + public TrainingInfoResponse getTrainingInfo() { return ocrTrainingService.getTrainingInfo(); } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/dto/TrainingInfoResponse.java b/backend/src/main/java/org/raddatz/familienarchiv/dto/TrainingInfoResponse.java new file mode 100644 index 00000000..6f1c6343 --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/dto/TrainingInfoResponse.java @@ -0,0 +1,19 @@ +package org.raddatz.familienarchiv.dto; + +import org.raddatz.familienarchiv.model.OcrTrainingRun; +import org.raddatz.familienarchiv.model.SenderModel; + +import java.util.List; +import java.util.Map; + +public record TrainingInfoResponse( + int availableBlocks, + int totalOcrBlocks, + int availableDocuments, + int availableSegBlocks, + boolean ocrServiceAvailable, + OcrTrainingRun lastRun, + List runs, + Map personNames, + List senderModels +) {} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java index 20780985..02f2cda2 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/OcrTrainingService.java @@ -2,9 +2,11 @@ package org.raddatz.familienarchiv.service; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.raddatz.familienarchiv.dto.TrainingInfoResponse; import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.exception.ErrorCode; import org.raddatz.familienarchiv.model.OcrTrainingRun; +import org.raddatz.familienarchiv.model.SenderModel; import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; @@ -37,17 +39,7 @@ public class OcrTrainingService { private final TranscriptionBlockRepository blockRepository; private final TransactionTemplate txTemplate; private final PersonService personService; - - public record TrainingInfoResponse( - int availableBlocks, - int totalOcrBlocks, - int availableDocuments, - int availableSegBlocks, - boolean ocrServiceAvailable, - OcrTrainingRun lastRun, - List runs, - Map personNames - ) {} + private final SenderModelService senderModelService; private void assertNoRunningTraining() { if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) { @@ -202,14 +194,15 @@ public class OcrTrainingService { List recentRuns = trainingRunRepository.findTop20ByOrderByCreatedAtDesc(); OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0); - List distinctPersonIds = recentRuns.stream() - .map(OcrTrainingRun::getPersonId) - .filter(Objects::nonNull) + List senderModels = senderModelService.getAllSenderModels(); + + List allPersonIds = senderModels.stream() + .map(SenderModel::getPersonId) .distinct() .toList(); Map personNames = new HashMap<>(); - if (!distinctPersonIds.isEmpty()) { - personService.getAllById(distinctPersonIds) + if (!allPersonIds.isEmpty()) { + personService.getAllById(allPersonIds) .forEach(p -> personNames.put(p.getId().toString(), p.getDisplayName())); } @@ -221,7 +214,8 @@ public class OcrTrainingService { ocrHealthClient.isHealthy(), lastRun, recentRuns, - personNames + personNames, + senderModels ); } diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java index be9318a7..69114b4d 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java @@ -17,6 +17,7 @@ import org.springframework.transaction.support.TransactionTemplate; import java.io.ByteArrayOutputStream; import java.time.Instant; +import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.UUID; @@ -45,6 +46,10 @@ public class SenderModelService { .map(SenderModel::getModelPath); } + public List getAllSenderModels() { + return senderModelRepository.findAll(); + } + /** * Called after every MANUAL block save for HANDWRITING_KURRENT documents. * Checks activation and retrain thresholds; enqueues or starts sender training when met. diff --git a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java index e65e490c..eee17ad9 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java @@ -5,6 +5,7 @@ import org.junit.jupiter.api.Test; import org.raddatz.familienarchiv.config.SecurityConfig; import org.raddatz.familienarchiv.dto.BatchOcrDTO; import org.raddatz.familienarchiv.dto.OcrStatusDTO; +import org.raddatz.familienarchiv.dto.TrainingInfoResponse; import org.raddatz.familienarchiv.dto.TriggerOcrDTO; import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.exception.ErrorCode; @@ -49,6 +50,7 @@ class OcrControllerTest { @MockitoBean TrainingDataExportService trainingDataExportService; @MockitoBean SegmentationTrainingExportService segmentationTrainingExportService; @MockitoBean OcrTrainingService ocrTrainingService; + @MockitoBean SenderModelService senderModelService; @Test @WithMockUser(authorities = "WRITE_ALL") @@ -220,8 +222,7 @@ class OcrControllerTest { @Test @WithMockUser(authorities = "ADMIN") void getTrainingInfo_returns200_withInfo() throws Exception { - OcrTrainingService.TrainingInfoResponse info = - new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(), Map.of()); + TrainingInfoResponse info = new TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(), Map.of(), List.of()); when(ocrTrainingService.getTrainingInfo()).thenReturn(info); mockMvc.perform(get("/api/ocr/training-info")) @@ -237,9 +238,8 @@ class OcrControllerTest { OcrTrainingRun runWithPerson = OcrTrainingRun.builder() .id(UUID.randomUUID()).status(TrainingStatus.DONE) .personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build(); - OcrTrainingService.TrainingInfoResponse info = - new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, - List.of(runWithPerson), Map.of()); + TrainingInfoResponse info = new TrainingInfoResponse(5, 20, 2, 3, true, null, + List.of(runWithPerson), Map.of(), List.of()); when(ocrTrainingService.getTrainingInfo()).thenReturn(info); mockMvc.perform(get("/api/ocr/training-info")) @@ -254,9 +254,8 @@ class OcrControllerTest { OcrTrainingRun runWithPerson = OcrTrainingRun.builder() .id(UUID.randomUUID()).status(TrainingStatus.DONE) .personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build(); - OcrTrainingService.TrainingInfoResponse info = - new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, - List.of(runWithPerson), Map.of(personId.toString(), "Max Mustermann")); + TrainingInfoResponse info = new TrainingInfoResponse(5, 20, 2, 3, true, null, + List.of(runWithPerson), Map.of(personId.toString(), "Max Mustermann"), List.of()); when(ocrTrainingService.getTrainingInfo()).thenReturn(info); mockMvc.perform(get("/api/ocr/training-info")) @@ -267,8 +266,7 @@ class OcrControllerTest { @Test @WithMockUser(authorities = "ADMIN") void getTrainingInfo_serializes_null_lastRun_as_json_null() throws Exception { - OcrTrainingService.TrainingInfoResponse info = - new OcrTrainingService.TrainingInfoResponse(0, 0, 0, 0, false, null, List.of(), Map.of()); + TrainingInfoResponse info = new TrainingInfoResponse(0, 0, 0, 0, false, null, List.of(), Map.of(), List.of()); when(ocrTrainingService.getTrainingInfo()).thenReturn(info); mockMvc.perform(get("/api/ocr/training-info")) @@ -276,6 +274,22 @@ class OcrControllerTest { .andExpect(jsonPath("$.lastRun").doesNotExist()); } + @Test + @WithMockUser(authorities = "ADMIN") + void getTrainingInfo_returns200_includingSenderModels() throws Exception { + UUID personId = UUID.randomUUID(); + SenderModel senderModel = SenderModel.builder() + .id(UUID.randomUUID()).personId(personId).correctedLinesAtTraining(80).build(); + TrainingInfoResponse info = new TrainingInfoResponse( + 5, 20, 2, 3, true, null, List.of(), Map.of(), List.of(senderModel)); + when(ocrTrainingService.getTrainingInfo()).thenReturn(info); + + mockMvc.perform(get("/api/ocr/training-info")) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.senderModels").isArray()) + .andExpect(jsonPath("$.senderModels[0].personId").value(personId.toString())); + } + @Test @WithMockUser(authorities = "READ_ALL") void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception { diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java index 2e09eb69..c2d8ad95 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/OcrTrainingServiceTest.java @@ -2,9 +2,11 @@ package org.raddatz.familienarchiv.service; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.raddatz.familienarchiv.dto.TrainingInfoResponse; import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.model.OcrTrainingRun; import org.raddatz.familienarchiv.model.Person; +import org.raddatz.familienarchiv.model.SenderModel; import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; @@ -34,6 +36,7 @@ class OcrTrainingServiceTest { TranscriptionBlockRepository blockRepository; TransactionTemplate txTemplate; PersonService personService; + SenderModelService senderModelService; OcrTrainingService service; @BeforeEach @@ -46,6 +49,7 @@ class OcrTrainingServiceTest { blockRepository = mock(TranscriptionBlockRepository.class); txTemplate = mock(TransactionTemplate.class); personService = mock(PersonService.class); + senderModelService = mock(SenderModelService.class); // Execute transaction callbacks inline so unit tests run without a real DataSource when(txTemplate.execute(any())).thenAnswer(inv -> { @@ -53,11 +57,12 @@ class OcrTrainingServiceTest { return callback.doInTransaction(null); }); - service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate, personService); + service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate, personService, senderModelService); when(blockRepository.count()).thenReturn(0L); when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of()); when(segExportService.querySegmentationBlocks()).thenReturn(List.of()); + when(senderModelService.getAllSenderModels()).thenReturn(List.of()); } // ─── Concurrent guard ───────────────────────────────────────────────────── @@ -236,18 +241,14 @@ class OcrTrainingServiceTest { // ─── getTrainingInfo: batch person name resolution ──────────────────────── @Test - void getTrainingInfo_resolves_person_names_in_single_batch_call() { + void getTrainingInfo_resolves_person_names_from_all_senderModels_in_batch() { UUID personA = UUID.randomUUID(); UUID personB = UUID.randomUUID(); - List runs = List.of( - OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE) - .personId(personA).blockCount(5).documentCount(1).modelName("sender_a").build(), - OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE) - .personId(personB).blockCount(5).documentCount(1).modelName("sender_b").build(), - OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE) - .personId(personA).blockCount(5).documentCount(1).modelName("sender_a").build() + List models = List.of( + SenderModel.builder().id(UUID.randomUUID()).personId(personA).correctedLinesAtTraining(100).build(), + SenderModel.builder().id(UUID.randomUUID()).personId(personB).correctedLinesAtTraining(80).build() ); - when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(runs); + when(senderModelService.getAllSenderModels()).thenReturn(models); when(exportService.queryEligibleBlocks()).thenReturn(List.of()); Person pa = Person.builder().id(personA).firstName("Anna").lastName("Müller").build(); @@ -256,7 +257,7 @@ class OcrTrainingServiceTest { .thenReturn(List.of(pa, pb)); when(healthClient.isHealthy()).thenReturn(true); - OcrTrainingService.TrainingInfoResponse info = service.getTrainingInfo(); + TrainingInfoResponse info = service.getTrainingInfo(); verify(personService, never()).getById(any()); verify(personService, times(1)).getAllById(any()); @@ -264,6 +265,21 @@ class OcrTrainingServiceTest { assertThat(info.personNames()).containsKey(personB.toString()); } + @Test + void getTrainingInfo_includesSenderModels_inResponse() { + UUID personId = UUID.randomUUID(); + SenderModel model = SenderModel.builder() + .id(UUID.randomUUID()).personId(personId).correctedLinesAtTraining(120).build(); + when(senderModelService.getAllSenderModels()).thenReturn(List.of(model)); + when(exportService.queryEligibleBlocks()).thenReturn(List.of()); + when(healthClient.isHealthy()).thenReturn(false); + + TrainingInfoResponse info = service.getTrainingInfo(); + + assertThat(info.senderModels()).hasSize(1); + assertThat(info.senderModels().get(0).getPersonId()).isEqualTo(personId); + } + // ─── Orphan recovery ────────────────────────────────────────────────────── @Test