feat(ocr): add POST /api/ocr/train-sender endpoint for manual sender training

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-18 00:16:02 +02:00
parent 178afcd496
commit c3fa09d12e
4 changed files with 41 additions and 1 deletions

View File

@@ -7,9 +7,11 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO;
import org.raddatz.familienarchiv.dto.TrainingHistoryResponse; import org.raddatz.familienarchiv.dto.TrainingHistoryResponse;
import org.raddatz.familienarchiv.dto.TrainingInfoResponse; import org.raddatz.familienarchiv.dto.TrainingInfoResponse;
import org.raddatz.familienarchiv.dto.TriggerOcrDTO; import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
import org.raddatz.familienarchiv.dto.TriggerSenderTrainingDTO;
import org.raddatz.familienarchiv.model.AppUser; import org.raddatz.familienarchiv.model.AppUser;
import org.raddatz.familienarchiv.model.OcrJob; import org.raddatz.familienarchiv.model.OcrJob;
import org.raddatz.familienarchiv.model.OcrTrainingRun; import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.security.Permission; import org.raddatz.familienarchiv.security.Permission;
import org.raddatz.familienarchiv.security.RequirePermission; import org.raddatz.familienarchiv.security.RequirePermission;
import org.raddatz.familienarchiv.service.OcrBatchService; import org.raddatz.familienarchiv.service.OcrBatchService;
@@ -17,6 +19,7 @@ import org.raddatz.familienarchiv.service.OcrProgressService;
import org.raddatz.familienarchiv.service.OcrService; import org.raddatz.familienarchiv.service.OcrService;
import org.raddatz.familienarchiv.service.OcrTrainingService; import org.raddatz.familienarchiv.service.OcrTrainingService;
import org.raddatz.familienarchiv.service.SegmentationTrainingExportService; import org.raddatz.familienarchiv.service.SegmentationTrainingExportService;
import org.raddatz.familienarchiv.service.SenderModelService;
import org.raddatz.familienarchiv.service.TrainingDataExportService; import org.raddatz.familienarchiv.service.TrainingDataExportService;
import org.raddatz.familienarchiv.service.UserService; import org.raddatz.familienarchiv.service.UserService;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
@@ -44,6 +47,7 @@ public class OcrController {
private final TrainingDataExportService trainingDataExportService; private final TrainingDataExportService trainingDataExportService;
private final SegmentationTrainingExportService segmentationTrainingExportService; private final SegmentationTrainingExportService segmentationTrainingExportService;
private final OcrTrainingService ocrTrainingService; private final OcrTrainingService ocrTrainingService;
private final SenderModelService senderModelService;
@PostMapping("/api/documents/{documentId}/ocr") @PostMapping("/api/documents/{documentId}/ocr")
@ResponseStatus(HttpStatus.ACCEPTED) @ResponseStatus(HttpStatus.ACCEPTED)
@@ -148,6 +152,17 @@ public class OcrController {
return ocrTrainingService.getSenderTrainingHistory(personId); return ocrTrainingService.getSenderTrainingHistory(personId);
} }
@PostMapping("/api/ocr/train-sender")
@ResponseStatus(HttpStatus.ACCEPTED)
@RequirePermission(Permission.ADMIN)
public OcrTrainingRun triggerSenderTraining(@RequestBody TriggerSenderTrainingDTO dto) {
OcrTrainingRun run = senderModelService.triggerManualSenderTraining(dto.personId());
if (run.getStatus() == TrainingStatus.RUNNING) {
senderModelService.runSenderTraining(dto.personId());
}
return run;
}
private UUID resolveUserId(Authentication authentication) { private UUID resolveUserId(Authentication authentication) {
if (authentication == null || !authentication.isAuthenticated()) return null; if (authentication == null || !authentication.isAuthenticated()) return null;
try { try {

View File

@@ -0,0 +1,5 @@
package org.raddatz.familienarchiv.dto;
import java.util.UUID;
public record TriggerSenderTrainingDTO(UUID personId) {}

View File

@@ -33,6 +33,7 @@ public class SenderModelService {
private final OcrClient ocrClient; private final OcrClient ocrClient;
private final TransactionTemplate txTemplate; private final TransactionTemplate txTemplate;
private final TrainingDataExportService trainingDataExportService; private final TrainingDataExportService trainingDataExportService;
private final PersonService personService;
@Value("${ocr.sender-model.activation-threshold:100}") @Value("${ocr.sender-model.activation-threshold:100}")
private int activationThreshold; private int activationThreshold;
@@ -50,6 +51,22 @@ public class SenderModelService {
return senderModelRepository.findAll(); return senderModelRepository.findAll();
} }
public OcrTrainingRun triggerManualSenderTraining(UUID personId) {
personService.getById(personId);
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines);
TrainingStatus targetStatus = runNow ? TrainingStatus.RUNNING : TrainingStatus.QUEUED;
return trainingRunRepository.findFirstByPersonIdAndStatus(personId, targetStatus)
.orElseThrow(() -> new IllegalStateException(
"Expected " + targetStatus + " row for person " + personId));
}
@Async
public void runSenderTraining(UUID personId) {
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
triggerSenderTraining(personId, (int) correctedLines);
}
/** /**
* Called after every MANUAL block save for HANDWRITING_KURRENT documents. * Called after every MANUAL block save for HANDWRITING_KURRENT documents.
* Checks activation and retrain thresholds; enqueues or starts sender training when met. * Checks activation and retrain thresholds; enqueues or starts sender training when met.

View File

@@ -10,6 +10,7 @@ import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository; import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.raddatz.familienarchiv.model.Person;
import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate; import org.springframework.transaction.support.TransactionTemplate;
@@ -30,6 +31,7 @@ class SenderModelServiceTest {
OcrClient ocrClient; OcrClient ocrClient;
TransactionTemplate txTemplate; TransactionTemplate txTemplate;
TrainingDataExportService trainingDataExportService; TrainingDataExportService trainingDataExportService;
PersonService personService;
SenderModelService service; SenderModelService service;
UUID personId = UUID.randomUUID(); UUID personId = UUID.randomUUID();
@@ -42,6 +44,7 @@ class SenderModelServiceTest {
ocrClient = mock(OcrClient.class); ocrClient = mock(OcrClient.class);
txTemplate = mock(TransactionTemplate.class); txTemplate = mock(TransactionTemplate.class);
trainingDataExportService = mock(TrainingDataExportService.class); trainingDataExportService = mock(TrainingDataExportService.class);
personService = mock(PersonService.class);
// Execute transaction callbacks inline so unit tests run without a real DataSource. // Execute transaction callbacks inline so unit tests run without a real DataSource.
// lenient: not every test hits the txTemplate path, but the setup is shared. // lenient: not every test hits the txTemplate path, but the setup is shared.
@@ -51,7 +54,7 @@ class SenderModelServiceTest {
}); });
service = new SenderModelService(senderModelRepository, blockRepository, service = new SenderModelService(senderModelRepository, blockRepository,
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService); trainingRunRepository, ocrClient, txTemplate, trainingDataExportService, personService);
ReflectionTestUtils.setField(service, "activationThreshold", 100); ReflectionTestUtils.setField(service, "activationThreshold", 100);
ReflectionTestUtils.setField(service, "retrainDelta", 50); ReflectionTestUtils.setField(service, "retrainDelta", 50);
} }