feat(ocr): add POST /api/ocr/train-sender endpoint for manual sender training
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,9 +7,11 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO;
|
||||
import org.raddatz.familienarchiv.dto.TrainingHistoryResponse;
|
||||
import org.raddatz.familienarchiv.dto.TrainingInfoResponse;
|
||||
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
|
||||
import org.raddatz.familienarchiv.dto.TriggerSenderTrainingDTO;
|
||||
import org.raddatz.familienarchiv.model.AppUser;
|
||||
import org.raddatz.familienarchiv.model.OcrJob;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.security.Permission;
|
||||
import org.raddatz.familienarchiv.security.RequirePermission;
|
||||
import org.raddatz.familienarchiv.service.OcrBatchService;
|
||||
@@ -17,6 +19,7 @@ import org.raddatz.familienarchiv.service.OcrProgressService;
|
||||
import org.raddatz.familienarchiv.service.OcrService;
|
||||
import org.raddatz.familienarchiv.service.OcrTrainingService;
|
||||
import org.raddatz.familienarchiv.service.SegmentationTrainingExportService;
|
||||
import org.raddatz.familienarchiv.service.SenderModelService;
|
||||
import org.raddatz.familienarchiv.service.TrainingDataExportService;
|
||||
import org.raddatz.familienarchiv.service.UserService;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
@@ -44,6 +47,7 @@ public class OcrController {
|
||||
private final TrainingDataExportService trainingDataExportService;
|
||||
private final SegmentationTrainingExportService segmentationTrainingExportService;
|
||||
private final OcrTrainingService ocrTrainingService;
|
||||
private final SenderModelService senderModelService;
|
||||
|
||||
@PostMapping("/api/documents/{documentId}/ocr")
|
||||
@ResponseStatus(HttpStatus.ACCEPTED)
|
||||
@@ -148,6 +152,17 @@ public class OcrController {
|
||||
return ocrTrainingService.getSenderTrainingHistory(personId);
|
||||
}
|
||||
|
||||
@PostMapping("/api/ocr/train-sender")
|
||||
@ResponseStatus(HttpStatus.ACCEPTED)
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public OcrTrainingRun triggerSenderTraining(@RequestBody TriggerSenderTrainingDTO dto) {
|
||||
OcrTrainingRun run = senderModelService.triggerManualSenderTraining(dto.personId());
|
||||
if (run.getStatus() == TrainingStatus.RUNNING) {
|
||||
senderModelService.runSenderTraining(dto.personId());
|
||||
}
|
||||
return run;
|
||||
}
|
||||
|
||||
private UUID resolveUserId(Authentication authentication) {
|
||||
if (authentication == null || !authentication.isAuthenticated()) return null;
|
||||
try {
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
package org.raddatz.familienarchiv.dto;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
public record TriggerSenderTrainingDTO(UUID personId) {}
|
||||
@@ -33,6 +33,7 @@ public class SenderModelService {
|
||||
private final OcrClient ocrClient;
|
||||
private final TransactionTemplate txTemplate;
|
||||
private final TrainingDataExportService trainingDataExportService;
|
||||
private final PersonService personService;
|
||||
|
||||
@Value("${ocr.sender-model.activation-threshold:100}")
|
||||
private int activationThreshold;
|
||||
@@ -50,6 +51,22 @@ public class SenderModelService {
|
||||
return senderModelRepository.findAll();
|
||||
}
|
||||
|
||||
public OcrTrainingRun triggerManualSenderTraining(UUID personId) {
|
||||
personService.getById(personId);
|
||||
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||
boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines);
|
||||
TrainingStatus targetStatus = runNow ? TrainingStatus.RUNNING : TrainingStatus.QUEUED;
|
||||
return trainingRunRepository.findFirstByPersonIdAndStatus(personId, targetStatus)
|
||||
.orElseThrow(() -> new IllegalStateException(
|
||||
"Expected " + targetStatus + " row for person " + personId));
|
||||
}
|
||||
|
||||
@Async
|
||||
public void runSenderTraining(UUID personId) {
|
||||
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||
triggerSenderTraining(personId, (int) correctedLines);
|
||||
}
|
||||
|
||||
/**
|
||||
* Called after every MANUAL block save for HANDWRITING_KURRENT documents.
|
||||
* Checks activation and retrain thresholds; enqueues or starts sender training when met.
|
||||
|
||||
@@ -10,6 +10,7 @@ import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.SenderModelRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.raddatz.familienarchiv.model.Person;
|
||||
import org.springframework.test.util.ReflectionTestUtils;
|
||||
import org.springframework.transaction.support.TransactionCallback;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
@@ -30,6 +31,7 @@ class SenderModelServiceTest {
|
||||
OcrClient ocrClient;
|
||||
TransactionTemplate txTemplate;
|
||||
TrainingDataExportService trainingDataExportService;
|
||||
PersonService personService;
|
||||
SenderModelService service;
|
||||
|
||||
UUID personId = UUID.randomUUID();
|
||||
@@ -42,6 +44,7 @@ class SenderModelServiceTest {
|
||||
ocrClient = mock(OcrClient.class);
|
||||
txTemplate = mock(TransactionTemplate.class);
|
||||
trainingDataExportService = mock(TrainingDataExportService.class);
|
||||
personService = mock(PersonService.class);
|
||||
|
||||
// Execute transaction callbacks inline so unit tests run without a real DataSource.
|
||||
// lenient: not every test hits the txTemplate path, but the setup is shared.
|
||||
@@ -51,7 +54,7 @@ class SenderModelServiceTest {
|
||||
});
|
||||
|
||||
service = new SenderModelService(senderModelRepository, blockRepository,
|
||||
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService);
|
||||
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService, personService);
|
||||
ReflectionTestUtils.setField(service, "activationThreshold", 100);
|
||||
ReflectionTestUtils.setField(service, "retrainDelta", 50);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user