refactor(ocr): move sender training methods from OcrTrainingService to SenderModelService

Eliminates cross-domain repository access: OcrTrainingService no longer
holds SenderModelRepository. SenderModelService now owns the full sender
training lifecycle (runOrQueueSenderTraining, triggerSenderTraining,
promoteNextQueuedRun), removing the circular dependency risk.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marcel
2026-04-17 19:08:10 +02:00
parent 60c1ec7b5f
commit 386dc83958
4 changed files with 311 additions and 175 deletions

View File

@@ -5,10 +5,8 @@ import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.slf4j.MDC;
import org.springframework.boot.context.event.ApplicationReadyEvent;
@@ -30,7 +28,6 @@ import java.util.UUID;
public class OcrTrainingService {
private final OcrTrainingRunRepository trainingRunRepository;
private final SenderModelRepository senderModelRepository;
private final TrainingDataExportService trainingDataExportService;
private final SegmentationTrainingExportService segmentationTrainingExportService;
private final OcrClient ocrClient;
@@ -212,110 +209,6 @@ public class OcrTrainingService {
);
}
/**
* Called from SenderModelService (which is @Async). Inserts a QUEUED row when training
* is already running (coalescing: at most one QUEUED per person_id). Runs immediately if idle.
* Returns true if training was started, false if queued or skipped.
*/
@Transactional
public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) {
if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) {
log.info("Sender training already queued for person {} — skipping duplicate trigger", personId);
return false;
}
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId);
trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.QUEUED)
.personId(personId)
.blockCount(blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
log.info("Queued sender training for person {} — training already running", personId);
return false;
}
return true;
}
/** Executes sender training synchronously. Caller must run this on a background thread. */
public void triggerSenderTraining(UUID personId, int correctedLines) {
String outputModelPath = "/app/models/sender_" + personId + ".mlmodel";
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId);
return trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.personId(personId)
.blockCount((int) blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
}));
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started sender training run {} for person {}", runId, personId);
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
trainingDataExportService.exportForSender(personId).writeTo(baos);
byte[] zipBytes = baos.toByteArray();
log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath);
txTemplate.execute(status -> {
SenderModel model = senderModelRepository.findByPersonId(personId)
.orElseGet(() -> SenderModel.builder().personId(personId).build());
model.setModelPath(outputModelPath);
model.setCer(result.cer());
model.setAccuracy(result.accuracy());
model.setCorrectedLinesAtTraining(correctedLines);
senderModelRepository.save(model);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setCer(result.cer());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
trainingRunRepository.save(run);
log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer());
return null;
});
} catch (Exception e) {
txTemplate.execute(status -> {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
trainingRunRepository.save(run);
log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e);
return null;
});
} finally {
MDC.remove("trainingRunId");
promoteNextQueuedRun();
}
}
private void promoteNextQueuedRun() {
Optional<OcrTrainingRun> queuedOpt = txTemplate.execute(status ->
trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)
.map(queued -> {
queued.setStatus(TrainingStatus.RUNNING);
return trainingRunRepository.save(queued);
}));
if (queuedOpt != null && queuedOpt.isPresent()) {
OcrTrainingRun promoted = queuedOpt.get();
log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId());
long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId());
triggerSenderTraining(promoted.getPersonId(), (int) freshCount);
}
}
@EventListener(ApplicationReadyEvent.class)
@Transactional
public void recoverOrphanedRuns() {

View File

@@ -0,0 +1,182 @@
package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.slf4j.MDC;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.TransactionTemplate;
import java.io.ByteArrayOutputStream;
import java.time.Instant;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
@Service
@RequiredArgsConstructor
@Slf4j
public class SenderModelService {
private final SenderModelRepository senderModelRepository;
private final TranscriptionBlockRepository blockRepository;
private final OcrTrainingRunRepository trainingRunRepository;
private final OcrClient ocrClient;
private final TransactionTemplate txTemplate;
private final TrainingDataExportService trainingDataExportService;
@Value("${ocr.sender-model.activation-threshold:100}")
private int activationThreshold;
@Value("${ocr.sender-model.retrain-delta:50}")
private int retrainDelta;
/** Returns the model path if a trained sender model exists for this person. */
public Optional<String> maybeGetModelPath(UUID personId) {
return senderModelRepository.findByPersonId(personId)
.map(SenderModel::getModelPath);
}
/**
* Called after every MANUAL block save for HANDWRITING_KURRENT documents.
* Checks activation and retrain thresholds; enqueues or starts sender training when met.
*/
@Async
public void checkAndTriggerTraining(UUID personId) {
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
Optional<SenderModel> existing = senderModelRepository.findByPersonId(personId);
boolean shouldActivate = existing.isEmpty() && correctedLines >= activationThreshold;
boolean shouldRetrain = existing.isPresent()
&& (correctedLines - existing.get().getCorrectedLinesAtTraining()) >= retrainDelta;
if (!shouldActivate && !shouldRetrain) {
return;
}
log.info("Sender training threshold met for person {} (correctedLines={}, activate={}, retrain={})",
personId, correctedLines, shouldActivate, shouldRetrain);
boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines);
if (runNow) {
triggerSenderTraining(personId, (int) correctedLines);
}
}
/**
* Inserts a QUEUED row when training is already running (coalescing: at most one QUEUED per
* person_id). Runs immediately if idle. Returns true if training should start now, false if
* queued or skipped.
*/
@Transactional
public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) {
if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) {
log.info("Sender training already queued for person {} — skipping duplicate trigger", personId);
return false;
}
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId);
trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.QUEUED)
.personId(personId)
.blockCount(blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
log.info("Queued sender training for person {} — training already running", personId);
return false;
}
return true;
}
/** Executes sender training synchronously. Caller must run this on a background thread. */
public void triggerSenderTraining(UUID personId, int correctedLines) {
String outputModelPath = "/app/models/sender_" + personId + ".mlmodel";
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId);
return trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.personId(personId)
.blockCount((int) blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
}));
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started sender training run {} for person {}", runId, personId);
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
trainingDataExportService.exportForSender(personId).writeTo(baos);
byte[] zipBytes = baos.toByteArray();
log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath);
txTemplate.execute(status -> {
SenderModel model = senderModelRepository.findByPersonId(personId)
.orElseGet(() -> SenderModel.builder().personId(personId).build());
model.setModelPath(outputModelPath);
model.setCer(result.cer());
model.setAccuracy(result.accuracy());
model.setCorrectedLinesAtTraining(correctedLines);
senderModelRepository.save(model);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setCer(result.cer());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
trainingRunRepository.save(run);
log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer());
return null;
});
} catch (Exception e) {
txTemplate.execute(status -> {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
trainingRunRepository.save(run);
log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e);
return null;
});
} finally {
MDC.remove("trainingRunId");
promoteNextQueuedRun();
}
}
/**
* Promotes the oldest QUEUED sender run to RUNNING and triggers its training.
* Called in the finally block of triggerSenderTraining, creating a sequential chain:
* each run promotes the next only after it fully completes (success or failure).
*/
private void promoteNextQueuedRun() {
Optional<OcrTrainingRun> queuedOpt = txTemplate.execute(status ->
trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)
.map(queued -> {
queued.setStatus(TrainingStatus.RUNNING);
return trainingRunRepository.save(queued);
}));
if (queuedOpt != null && queuedOpt.isPresent()) {
OcrTrainingRun promoted = queuedOpt.get();
log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId());
long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId());
triggerSenderTraining(promoted.getPersonId(), (int) freshCount);
}
}
}