diff --git a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java index a6a65a7b..65db519c 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java @@ -11,7 +11,6 @@ import org.raddatz.familienarchiv.dto.TriggerSenderTrainingDTO; import org.raddatz.familienarchiv.model.AppUser; import org.raddatz.familienarchiv.model.OcrJob; import org.raddatz.familienarchiv.model.OcrTrainingRun; -import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.security.Permission; import org.raddatz.familienarchiv.security.RequirePermission; import org.raddatz.familienarchiv.service.OcrBatchService; @@ -156,11 +155,7 @@ public class OcrController { @ResponseStatus(HttpStatus.ACCEPTED) @RequirePermission(Permission.ADMIN) public OcrTrainingRun triggerSenderTraining(@Valid @RequestBody TriggerSenderTrainingDTO dto) { - OcrTrainingRun run = senderModelService.triggerManualSenderTraining(dto.personId()); - if (run.getStatus() == TrainingStatus.RUNNING) { - senderModelService.runSenderTraining(dto.personId()); - } - return run; + return senderModelService.triggerManualSenderTraining(dto.personId()); } private UUID resolveUserId(Authentication authentication) { diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java index 4bcd5df0..8621a153 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java @@ -9,7 +9,9 @@ import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.SenderModelRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.slf4j.MDC; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Lazy; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -35,6 +37,11 @@ public class SenderModelService { private final TrainingDataExportService trainingDataExportService; private final PersonService personService; + // Self-reference through the Spring proxy so @Async is honoured on self-calls. + @Lazy + @Autowired + private SenderModelService self; + @Value("${ocr.sender-model.activation-threshold:100}") private int activationThreshold; @@ -56,9 +63,13 @@ public class SenderModelService { long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId); boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines); TrainingStatus targetStatus = runNow ? TrainingStatus.RUNNING : TrainingStatus.QUEUED; - return trainingRunRepository.findFirstByPersonIdAndStatus(personId, targetStatus) + OcrTrainingRun run = trainingRunRepository.findFirstByPersonIdAndStatus(personId, targetStatus) .orElseThrow(() -> new IllegalStateException( "Expected " + targetStatus + " row for person " + personId)); + if (runNow) { + self.runSenderTraining(personId); + } + return run; } @Async diff --git a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java index 91757d5e..9fe3b28b 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/controller/OcrControllerTest.java @@ -30,6 +30,8 @@ import java.util.Map; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; @@ -429,6 +431,23 @@ class OcrControllerTest { .andExpect(jsonPath("$.status").value("QUEUED")); } + @Test + @WithMockUser(authorities = "ADMIN") + void triggerSenderTraining_doesNotCallRunSenderTraining_fromController() throws Exception { + UUID personId = UUID.randomUUID(); + OcrTrainingRun run = OcrTrainingRun.builder() + .id(UUID.randomUUID()).status(TrainingStatus.RUNNING) + .personId(personId).blockCount(5).documentCount(0).modelName("sender_" + personId).build(); + when(senderModelService.triggerManualSenderTraining(personId)).thenReturn(run); + + mockMvc.perform(post("/api/ocr/train-sender") + .contentType(MediaType.APPLICATION_JSON) + .content("{\"personId\":\"" + personId + "\"}")) + .andExpect(status().isAccepted()); + + verify(senderModelService, never()).runSenderTraining(any()); + } + @Test @WithMockUser(authorities = "READ_ALL") void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception { diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java index 7f4ee02a..3d999d5c 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java @@ -34,6 +34,7 @@ class SenderModelServiceTest { TransactionTemplate txTemplate; TrainingDataExportService trainingDataExportService; PersonService personService; + SenderModelService selfProxy; SenderModelService service; UUID personId = UUID.randomUUID(); @@ -47,6 +48,7 @@ class SenderModelServiceTest { txTemplate = mock(TransactionTemplate.class); trainingDataExportService = mock(TrainingDataExportService.class); personService = mock(PersonService.class); + selfProxy = mock(SenderModelService.class); // Execute transaction callbacks inline so unit tests run without a real DataSource. // lenient: not every test hits the txTemplate path, but the setup is shared. @@ -57,6 +59,7 @@ class SenderModelServiceTest { service = new SenderModelService(senderModelRepository, blockRepository, trainingRunRepository, ocrClient, txTemplate, trainingDataExportService, personService); + ReflectionTestUtils.setField(service, "self", selfProxy); ReflectionTestUtils.setField(service, "activationThreshold", 100); ReflectionTestUtils.setField(service, "retrainDelta", 50); }