fix(training): remove @Transactional from triggerTraining to avoid holding DB connection during OCR HTTP call
OcrTrainingService.triggerTraining() and triggerSegTraining() held a DB connection open for the entire ketos training run (potentially minutes), risking connection pool exhaustion. Replaced class-level @Transactional with TransactionTemplate for narrow DB writes: guard+create and result-record each run in their own short transaction; the HTTP call to the OCR service runs between them with no open connection. Also replaces blockRepository.findAll().size() with blockRepository.count() in getTrainingInfo() to avoid loading every block into heap on each poll. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -13,11 +13,13 @@ import org.springframework.boot.context.event.ApplicationReadyEvent;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
@@ -31,6 +33,7 @@ public class OcrTrainingService {
|
||||
private final OcrClient ocrClient;
|
||||
private final OcrHealthClient ocrHealthClient;
|
||||
private final TranscriptionBlockRepository blockRepository;
|
||||
private final TransactionTemplate txTemplate;
|
||||
|
||||
public record TrainingInfoResponse(
|
||||
int availableBlocks,
|
||||
@@ -46,37 +49,40 @@ public class OcrTrainingService {
|
||||
// Python OCR service after each run. The DB-level RUNNING constraint (V30 partial unique
|
||||
// index) prevents concurrent training API calls, but cannot prevent two OCR service replicas
|
||||
// from diverging on model state. Deploy as a single instance only. See ADR-001.
|
||||
@Transactional
|
||||
public OcrTrainingRun triggerTraining(UUID triggeredBy) {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
// Short transaction: guard check + create RUNNING row, then commit immediately.
|
||||
// The DB connection is released before the OCR HTTP call, which can take several minutes.
|
||||
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
if (eligibleBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
|
||||
}
|
||||
var eligibleBlocks = trainingDataExportService.queryEligibleBlocks();
|
||||
if (eligibleBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible blocks are required to start training (found " + eligibleBlocks.size() + ")");
|
||||
}
|
||||
|
||||
long documentCount = eligibleBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
long documentCount = eligibleBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(eligibleBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("german_kurrent")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
run = trainingRunRepository.save(run);
|
||||
OcrTrainingRun newRun = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(eligibleBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("german_kurrent")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
return trainingRunRepository.save(newRun);
|
||||
}));
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started training run {} with {} blocks from {} documents",
|
||||
runId, eligibleBlocks.size(), documentCount);
|
||||
runId, run.getBlockCount(), run.getDocumentCount());
|
||||
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
@@ -86,58 +92,64 @@ public class OcrTrainingService {
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service", runId, zipBytes.length);
|
||||
OcrClient.TrainingResult result = ocrClient.trainModel(zipBytes);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setCer(result.cer());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setCer(result.cer());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
OcrTrainingRun updated = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Training completed — cer={} epochs={}", runId, result.cer(), result.epochs());
|
||||
return updated;
|
||||
}));
|
||||
} catch (Exception e) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
OcrTrainingRun failed = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Training failed: {}", runId, e.getMessage(), e);
|
||||
return failed;
|
||||
}));
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
}
|
||||
|
||||
return run;
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public OcrTrainingRun triggerSegTraining(UUID triggeredBy) {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
// Same pattern as triggerTraining: narrow transactions around DB writes only.
|
||||
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
throw DomainException.conflict(ErrorCode.TRAINING_ALREADY_RUNNING,
|
||||
"A training run is already in progress");
|
||||
}
|
||||
|
||||
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
|
||||
if (segBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
|
||||
}
|
||||
var segBlocks = segmentationTrainingExportService.querySegmentationBlocks();
|
||||
if (segBlocks.size() < 5) {
|
||||
throw DomainException.badRequest(ErrorCode.VALIDATION_ERROR,
|
||||
"At least 5 eligible segments are required to start training (found " + segBlocks.size() + ")");
|
||||
}
|
||||
|
||||
long documentCount = segBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
long documentCount = segBlocks.stream()
|
||||
.map(b -> b.getDocumentId())
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
OcrTrainingRun run = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(segBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("blla")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
run = trainingRunRepository.save(run);
|
||||
OcrTrainingRun newRun = OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.blockCount(segBlocks.size())
|
||||
.documentCount((int) documentCount)
|
||||
.modelName("blla")
|
||||
.triggeredBy(triggeredBy)
|
||||
.build();
|
||||
return trainingRunRepository.save(newRun);
|
||||
}));
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started segmentation training run {} with {} segments from {} documents",
|
||||
runId, segBlocks.size(), documentCount);
|
||||
runId, run.getBlockCount(), run.getDocumentCount());
|
||||
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
@@ -147,24 +159,28 @@ public class OcrTrainingService {
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service for segtrain", runId, zipBytes.length);
|
||||
OcrClient.TrainingResult result = ocrClient.segtrainModel(zipBytes);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setLoss(result.loss());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
OcrTrainingRun updated = trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Segmentation training completed — epochs={}", runId, result.epochs());
|
||||
return updated;
|
||||
}));
|
||||
} catch (Exception e) {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
run = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
|
||||
return Objects.requireNonNull(txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
OcrTrainingRun failed = trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Segmentation training failed: {}", runId, e.getMessage(), e);
|
||||
return failed;
|
||||
}));
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
}
|
||||
|
||||
return run;
|
||||
}
|
||||
|
||||
public TrainingInfoResponse getTrainingInfo() {
|
||||
@@ -174,7 +190,7 @@ public class OcrTrainingService {
|
||||
.distinct()
|
||||
.count();
|
||||
|
||||
int totalOcrBlocks = blockRepository.findAll().size();
|
||||
int totalOcrBlocks = (int) blockRepository.count();
|
||||
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
|
||||
|
||||
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop5ByOrderByCreatedAtDesc();
|
||||
|
||||
@@ -9,6 +9,8 @@ import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.transaction.support.TransactionCallback;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
@@ -29,6 +31,7 @@ class OcrTrainingServiceTest {
|
||||
OcrClient ocrClient;
|
||||
OcrHealthClient healthClient;
|
||||
TranscriptionBlockRepository blockRepository;
|
||||
TransactionTemplate txTemplate;
|
||||
OcrTrainingService service;
|
||||
|
||||
@BeforeEach
|
||||
@@ -39,10 +42,17 @@ class OcrTrainingServiceTest {
|
||||
ocrClient = mock(OcrClient.class);
|
||||
healthClient = mock(OcrHealthClient.class);
|
||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||
txTemplate = mock(TransactionTemplate.class);
|
||||
|
||||
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository);
|
||||
// Execute transaction callbacks inline so unit tests run without a real DataSource
|
||||
when(txTemplate.execute(any())).thenAnswer(inv -> {
|
||||
TransactionCallback<?> callback = inv.getArgument(0);
|
||||
return callback.doInTransaction(null);
|
||||
});
|
||||
|
||||
when(blockRepository.findAll()).thenReturn(List.of());
|
||||
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
|
||||
|
||||
when(blockRepository.count()).thenReturn(0L);
|
||||
when(runRepository.findTop5ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||
when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user