feat(ocr): add POST /api/ocr/train-sender endpoint for manual sender training
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,9 +7,11 @@ import org.raddatz.familienarchiv.dto.OcrStatusDTO;
|
|||||||
import org.raddatz.familienarchiv.dto.TrainingHistoryResponse;
|
import org.raddatz.familienarchiv.dto.TrainingHistoryResponse;
|
||||||
import org.raddatz.familienarchiv.dto.TrainingInfoResponse;
|
import org.raddatz.familienarchiv.dto.TrainingInfoResponse;
|
||||||
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
|
import org.raddatz.familienarchiv.dto.TriggerOcrDTO;
|
||||||
|
import org.raddatz.familienarchiv.dto.TriggerSenderTrainingDTO;
|
||||||
import org.raddatz.familienarchiv.model.AppUser;
|
import org.raddatz.familienarchiv.model.AppUser;
|
||||||
import org.raddatz.familienarchiv.model.OcrJob;
|
import org.raddatz.familienarchiv.model.OcrJob;
|
||||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||||
|
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||||
import org.raddatz.familienarchiv.security.Permission;
|
import org.raddatz.familienarchiv.security.Permission;
|
||||||
import org.raddatz.familienarchiv.security.RequirePermission;
|
import org.raddatz.familienarchiv.security.RequirePermission;
|
||||||
import org.raddatz.familienarchiv.service.OcrBatchService;
|
import org.raddatz.familienarchiv.service.OcrBatchService;
|
||||||
@@ -17,6 +19,7 @@ import org.raddatz.familienarchiv.service.OcrProgressService;
|
|||||||
import org.raddatz.familienarchiv.service.OcrService;
|
import org.raddatz.familienarchiv.service.OcrService;
|
||||||
import org.raddatz.familienarchiv.service.OcrTrainingService;
|
import org.raddatz.familienarchiv.service.OcrTrainingService;
|
||||||
import org.raddatz.familienarchiv.service.SegmentationTrainingExportService;
|
import org.raddatz.familienarchiv.service.SegmentationTrainingExportService;
|
||||||
|
import org.raddatz.familienarchiv.service.SenderModelService;
|
||||||
import org.raddatz.familienarchiv.service.TrainingDataExportService;
|
import org.raddatz.familienarchiv.service.TrainingDataExportService;
|
||||||
import org.raddatz.familienarchiv.service.UserService;
|
import org.raddatz.familienarchiv.service.UserService;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
@@ -44,6 +47,7 @@ public class OcrController {
|
|||||||
private final TrainingDataExportService trainingDataExportService;
|
private final TrainingDataExportService trainingDataExportService;
|
||||||
private final SegmentationTrainingExportService segmentationTrainingExportService;
|
private final SegmentationTrainingExportService segmentationTrainingExportService;
|
||||||
private final OcrTrainingService ocrTrainingService;
|
private final OcrTrainingService ocrTrainingService;
|
||||||
|
private final SenderModelService senderModelService;
|
||||||
|
|
||||||
@PostMapping("/api/documents/{documentId}/ocr")
|
@PostMapping("/api/documents/{documentId}/ocr")
|
||||||
@ResponseStatus(HttpStatus.ACCEPTED)
|
@ResponseStatus(HttpStatus.ACCEPTED)
|
||||||
@@ -148,6 +152,17 @@ public class OcrController {
|
|||||||
return ocrTrainingService.getSenderTrainingHistory(personId);
|
return ocrTrainingService.getSenderTrainingHistory(personId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@PostMapping("/api/ocr/train-sender")
|
||||||
|
@ResponseStatus(HttpStatus.ACCEPTED)
|
||||||
|
@RequirePermission(Permission.ADMIN)
|
||||||
|
public OcrTrainingRun triggerSenderTraining(@RequestBody TriggerSenderTrainingDTO dto) {
|
||||||
|
OcrTrainingRun run = senderModelService.triggerManualSenderTraining(dto.personId());
|
||||||
|
if (run.getStatus() == TrainingStatus.RUNNING) {
|
||||||
|
senderModelService.runSenderTraining(dto.personId());
|
||||||
|
}
|
||||||
|
return run;
|
||||||
|
}
|
||||||
|
|
||||||
private UUID resolveUserId(Authentication authentication) {
|
private UUID resolveUserId(Authentication authentication) {
|
||||||
if (authentication == null || !authentication.isAuthenticated()) return null;
|
if (authentication == null || !authentication.isAuthenticated()) return null;
|
||||||
try {
|
try {
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
package org.raddatz.familienarchiv.dto;
|
||||||
|
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
public record TriggerSenderTrainingDTO(UUID personId) {}
|
||||||
@@ -33,6 +33,7 @@ public class SenderModelService {
|
|||||||
private final OcrClient ocrClient;
|
private final OcrClient ocrClient;
|
||||||
private final TransactionTemplate txTemplate;
|
private final TransactionTemplate txTemplate;
|
||||||
private final TrainingDataExportService trainingDataExportService;
|
private final TrainingDataExportService trainingDataExportService;
|
||||||
|
private final PersonService personService;
|
||||||
|
|
||||||
@Value("${ocr.sender-model.activation-threshold:100}")
|
@Value("${ocr.sender-model.activation-threshold:100}")
|
||||||
private int activationThreshold;
|
private int activationThreshold;
|
||||||
@@ -50,6 +51,22 @@ public class SenderModelService {
|
|||||||
return senderModelRepository.findAll();
|
return senderModelRepository.findAll();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public OcrTrainingRun triggerManualSenderTraining(UUID personId) {
|
||||||
|
personService.getById(personId);
|
||||||
|
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||||
|
boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines);
|
||||||
|
TrainingStatus targetStatus = runNow ? TrainingStatus.RUNNING : TrainingStatus.QUEUED;
|
||||||
|
return trainingRunRepository.findFirstByPersonIdAndStatus(personId, targetStatus)
|
||||||
|
.orElseThrow(() -> new IllegalStateException(
|
||||||
|
"Expected " + targetStatus + " row for person " + personId));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Async
|
||||||
|
public void runSenderTraining(UUID personId) {
|
||||||
|
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||||
|
triggerSenderTraining(personId, (int) correctedLines);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Called after every MANUAL block save for HANDWRITING_KURRENT documents.
|
* Called after every MANUAL block save for HANDWRITING_KURRENT documents.
|
||||||
* Checks activation and retrain thresholds; enqueues or starts sender training when met.
|
* Checks activation and retrain thresholds; enqueues or starts sender training when met.
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import org.raddatz.familienarchiv.model.TrainingStatus;
|
|||||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||||
import org.raddatz.familienarchiv.repository.SenderModelRepository;
|
import org.raddatz.familienarchiv.repository.SenderModelRepository;
|
||||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||||
|
import org.raddatz.familienarchiv.model.Person;
|
||||||
import org.springframework.test.util.ReflectionTestUtils;
|
import org.springframework.test.util.ReflectionTestUtils;
|
||||||
import org.springframework.transaction.support.TransactionCallback;
|
import org.springframework.transaction.support.TransactionCallback;
|
||||||
import org.springframework.transaction.support.TransactionTemplate;
|
import org.springframework.transaction.support.TransactionTemplate;
|
||||||
@@ -30,6 +31,7 @@ class SenderModelServiceTest {
|
|||||||
OcrClient ocrClient;
|
OcrClient ocrClient;
|
||||||
TransactionTemplate txTemplate;
|
TransactionTemplate txTemplate;
|
||||||
TrainingDataExportService trainingDataExportService;
|
TrainingDataExportService trainingDataExportService;
|
||||||
|
PersonService personService;
|
||||||
SenderModelService service;
|
SenderModelService service;
|
||||||
|
|
||||||
UUID personId = UUID.randomUUID();
|
UUID personId = UUID.randomUUID();
|
||||||
@@ -42,6 +44,7 @@ class SenderModelServiceTest {
|
|||||||
ocrClient = mock(OcrClient.class);
|
ocrClient = mock(OcrClient.class);
|
||||||
txTemplate = mock(TransactionTemplate.class);
|
txTemplate = mock(TransactionTemplate.class);
|
||||||
trainingDataExportService = mock(TrainingDataExportService.class);
|
trainingDataExportService = mock(TrainingDataExportService.class);
|
||||||
|
personService = mock(PersonService.class);
|
||||||
|
|
||||||
// Execute transaction callbacks inline so unit tests run without a real DataSource.
|
// Execute transaction callbacks inline so unit tests run without a real DataSource.
|
||||||
// lenient: not every test hits the txTemplate path, but the setup is shared.
|
// lenient: not every test hits the txTemplate path, but the setup is shared.
|
||||||
@@ -51,7 +54,7 @@ class SenderModelServiceTest {
|
|||||||
});
|
});
|
||||||
|
|
||||||
service = new SenderModelService(senderModelRepository, blockRepository,
|
service = new SenderModelService(senderModelRepository, blockRepository,
|
||||||
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService);
|
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService, personService);
|
||||||
ReflectionTestUtils.setField(service, "activationThreshold", 100);
|
ReflectionTestUtils.setField(service, "activationThreshold", 100);
|
||||||
ReflectionTestUtils.setField(service, "retrainDelta", 50);
|
ReflectionTestUtils.setField(service, "retrainDelta", 50);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user