feat(ocr): add POST /api/ocr/train-sender endpoint for manual sender training

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-18 00:16:02 +02:00
parent 178afcd496
commit c3fa09d12e
4 changed files with 41 additions and 1 deletions

View File

@@ -7,9 +7,11 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO;
import org.raddatz.familienarchiv.dto.TrainingHistoryResponse;
import org.raddatz.familienarchiv.dto.TrainingInfoResponse;
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
import org.raddatz.familienarchiv.dto.TriggerSenderTrainingDTO;
import org.raddatz.familienarchiv.model.AppUser;
import org.raddatz.familienarchiv.model.OcrJob;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.security.Permission;
import org.raddatz.familienarchiv.security.RequirePermission;
import org.raddatz.familienarchiv.service.OcrBatchService;
@@ -17,6 +19,7 @@ import org.raddatz.familienarchiv.service.OcrProgressService;
import org.raddatz.familienarchiv.service.OcrService;
import org.raddatz.familienarchiv.service.OcrTrainingService;
import org.raddatz.familienarchiv.service.SegmentationTrainingExportService;
import org.raddatz.familienarchiv.service.SenderModelService;
import org.raddatz.familienarchiv.service.TrainingDataExportService;
import org.raddatz.familienarchiv.service.UserService;
import org.springframework.http.HttpHeaders;
@@ -44,6 +47,7 @@ public class OcrController {
private final TrainingDataExportService trainingDataExportService;
private final SegmentationTrainingExportService segmentationTrainingExportService;
private final OcrTrainingService ocrTrainingService;
private final SenderModelService senderModelService;
@PostMapping("/api/documents/{documentId}/ocr")
@ResponseStatus(HttpStatus.ACCEPTED)
@@ -148,6 +152,17 @@ public class OcrController {
return ocrTrainingService.getSenderTrainingHistory(personId);
}
@PostMapping("/api/ocr/train-sender")
@ResponseStatus(HttpStatus.ACCEPTED)
@RequirePermission(Permission.ADMIN)
public OcrTrainingRun triggerSenderTraining(@RequestBody TriggerSenderTrainingDTO dto) {
OcrTrainingRun run = senderModelService.triggerManualSenderTraining(dto.personId());
if (run.getStatus() == TrainingStatus.RUNNING) {
senderModelService.runSenderTraining(dto.personId());
}
return run;
}
private UUID resolveUserId(Authentication authentication) {
if (authentication == null || !authentication.isAuthenticated()) return null;
try {

View File

@@ -0,0 +1,5 @@
package org.raddatz.familienarchiv.dto;
import java.util.UUID;
public record TriggerSenderTrainingDTO(UUID personId) {}

View File

@@ -33,6 +33,7 @@ public class SenderModelService {
private final OcrClient ocrClient;
private final TransactionTemplate txTemplate;
private final TrainingDataExportService trainingDataExportService;
private final PersonService personService;
@Value("${ocr.sender-model.activation-threshold:100}")
private int activationThreshold;
@@ -50,6 +51,22 @@ public class SenderModelService {
return senderModelRepository.findAll();
}
public OcrTrainingRun triggerManualSenderTraining(UUID personId) {
personService.getById(personId);
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines);
TrainingStatus targetStatus = runNow ? TrainingStatus.RUNNING : TrainingStatus.QUEUED;
return trainingRunRepository.findFirstByPersonIdAndStatus(personId, targetStatus)
.orElseThrow(() -> new IllegalStateException(
"Expected " + targetStatus + " row for person " + personId));
}
@Async
public void runSenderTraining(UUID personId) {
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
triggerSenderTraining(personId, (int) correctedLines);
}
/**
* Called after every MANUAL block save for HANDWRITING_KURRENT documents.
* Checks activation and retrain thresholds; enqueues or starts sender training when met.

View File

@@ -10,6 +10,7 @@ import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.raddatz.familienarchiv.model.Person;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
@@ -30,6 +31,7 @@ class SenderModelServiceTest {
OcrClient ocrClient;
TransactionTemplate txTemplate;
TrainingDataExportService trainingDataExportService;
PersonService personService;
SenderModelService service;
UUID personId = UUID.randomUUID();
@@ -42,6 +44,7 @@ class SenderModelServiceTest {
ocrClient = mock(OcrClient.class);
txTemplate = mock(TransactionTemplate.class);
trainingDataExportService = mock(TrainingDataExportService.class);
personService = mock(PersonService.class);
// Execute transaction callbacks inline so unit tests run without a real DataSource.
// lenient: not every test hits the txTemplate path, but the setup is shared.
@@ -51,7 +54,7 @@ class SenderModelServiceTest {
});
service = new SenderModelService(senderModelRepository, blockRepository,
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService);
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService, personService);
ReflectionTestUtils.setField(service, "activationThreshold", 100);
ReflectionTestUtils.setField(service, "retrainDelta", 50);
}