diff --git a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java index ae846f19..19f44ab5 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/controller/OcrController.java @@ -7,9 +7,11 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO; import org.raddatz.familienarchiv.dto.TrainingHistoryResponse; import org.raddatz.familienarchiv.dto.TrainingInfoResponse; import org.raddatz.familienarchiv.dto.TriggerOcrDTO; +import org.raddatz.familienarchiv.dto.TriggerSenderTrainingDTO; import org.raddatz.familienarchiv.model.AppUser; import org.raddatz.familienarchiv.model.OcrJob; import org.raddatz.familienarchiv.model.OcrTrainingRun; +import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.security.Permission; import org.raddatz.familienarchiv.security.RequirePermission; import org.raddatz.familienarchiv.service.OcrBatchService; @@ -17,6 +19,7 @@ import org.raddatz.familienarchiv.service.OcrProgressService; import org.raddatz.familienarchiv.service.OcrService; import org.raddatz.familienarchiv.service.OcrTrainingService; import org.raddatz.familienarchiv.service.SegmentationTrainingExportService; +import org.raddatz.familienarchiv.service.SenderModelService; import org.raddatz.familienarchiv.service.TrainingDataExportService; import org.raddatz.familienarchiv.service.UserService; import org.springframework.http.HttpHeaders; @@ -44,6 +47,7 @@ public class OcrController { private final TrainingDataExportService trainingDataExportService; private final SegmentationTrainingExportService segmentationTrainingExportService; private final OcrTrainingService ocrTrainingService; + private final SenderModelService senderModelService; @PostMapping("/api/documents/{documentId}/ocr") @ResponseStatus(HttpStatus.ACCEPTED) @@ -148,6 +152,17 @@ public class OcrController { return ocrTrainingService.getSenderTrainingHistory(personId); } + @PostMapping("/api/ocr/train-sender") + @ResponseStatus(HttpStatus.ACCEPTED) + @RequirePermission(Permission.ADMIN) + public OcrTrainingRun triggerSenderTraining(@RequestBody TriggerSenderTrainingDTO dto) { + OcrTrainingRun run = senderModelService.triggerManualSenderTraining(dto.personId()); + if (run.getStatus() == TrainingStatus.RUNNING) { + senderModelService.runSenderTraining(dto.personId()); + } + return run; + } + private UUID resolveUserId(Authentication authentication) { if (authentication == null || !authentication.isAuthenticated()) return null; try { diff --git a/backend/src/main/java/org/raddatz/familienarchiv/dto/TriggerSenderTrainingDTO.java b/backend/src/main/java/org/raddatz/familienarchiv/dto/TriggerSenderTrainingDTO.java new file mode 100644 index 00000000..baa8d0f3 --- /dev/null +++ b/backend/src/main/java/org/raddatz/familienarchiv/dto/TriggerSenderTrainingDTO.java @@ -0,0 +1,5 @@ +package org.raddatz.familienarchiv.dto; + +import java.util.UUID; + +public record TriggerSenderTrainingDTO(UUID personId) {} diff --git a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java index 69114b4d..4bcd5df0 100644 --- a/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java +++ b/backend/src/main/java/org/raddatz/familienarchiv/service/SenderModelService.java @@ -33,6 +33,7 @@ public class SenderModelService { private final OcrClient ocrClient; private final TransactionTemplate txTemplate; private final TrainingDataExportService trainingDataExportService; + private final PersonService personService; @Value("${ocr.sender-model.activation-threshold:100}") private int activationThreshold; @@ -50,6 +51,22 @@ public class SenderModelService { return senderModelRepository.findAll(); } + public OcrTrainingRun triggerManualSenderTraining(UUID personId) { + personService.getById(personId); + long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId); + boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines); + TrainingStatus targetStatus = runNow ? TrainingStatus.RUNNING : TrainingStatus.QUEUED; + return trainingRunRepository.findFirstByPersonIdAndStatus(personId, targetStatus) + .orElseThrow(() -> new IllegalStateException( + "Expected " + targetStatus + " row for person " + personId)); + } + + @Async + public void runSenderTraining(UUID personId) { + long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId); + triggerSenderTraining(personId, (int) correctedLines); + } + /** * Called after every MANUAL block save for HANDWRITING_KURRENT documents. * Checks activation and retrain thresholds; enqueues or starts sender training when met. diff --git a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java index 57318758..f4ae5f2d 100644 --- a/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java +++ b/backend/src/test/java/org/raddatz/familienarchiv/service/SenderModelServiceTest.java @@ -10,6 +10,7 @@ import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.SenderModelRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; +import org.raddatz.familienarchiv.model.Person; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionTemplate; @@ -30,6 +31,7 @@ class SenderModelServiceTest { OcrClient ocrClient; TransactionTemplate txTemplate; TrainingDataExportService trainingDataExportService; + PersonService personService; SenderModelService service; UUID personId = UUID.randomUUID(); @@ -42,6 +44,7 @@ class SenderModelServiceTest { ocrClient = mock(OcrClient.class); txTemplate = mock(TransactionTemplate.class); trainingDataExportService = mock(TrainingDataExportService.class); + personService = mock(PersonService.class); // Execute transaction callbacks inline so unit tests run without a real DataSource. // lenient: not every test hits the txTemplate path, but the setup is shared. @@ -51,7 +54,7 @@ class SenderModelServiceTest { }); service = new SenderModelService(senderModelRepository, blockRepository, - trainingRunRepository, ocrClient, txTemplate, trainingDataExportService); + trainingRunRepository, ocrClient, txTemplate, trainingDataExportService, personService); ReflectionTestUtils.setField(service, "activationThreshold", 100); ReflectionTestUtils.setField(service, "retrainDelta", 50); }