feat(ocr): per-sender specialized Kurrent models with automatic active-learning retraining #263

Merged
marcel merged 32 commits from feat/issue-253-sender-models into main 2026-04-18 12:30:57 +02:00
42 changed files with 1495 additions and 136 deletions

View File

@@ -48,7 +48,7 @@ jobs:
path: frontend/test-results/screenshots/
# ─── OCR Service Unit Tests ───────────────────────────────────────────────────
# Only spell_check.py and test_confidence.py — no ML stack required.
# Only spell_check.py, test_confidence.py, test_sender_registry.py — no ML stack required.
ocr-tests:
name: OCR Service Tests
runs-on: ubuntu-latest
@@ -60,11 +60,11 @@ jobs:
python-version: '3.11'
- name: Install test dependencies
run: pip install "pyspellchecker==0.9.0" pytest
run: pip install "pyspellchecker==0.9.0" pytest pytest-asyncio
working-directory: ocr-service
- name: Run OCR unit tests (no ML stack required)
run: python -m pytest test_spell_check.py test_confidence.py -v
run: python -m pytest test_spell_check.py test_confidence.py test_sender_registry.py -v
working-directory: ocr-service
# ─── Backend Unit & Slice Tests ───────────────────────────────────────────────

View File

@@ -59,6 +59,9 @@ public class OcrTrainingRun {
@Column(name = "triggered_by")
private UUID triggeredBy;
@Column(name = "person_id")
private UUID personId;
@CreationTimestamp
@Column(name = "created_at", nullable = false, updatable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)

View File

@@ -0,0 +1,56 @@
package org.raddatz.familienarchiv.model;
import com.fasterxml.jackson.annotation.JsonIgnore;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.persistence.*;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.hibernate.annotations.CreationTimestamp;
import org.hibernate.annotations.UpdateTimestamp;
import java.time.Instant;
import java.util.UUID;
@Entity
@Table(name = "sender_models")
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class SenderModel {
@Id
@GeneratedValue(strategy = GenerationType.UUID)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID id;
@Column(name = "person_id", nullable = false, unique = true)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID personId;
@JsonIgnore
@Column(name = "model_path", nullable = false)
private String modelPath;
@Column
private Double accuracy;
@Column
private Double cer;
@Column(name = "corrected_lines_at_training", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private int correctedLinesAtTraining;
@CreationTimestamp
@Column(name = "created_at", nullable = false, updatable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private Instant createdAt;
@UpdateTimestamp
@Column(name = "updated_at", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private Instant updatedAt;
}

View File

@@ -1,6 +1,7 @@
package org.raddatz.familienarchiv.model;
public enum TrainingStatus {
QUEUED,
RUNNING,
DONE,
FAILED

View File

@@ -12,5 +12,11 @@ public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun,
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
List<OcrTrainingRun> findTop10ByOrderByCreatedAtDesc();
Optional<OcrTrainingRun> findFirstByStatusOrderByCreatedAtAsc(TrainingStatus status);
Optional<OcrTrainingRun> findFirstByPersonIdAndStatus(UUID personId, TrainingStatus status);
boolean existsByPersonIdAndStatus(UUID personId, TrainingStatus status);
List<OcrTrainingRun> findTop20ByOrderByCreatedAtDesc();
}

View File

@@ -0,0 +1,12 @@
package org.raddatz.familienarchiv.repository;
import org.raddatz.familienarchiv.model.SenderModel;
import org.springframework.data.jpa.repository.JpaRepository;
import java.util.Optional;
import java.util.UUID;
public interface SenderModelRepository extends JpaRepository<SenderModel, UUID> {
Optional<SenderModel> findByPersonId(UUID personId);
}

View File

@@ -3,6 +3,7 @@ package org.raddatz.familienarchiv.repository;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import java.util.List;
import java.util.Optional;
@@ -37,4 +38,22 @@ public interface TranscriptionBlockRepository extends JpaRepository<Transcriptio
AND 'KURRENT_SEGMENTATION' MEMBER OF d.trainingLabels
""")
List<TranscriptionBlock> findSegmentationBlocks();
@Query("""
SELECT COUNT(b) FROM TranscriptionBlock b
JOIN Document d ON d.id = b.documentId
WHERE b.source = 'MANUAL'
AND d.sender.id = :personId
AND d.scriptType = 'HANDWRITING_KURRENT'
""")
long countManualKurrentBlocksByPerson(@Param("personId") UUID personId);
@Query("""
SELECT b FROM TranscriptionBlock b
JOIN Document d ON d.id = b.documentId
WHERE b.source = 'MANUAL'
AND d.sender.id = :personId
AND d.scriptType = 'HANDWRITING_KURRENT'
""")
List<TranscriptionBlock> findManualKurrentBlocksByPerson(@Param("personId") UUID personId);
}

View File

@@ -9,10 +9,12 @@ import org.raddatz.familienarchiv.repository.OcrJobRepository;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
@Component
@RequiredArgsConstructor
@@ -29,6 +31,7 @@ public class OcrAsyncRunner {
private final OcrJobRepository ocrJobRepository;
private final OcrJobDocumentRepository ocrJobDocumentRepository;
private final OcrProgressService ocrProgressService;
private final SenderModelService senderModelService;
@Async
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
@@ -68,12 +71,18 @@ public class OcrAsyncRunner {
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
String senderModelPath = null;
if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) {
senderModelPath = senderModelService.maybeGetModelPath(doc.getSender().getId()).orElse(null);
}
AtomicInteger blockCounter = new AtomicInteger(0);
AtomicInteger currentPage = new AtomicInteger(0);
AtomicInteger skippedPages = new AtomicInteger(0);
AtomicInteger totalPages = new AtomicInteger(0);
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, event -> {
final String finalSenderModelPath = senderModelPath;
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, finalSenderModelPath, event -> {
switch (event) {
case OcrStreamEvent.Start start -> {
totalPages.set(start.totalPages());
@@ -207,7 +216,25 @@ public class OcrAsyncRunner {
clearExistingBlocks(documentId);
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
List<OcrBlockResult> blocks = ocrClient.extractBlocks(pdfUrl, doc.getScriptType());
String senderModelPath = null;
if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) {
senderModelPath = senderModelService.maybeGetModelPath(doc.getSender().getId()).orElse(null);
}
final AtomicReference<List<OcrBlockResult>> blocksRef = new AtomicReference<>();
final String finalSenderModelPath = senderModelPath;
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), null, finalSenderModelPath, event -> {
switch (event) {
case OcrStreamEvent.Page page -> {
blocksRef.compareAndSet(null, new ArrayList<>());
blocksRef.get().addAll(page.blocks());
}
default -> {}
}
});
List<OcrBlockResult> blocks = blocksRef.get() != null ? blocksRef.get() : List.of();
createTranscriptionBlocks(documentId, blocks, userId, doc.getFileHash());
}

View File

@@ -1,6 +1,7 @@
package org.raddatz.familienarchiv.service;
import org.raddatz.familienarchiv.model.ScriptType;
import org.springframework.lang.Nullable;
import java.util.ArrayList;
import java.util.LinkedHashMap;
@@ -37,15 +38,27 @@ public interface OcrClient {
TrainingResult segtrainModel(byte[] trainingDataZip);
/**
* Stream OCR results page-by-page via NDJSON. Implementations should override
* this method. The default exists only for backward compatibility during migration
* — it calls extractBlocks() and synthesizes events from the collected result.
* Fine-tune the Kurrent model for a specific sender.
*
* @param regions optional list of pre-drawn annotation regions; when non-null,
* the OCR service runs in guided mode (crop + recognize per region)
* @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService.exportForSender()
* @param outputModelPath where to save the trained model (e.g. /app/models/sender_{uuid}.mlmodel)
* @return training result metrics
*/
TrainingResult trainSenderModel(byte[] trainingDataZip, String outputModelPath);
/**
* Stream OCR results page-by-page via NDJSON, optionally using a sender-specific model.
* The default implementation synthesizes events from extractBlocks() for backward compatibility.
* Implementations that support real streaming (e.g. RestClientOcrClient) override this.
*
* @param regions optional list of pre-drawn annotation regions; when non-null,
* the OCR service runs in guided mode (crop + recognize per region)
* @param senderModelPath optional path to a per-sender model file; null means use base model
*/
default void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
List<OcrRegion> regions,
@Nullable String senderModelPath,
Consumer<OcrStreamEvent> handler) {
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();
@@ -62,4 +75,9 @@ public interface OcrClient {
handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0));
}
default void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
streamBlocks(pdfUrl, scriptType, regions, null, handler);
}
}

View File

@@ -17,9 +17,11 @@ import org.springframework.transaction.support.TransactionTemplate;
import java.io.ByteArrayOutputStream;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
@Service
@@ -34,6 +36,7 @@ public class OcrTrainingService {
private final OcrHealthClient ocrHealthClient;
private final TranscriptionBlockRepository blockRepository;
private final TransactionTemplate txTemplate;
private final PersonService personService;
public record TrainingInfoResponse(
int availableBlocks,
@@ -42,7 +45,8 @@ public class OcrTrainingService {
int availableSegBlocks,
boolean ocrServiceAvailable,
OcrTrainingRun lastRun,
List<OcrTrainingRun> runs
List<OcrTrainingRun> runs,
Map<String, String> personNames
) {}
private void assertNoRunningTraining() {
@@ -195,9 +199,20 @@ public class OcrTrainingService {
int totalOcrBlocks = (int) blockRepository.count();
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop10ByOrderByCreatedAtDesc();
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop20ByOrderByCreatedAtDesc();
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
List<UUID> distinctPersonIds = recentRuns.stream()
.map(OcrTrainingRun::getPersonId)
.filter(Objects::nonNull)
.distinct()
.toList();
Map<String, String> personNames = new HashMap<>();
if (!distinctPersonIds.isEmpty()) {
personService.getAllById(distinctPersonIds)
.forEach(p -> personNames.put(p.getId().toString(), p.getDisplayName()));
}
return new TrainingInfoResponse(
eligibleBlocks.size(),
totalOcrBlocks,
@@ -205,7 +220,8 @@ public class OcrTrainingService {
availableSegBlocks,
ocrHealthClient.isHealthy(),
lastRun,
recentRuns
recentRuns,
personNames
);
}
@@ -224,15 +240,4 @@ public class OcrTrainingService {
});
}
public Map<String, Object> buildTrainingInfoMap(TrainingInfoResponse info) {
return Map.of(
"availableBlocks", info.availableBlocks(),
"totalOcrBlocks", info.totalOcrBlocks(),
"availableDocuments", info.availableDocuments(),
"availableSegBlocks", info.availableSegBlocks(),
"ocrServiceAvailable", info.ocrServiceAvailable(),
"lastRun", info.lastRun() != null ? info.lastRun() : Map.of(),
"runs", info.runs()
);
}
}

View File

@@ -14,6 +14,7 @@ import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
@@ -102,6 +103,13 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.toList();
}
private RestClient.RequestBodySpec addTrainingAuth(RestClient.RequestBodySpec spec) {
if (trainingToken != null && !trainingToken.isBlank()) {
return spec.header("X-Training-Token", trainingToken);
}
return spec;
}
@Override
public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) {
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
@@ -114,15 +122,10 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders));
var spec = trainingRestClient.post()
.uri("/train")
.contentType(MediaType.MULTIPART_FORM_DATA);
if (trainingToken != null && !trainingToken.isBlank()) {
spec = spec.header("X-Training-Token", trainingToken);
}
TrainingResultJson result = spec
TrainingResultJson result = addTrainingAuth(
trainingRestClient.post()
.uri("/train")
.contentType(MediaType.MULTIPART_FORM_DATA))
.body(body)
.retrieve()
.body(TrainingResultJson.class);
@@ -143,15 +146,35 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders));
var spec = trainingRestClient.post()
.uri("/segtrain")
.contentType(MediaType.MULTIPART_FORM_DATA);
TrainingResultJson result = addTrainingAuth(
trainingRestClient.post()
.uri("/segtrain")
.contentType(MediaType.MULTIPART_FORM_DATA))
.body(body)
.retrieve()
.body(TrainingResultJson.class);
if (trainingToken != null && !trainingToken.isBlank()) {
spec = spec.header("X-Training-Token", trainingToken);
}
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
}
TrainingResultJson result = spec
@Override
public OcrClient.TrainingResult trainSenderModel(byte[] trainingDataZip, String outputModelPath) {
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
@Override
public String getFilename() { return "sender-training-data.zip"; }
};
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
HttpHeaders partHeaders = new HttpHeaders();
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders));
body.add("output_model_path", outputModelPath);
TrainingResultJson result = addTrainingAuth(
trainingRestClient.post()
.uri("/train-sender")
.contentType(MediaType.MULTIPART_FORM_DATA))
.body(body)
.retrieve()
.body(TrainingResultJson.class);
@@ -176,7 +199,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
@Override
public void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
List<OcrRegion> regions, @Nullable String senderModelPath,
Consumer<OcrStreamEvent> handler) {
String body;
try {
var requestMap = new java.util.LinkedHashMap<String, Object>();
@@ -186,6 +210,9 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
if (regions != null && !regions.isEmpty()) {
requestMap.put("regions", regions);
}
if (senderModelPath != null) {
requestMap.put("senderModelPath", senderModelPath);
}
body = NDJSON_MAPPER.writeValueAsString(requestMap);
} catch (IOException e) {
throw new RuntimeException("Failed to serialize OCR request", e);
@@ -204,7 +231,12 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
if (response.statusCode() == 404) {
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
OcrClient.super.streamBlocks(pdfUrl, scriptType, regions, handler);
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
handler.accept(new OcrStreamEvent.Start(0));
for (OcrBlockResult block : allBlocks) {
handler.accept(new OcrStreamEvent.Page(block.pageNumber(), List.of(block)));
}
handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0));
return;
}

View File

@@ -0,0 +1,197 @@
package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.slf4j.MDC;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.TransactionTemplate;
import java.io.ByteArrayOutputStream;
import java.time.Instant;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
@Service
@RequiredArgsConstructor
@Slf4j
public class SenderModelService {
private final SenderModelRepository senderModelRepository;
private final TranscriptionBlockRepository blockRepository;
private final OcrTrainingRunRepository trainingRunRepository;
private final OcrClient ocrClient;
private final TransactionTemplate txTemplate;
private final TrainingDataExportService trainingDataExportService;
@Value("${ocr.sender-model.activation-threshold:100}")
private int activationThreshold;
@Value("${ocr.sender-model.retrain-delta:50}")
private int retrainDelta;
/** Returns the model path if a trained sender model exists for this person. */
public Optional<String> maybeGetModelPath(UUID personId) {
return senderModelRepository.findByPersonId(personId)
.map(SenderModel::getModelPath);
}
/**
* Called after every MANUAL block save for HANDWRITING_KURRENT documents.
* Checks activation and retrain thresholds; enqueues or starts sender training when met.
*/
@Async
public void checkAndTriggerTraining(UUID personId) {
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
Optional<SenderModel> existing = senderModelRepository.findByPersonId(personId);
boolean shouldActivate = existing.isEmpty() && correctedLines >= activationThreshold;
boolean shouldRetrain = existing.isPresent()
&& (correctedLines - existing.get().getCorrectedLinesAtTraining()) >= retrainDelta;
if (!shouldActivate && !shouldRetrain) {
return;
}
log.info("Sender training threshold met for person {} (correctedLines={}, activate={}, retrain={})",
personId, correctedLines, shouldActivate, shouldRetrain);
boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines);
if (runNow) {
triggerSenderTraining(personId, (int) correctedLines);
}
}
/**
* Atomically checks the queue state and either creates a RUNNING row (returns true) or a
* QUEUED row (returns false). All three operations — idle check, duplicate-queue guard, and
* RUNNING row creation — happen in one transaction, eliminating the race window that would
* otherwise exist between the check and a separate RUNNING row creation.
*/
@Transactional
public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) {
if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) {
log.info("Sender training already queued for person {} — skipping duplicate trigger", personId);
return false;
}
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId);
trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.QUEUED)
.personId(personId)
.blockCount(blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
log.info("Queued sender training for person {} — training already running", personId);
return false;
}
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId);
trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.personId(personId)
.blockCount((int) blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
return true;
}
/**
* Executes sender training synchronously. Caller must run this on a background thread.
* The RUNNING row is expected to already exist — created atomically by
* runOrQueueSenderTraining (for new runs) or by promoteNextQueuedRun (for promoted runs).
*/
public void triggerSenderTraining(UUID personId, int correctedLines) {
String outputModelPath = "/app/models/sender_" + personId + ".mlmodel";
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status ->
trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING)
.orElseThrow(() -> new IllegalStateException(
"Expected RUNNING row for person " + personId + " but none found"))));
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started sender training run {} for person {}", runId, personId);
try {
byte[] zipBytes = exportSenderData(personId);
log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath);
txTemplate.execute(status -> {
SenderModel model = senderModelRepository.findByPersonId(personId)
.orElseGet(() -> SenderModel.builder().personId(personId).build());
model.setModelPath(outputModelPath);
model.setCer(result.cer());
model.setAccuracy(result.accuracy());
model.setCorrectedLinesAtTraining(correctedLines);
senderModelRepository.save(model);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setCer(result.cer());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
trainingRunRepository.save(run);
log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer());
return null;
});
} catch (Exception e) {
txTemplate.execute(status -> {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
trainingRunRepository.save(run);
log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e);
return null;
});
} finally {
MDC.remove("trainingRunId");
promoteNextQueuedRun();
}
}
private byte[] exportSenderData(UUID personId) throws java.io.IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
trainingDataExportService.exportForSender(personId).writeTo(baos);
return baos.toByteArray();
}
/**
* Promotes the oldest QUEUED sender run to RUNNING and triggers its training.
* Called in the finally block of triggerSenderTraining, creating a sequential chain:
* each run promotes the next only after it fully completes (success or failure).
*
* This is intentionally tail-recursive via the @Async thread: the same thread holds the
* full queue drain, serialising all sender training runs naturally without an external
* scheduler. With N queued runs the thread stays occupied for N sequential trainings —
* acceptable because the @Async executor is dedicated to long-running background work.
*/
private void promoteNextQueuedRun() {
Optional<OcrTrainingRun> queuedOpt = txTemplate.execute(status ->
trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)
.map(queued -> {
queued.setStatus(TrainingStatus.RUNNING);
return trainingRunRepository.save(queued);
}));
if (queuedOpt != null && queuedOpt.isPresent()) {
OcrTrainingRun promoted = queuedOpt.get();
log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId());
long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId());
triggerSenderTraining(promoted.getPersonId(), (int) freshCount);
}
}
}

View File

@@ -38,10 +38,20 @@ public class TrainingDataExportService {
}
public StreamingResponseBody exportToZip() {
// Collect all data before entering the lambda — no open DB txn during streaming
List<TranscriptionBlock> blocks = queryEligibleBlocks();
return exportBlocksToZip(queryEligibleBlocks());
}
public List<TranscriptionBlock> queryBlocksForSender(UUID personId) {
return blockRepository.findManualKurrentBlocksByPerson(personId);
}
public StreamingResponseBody exportForSender(UUID personId) {
return exportBlocksToZip(queryBlocksForSender(personId));
}
private StreamingResponseBody exportBlocksToZip(List<TranscriptionBlock> blocks) {
if (blocks.isEmpty()) {
return out -> {}; // caller checks isEmpty() for 204 response
return out -> {};
}
// Group blocks by documentId so we only download each PDF once

View File

@@ -11,6 +11,7 @@ import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.BlockSource;
import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.model.DocumentAnnotation;
import org.raddatz.familienarchiv.model.ScriptType;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.model.TranscriptionBlockVersion;
import org.raddatz.familienarchiv.repository.AnnotationRepository;
@@ -35,6 +36,7 @@ public class TranscriptionService {
private final AnnotationRepository annotationRepository;
private final AnnotationService annotationService;
private final DocumentService documentService;
private final SenderModelService senderModelService;
public List<TranscriptionBlock> listBlocks(UUID documentId) {
return blockRepository.findByDocumentIdOrderBySortOrderAsc(documentId);
@@ -122,6 +124,7 @@ public class TranscriptionService {
String text = sanitizeText(dto.getText());
block.setText(text);
block.setSource(BlockSource.MANUAL);
if (dto.getLabel() != null) {
block.setLabel(dto.getLabel());
}
@@ -129,6 +132,12 @@ public class TranscriptionService {
TranscriptionBlock saved = blockRepository.save(block);
saveVersion(saved, userId);
Document doc = documentService.getDocumentById(documentId);
if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) {
senderModelService.checkAndTriggerTraining(doc.getSender().getId());
}
return saved;
}

View File

@@ -77,3 +77,8 @@ app:
tags: 10
summary: 11
transcription: 13
ocr:
sender-model:
activation-threshold: 100
retrain-delta: 50

View File

@@ -0,0 +1,10 @@
CREATE TABLE sender_models (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
person_id UUID NOT NULL UNIQUE REFERENCES persons(id) ON DELETE CASCADE,
model_path TEXT NOT NULL,
accuracy DOUBLE PRECISION,
cer DOUBLE PRECISION,
corrected_lines_at_training INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);

View File

@@ -0,0 +1,2 @@
ALTER TABLE ocr_training_runs
ADD COLUMN person_id UUID REFERENCES persons(id) ON DELETE SET NULL;

View File

@@ -0,0 +1,3 @@
CREATE UNIQUE INDEX idx_training_runs_queued_per_person
ON ocr_training_runs(person_id)
WHERE status = 'QUEUED';

View File

@@ -23,6 +23,8 @@ import org.springframework.test.web.servlet.MockMvc;
import java.util.List;
import java.util.UUID;
import java.util.Map;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.eq;
@@ -219,7 +221,7 @@ class OcrControllerTest {
@WithMockUser(authorities = "ADMIN")
void getTrainingInfo_returns200_withInfo() throws Exception {
OcrTrainingService.TrainingInfoResponse info =
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of());
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(), Map.of());
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info"))
@@ -228,6 +230,52 @@ class OcrControllerTest {
.andExpect(jsonPath("$.ocrServiceAvailable").value(true));
}
@Test
@WithMockUser(authorities = "ADMIN")
void getTrainingInfo_returns200_and_omits_personName_when_service_provides_empty_map() throws Exception {
UUID personId = UUID.randomUUID();
OcrTrainingRun runWithPerson = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build();
OcrTrainingService.TrainingInfoResponse info =
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null,
List.of(runWithPerson), Map.of());
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.personNames").isEmpty());
}
@Test
@WithMockUser(authorities = "ADMIN")
void getTrainingInfo_includesPersonName_whenPersonIdResolves() throws Exception {
UUID personId = UUID.randomUUID();
OcrTrainingRun runWithPerson = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build();
OcrTrainingService.TrainingInfoResponse info =
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null,
List.of(runWithPerson), Map.of(personId.toString(), "Max Mustermann"));
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.personNames." + personId).value("Max Mustermann"));
}
@Test
@WithMockUser(authorities = "ADMIN")
void getTrainingInfo_serializes_null_lastRun_as_json_null() throws Exception {
OcrTrainingService.TrainingInfoResponse info =
new OcrTrainingService.TrainingInfoResponse(0, 0, 0, 0, false, null, List.of(), Map.of());
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.lastRun").doesNotExist());
}
@Test
@WithMockUser(authorities = "READ_ALL")
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {

View File

@@ -217,8 +217,71 @@ class MigrationIntegrationTest {
assertThat(stored).isEqualTo("sage");
}
// ─── V42: idx_training_runs_queued_per_person partial unique index ────────
@Test
@Transactional(propagation = Propagation.NOT_SUPPORTED)
void v42_partialUniqueIndex_preventsTwoQueuedRunsForSamePerson() {
UUID personId = createPerson("V42Test", "SamePerson");
jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
VALUES (gen_random_uuid(), 'QUEUED', 10, 2, 'sender_x', ?)
""", personId);
assertThatThrownBy(() ->
jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
VALUES (gen_random_uuid(), 'QUEUED', 5, 1, 'sender_x', ?)
""", personId)
).isInstanceOf(DataIntegrityViolationException.class);
jdbc.update("DELETE FROM ocr_training_runs WHERE person_id = ?", personId);
jdbc.update("DELETE FROM persons WHERE id = ?", personId);
}
@Test
void v42_partialUniqueIndex_allowsQueuedRunsForDifferentPersons() {
UUID person1 = createPerson("V42Test", "PersonA");
UUID person2 = createPerson("V42Test", "PersonB");
int rows1 = jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
VALUES (gen_random_uuid(), 'QUEUED', 10, 2, 'sender_a', ?)
""", person1);
int rows2 = jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
VALUES (gen_random_uuid(), 'QUEUED', 5, 1, 'sender_b', ?)
""", person2);
assertThat(rows1).isEqualTo(1);
assertThat(rows2).isEqualTo(1);
}
@Test
void v42_partialUniqueIndex_allowsMultipleDoneRunsForSamePerson() {
UUID personId = createPerson("V42Test", "DonePerson");
int rows1 = jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
VALUES (gen_random_uuid(), 'DONE', 10, 2, 'sender_x', ?)
""", personId);
int rows2 = jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
VALUES (gen_random_uuid(), 'DONE', 15, 3, 'sender_x', ?)
""", personId);
assertThat(rows1).isEqualTo(1);
assertThat(rows2).isEqualTo(1);
}
// ─── helpers ─────────────────────────────────────────────────────────────
private UUID createPerson(String firstName, String lastName) {
UUID id = UUID.randomUUID();
jdbc.update("INSERT INTO persons (id, first_name, last_name) VALUES (?, ?, ?)", id, firstName, lastName);
return id;
}
private UUID createTag(String name) {
UUID id = UUID.randomUUID();
jdbc.update("INSERT INTO tag (id, name) VALUES (?, ?)", id, name);

View File

@@ -31,6 +31,7 @@ class OcrAsyncRunnerTest {
@Mock OcrJobRepository ocrJobRepository;
@Mock OcrJobDocumentRepository ocrJobDocumentRepository;
@Mock OcrProgressService ocrProgressService;
@Mock SenderModelService senderModelService;
@InjectMocks OcrAsyncRunner ocrAsyncRunner;
@@ -42,7 +43,12 @@ class OcrAsyncRunnerTest {
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of());
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(0));
handler.accept(new OcrStreamEvent.Done(0, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.processDocument(docId, doc, userId);
@@ -59,9 +65,15 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null)));
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(1));
handler.accept(new OcrStreamEvent.Page(0, List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null))));
handler.accept(new OcrStreamEvent.Done(2, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
@@ -83,8 +95,14 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null)));
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(1));
handler.accept(new OcrStreamEvent.Page(0, List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null))));
handler.accept(new OcrStreamEvent.Done(1, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
@@ -112,12 +130,12 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(1));
handler.accept(new OcrStreamEvent.Page(0, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -142,7 +160,7 @@ class OcrAsyncRunnerTest {
when(documentService.getDocumentById(docId)).thenReturn(doc);
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any());
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -174,7 +192,7 @@ class OcrAsyncRunnerTest {
List<String> progressMessages = new ArrayList<>();
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(3));
handler.accept(new OcrStreamEvent.Page(0, List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null),
@@ -185,7 +203,7 @@ class OcrAsyncRunnerTest {
progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Done(3, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -215,14 +233,14 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(3));
handler.accept(new OcrStreamEvent.Page(0, List.of()));
handler.accept(new OcrStreamEvent.Error(1, "failed"));
handler.accept(new OcrStreamEvent.Page(2, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 1));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -251,14 +269,14 @@ class OcrAsyncRunnerTest {
List<String> progressMessages = new ArrayList<>();
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(5));
handler.accept(new OcrStreamEvent.Preprocessing(1));
progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Page(1, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -287,13 +305,13 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(2));
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
handler.accept(new OcrStreamEvent.Page(1, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 1));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);

View File

@@ -21,6 +21,7 @@ class OcrClientDefaultStreamTest {
}
@Override public TrainingResult trainModel(byte[] zip) { return null; }
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
@Override public TrainingResult trainSenderModel(byte[] zip, String path) { return null; }
};
List<OcrStreamEvent> events = new ArrayList<>();
@@ -52,6 +53,7 @@ class OcrClientDefaultStreamTest {
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); }
@Override public TrainingResult trainModel(byte[] zip) { return null; }
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
@Override public TrainingResult trainSenderModel(byte[] zip, String path) { return null; }
};
List<OcrStreamEvent> events = new ArrayList<>();

View File

@@ -4,11 +4,12 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingLabel;
import org.raddatz.familienarchiv.model.Person;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.raddatz.familienarchiv.service.PersonService;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
@@ -32,6 +33,7 @@ class OcrTrainingServiceTest {
OcrHealthClient healthClient;
TranscriptionBlockRepository blockRepository;
TransactionTemplate txTemplate;
PersonService personService;
OcrTrainingService service;
@BeforeEach
@@ -43,6 +45,7 @@ class OcrTrainingServiceTest {
healthClient = mock(OcrHealthClient.class);
blockRepository = mock(TranscriptionBlockRepository.class);
txTemplate = mock(TransactionTemplate.class);
personService = mock(PersonService.class);
// Execute transaction callbacks inline so unit tests run without a real DataSource
when(txTemplate.execute(any())).thenAnswer(inv -> {
@@ -50,10 +53,10 @@ class OcrTrainingServiceTest {
return callback.doInTransaction(null);
});
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate, personService);
when(blockRepository.count()).thenReturn(0L);
when(runRepository.findTop10ByOrderByCreatedAtDesc()).thenReturn(List.of());
when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of());
when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
}
@@ -230,6 +233,37 @@ class OcrTrainingServiceTest {
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
}
// ─── getTrainingInfo: batch person name resolution ────────────────────────
@Test
void getTrainingInfo_resolves_person_names_in_single_batch_call() {
UUID personA = UUID.randomUUID();
UUID personB = UUID.randomUUID();
List<OcrTrainingRun> runs = List.of(
OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personA).blockCount(5).documentCount(1).modelName("sender_a").build(),
OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personB).blockCount(5).documentCount(1).modelName("sender_b").build(),
OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personA).blockCount(5).documentCount(1).modelName("sender_a").build()
);
when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(runs);
when(exportService.queryEligibleBlocks()).thenReturn(List.of());
Person pa = Person.builder().id(personA).firstName("Anna").lastName("Müller").build();
Person pb = Person.builder().id(personB).firstName("Karl").lastName("Bauer").build();
when(personService.getAllById(argThat(ids -> ids.size() == 2 && ids.containsAll(List.of(personA, personB)))))
.thenReturn(List.of(pa, pb));
when(healthClient.isHealthy()).thenReturn(true);
OcrTrainingService.TrainingInfoResponse info = service.getTrainingInfo();
verify(personService, never()).getById(any());
verify(personService, times(1)).getAllById(any());
assertThat(info.personNames()).containsKey(personA.toString());
assertThat(info.personNames()).containsKey(personB.toString());
}
// ─── Orphan recovery ──────────────────────────────────────────────────────
@Test

View File

@@ -0,0 +1,41 @@
package org.raddatz.familienarchiv.service;
import com.sun.net.httpserver.HttpServer;
import org.junit.jupiter.api.Test;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import static org.assertj.core.api.Assertions.assertThat;
class RestClientOcrClientTrainingTest {
@Test
void trainSenderModel_postsToTrainSenderEndpoint() throws Exception {
List<String> capturedPaths = new CopyOnWriteArrayList<>();
HttpServer server = HttpServer.create(new InetSocketAddress(0), 0);
server.createContext("/", exchange -> {
capturedPaths.add(exchange.getRequestURI().getPath());
byte[] response = "{\"loss\":0.1,\"accuracy\":0.95,\"cer\":0.05,\"epochs\":3}"
.getBytes(StandardCharsets.UTF_8);
exchange.getResponseHeaders().set("Content-Type", "application/json");
exchange.sendResponseHeaders(200, response.length);
try (OutputStream os = exchange.getResponseBody()) {
os.write(response);
}
});
server.start();
try {
int port = server.getAddress().getPort();
RestClientOcrClient client = new RestClientOcrClient("http://localhost:" + port, "");
client.trainSenderModel(new byte[]{1, 2, 3}, "/app/models/sender_test.mlmodel");
assertThat(capturedPaths).contains("/train-sender");
} finally {
server.stop(0);
}
}
}

View File

@@ -0,0 +1,302 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.junit.jupiter.MockitoExtension;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
import java.util.Optional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class SenderModelServiceTest {
SenderModelRepository senderModelRepository;
TranscriptionBlockRepository blockRepository;
OcrTrainingRunRepository trainingRunRepository;
OcrClient ocrClient;
TransactionTemplate txTemplate;
TrainingDataExportService trainingDataExportService;
SenderModelService service;
UUID personId = UUID.randomUUID();
@BeforeEach
void setUp() {
senderModelRepository = mock(SenderModelRepository.class);
blockRepository = mock(TranscriptionBlockRepository.class);
trainingRunRepository = mock(OcrTrainingRunRepository.class);
ocrClient = mock(OcrClient.class);
txTemplate = mock(TransactionTemplate.class);
trainingDataExportService = mock(TrainingDataExportService.class);
// Execute transaction callbacks inline so unit tests run without a real DataSource.
// lenient: not every test hits the txTemplate path, but the setup is shared.
lenient().when(txTemplate.execute(any())).thenAnswer(inv -> {
TransactionCallback<?> callback = inv.getArgument(0);
return callback.doInTransaction(null);
});
service = new SenderModelService(senderModelRepository, blockRepository,
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService);
ReflectionTestUtils.setField(service, "activationThreshold", 100);
ReflectionTestUtils.setField(service, "retrainDelta", 50);
}
// ─── Activation threshold ─────────────────────────────────────────────────
@Test
void checkAndTriggerTraining_doesNothing_belowActivationThreshold() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
SenderModelService spy = spy(service);
spy.checkAndTriggerTraining(personId);
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
}
@Test
void checkAndTriggerTraining_triggersTraining_atActivationThreshold() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
SenderModelService spy = spy(service);
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
spy.checkAndTriggerTraining(personId);
verify(spy).runOrQueueSenderTraining(personId, 100);
}
// ─── Retrain delta ────────────────────────────────────────────────────────
@Test
void checkAndTriggerTraining_doesNothing_belowRetrainDelta() {
SenderModel existing = SenderModel.builder().personId(personId)
.correctedLinesAtTraining(100).build();
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L);
SenderModelService spy = spy(service);
spy.checkAndTriggerTraining(personId);
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
}
@Test
void checkAndTriggerTraining_triggersRetrain_atRetrainDelta() {
SenderModel existing = SenderModel.builder().personId(personId)
.correctedLinesAtTraining(100).build();
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L);
SenderModelService spy = spy(service);
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 150);
spy.checkAndTriggerTraining(personId);
verify(spy).runOrQueueSenderTraining(personId, 150);
}
// ─── runNow flag ──────────────────────────────────────────────────────────
@Test
void checkAndTriggerTraining_callsTrigger_whenRunNow() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
SenderModelService spy = spy(service);
doReturn(true).when(spy).runOrQueueSenderTraining(personId, 100);
doNothing().when(spy).triggerSenderTraining(personId, 100);
spy.checkAndTriggerTraining(personId);
verify(spy).triggerSenderTraining(personId, 100);
}
@Test
void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
SenderModelService spy = spy(service);
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
spy.checkAndTriggerTraining(personId);
verify(spy, never()).triggerSenderTraining(any(), anyInt());
}
// ─── runOrQueueSenderTraining ─────────────────────────────────────────────
@Test
void runOrQueueSenderTraining_returnsTrue_whenIdle() {
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isTrue();
verify(trainingRunRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED));
}
@Test
void runOrQueueSenderTraining_savesQueued_whenRunning() {
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(
Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent").build()));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(trainingRunRepository).save(argThat(r ->
r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId())));
}
@Test
void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() {
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true);
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(trainingRunRepository, never()).save(any());
}
@Test
void runOrQueueSenderTraining_createsRunningRow_atomically() {
// RUNNING row must be created inside the same @Transactional boundary as the idle check,
// eliminating the race window between the check and a separate triggerSenderTraining call.
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
OcrTrainingRun r = inv.getArgument(0);
if (r.getId() == null) r.setId(UUID.randomUUID());
return r;
});
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isTrue();
verify(trainingRunRepository).save(argThat(r ->
r.getStatus() == TrainingStatus.RUNNING && personId.equals(r.getPersonId())));
}
// ─── triggerSenderTraining — happy path ──────────────────────────────────
@Test
void triggerSenderTraining_savesModelRecord_onSuccess() throws Exception {
OcrTrainingRun runningRun = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.personId(personId).blockCount(10).documentCount(0)
.modelName("sender_" + personId).build();
when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING))
.thenReturn(Optional.of(runningRun));
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
when(ocrClient.trainSenderModel(any(), any()))
.thenReturn(new OcrClient.TrainingResult(0.08, 0.92, 0.08, 5));
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
when(senderModelRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
.thenReturn(Optional.empty());
service.triggerSenderTraining(personId, 120);
verify(senderModelRepository).save(argThat(m ->
("/app/models/sender_" + personId + ".mlmodel").equals(m.getModelPath())
&& Double.valueOf(0.08).equals(m.getCer())
&& personId.equals(m.getPersonId())));
verify(trainingRunRepository, atLeastOnce()).save(argThat(r ->
r.getStatus() == TrainingStatus.DONE
&& Double.valueOf(0.08).equals(r.getCer())
&& r.getCompletedAt() != null));
}
// ─── triggerSenderTraining — failure path ────────────────────────────────
@Test
void triggerSenderTraining_marksRunFailed_whenOcrClientThrows() throws Exception {
OcrTrainingRun runningRun = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.personId(personId).blockCount(10).documentCount(0)
.modelName("sender_" + personId).build();
when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING))
.thenReturn(Optional.of(runningRun));
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
when(ocrClient.trainSenderModel(any(), any()))
.thenThrow(new RuntimeException("OCR service timeout"));
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
.thenReturn(Optional.empty());
service.triggerSenderTraining(personId, 120);
verify(trainingRunRepository, atLeastOnce()).save(argThat(r ->
r.getStatus() == TrainingStatus.FAILED
&& r.getErrorMessage() != null
&& r.getCompletedAt() != null));
verify(senderModelRepository, never()).save(any());
}
// ─── triggerSenderTraining — queue promotion ──────────────────────────────
@Test
void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception {
UUID nextPersonId = UUID.randomUUID();
OcrTrainingRun runningRun = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.personId(personId).blockCount(10).documentCount(0)
.modelName("sender_" + personId).build();
when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING))
.thenReturn(Optional.of(runningRun));
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
OcrTrainingRun r = inv.getArgument(0);
if (r.getId() == null) r.setId(UUID.randomUUID());
return r;
});
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
when(ocrClient.trainSenderModel(any(), any()))
.thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
OcrTrainingRun queued = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.QUEUED)
.personId(nextPersonId).blockCount(5).documentCount(0)
.modelName("sender_" + nextPersonId).build();
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
.thenReturn(Optional.of(queued));
when(blockRepository.countManualKurrentBlocksByPerson(nextPersonId)).thenReturn(5L);
SenderModelService spy = spy(service);
// Stub the recursive call to stop the chain after one promotion
doNothing().when(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
spy.triggerSenderTraining(personId, 10);
// Queued run promoted to RUNNING
verify(trainingRunRepository).save(argThat(r ->
nextPersonId.equals(r.getPersonId()) && r.getStatus() == TrainingStatus.RUNNING));
// Next training was kicked off
verify(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
}
}

View File

@@ -22,6 +22,7 @@ class TranscriptionServiceGuidedTest {
AnnotationRepository annotationRepository;
AnnotationService annotationService;
DocumentService documentService;
SenderModelService senderModelService;
TranscriptionService service;
UUID docId = UUID.randomUUID();
@@ -35,9 +36,10 @@ class TranscriptionServiceGuidedTest {
annotationRepository = mock(AnnotationRepository.class);
annotationService = mock(AnnotationService.class);
documentService = mock(DocumentService.class);
senderModelService = mock(SenderModelService.class);
service = new TranscriptionService(blockRepository, versionRepository,
annotationRepository, annotationService, documentService);
annotationRepository, annotationService, documentService, senderModelService);
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));

View File

@@ -13,6 +13,8 @@ import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.BlockSource;
import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.model.DocumentAnnotation;
import org.raddatz.familienarchiv.model.Person;
import org.raddatz.familienarchiv.model.ScriptType;
import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.model.TranscriptionBlockVersion;
import org.raddatz.familienarchiv.repository.AnnotationRepository;
@@ -39,6 +41,7 @@ class TranscriptionServiceTest {
@Mock AnnotationRepository annotationRepository;
@Mock AnnotationService annotationService;
@Mock DocumentService documentService;
@Mock SenderModelService senderModelService;
@InjectMocks TranscriptionService transcriptionService;
// ─── getBlock ────────────────────────────────────────────────────────────────
@@ -156,6 +159,8 @@ class TranscriptionServiceTest {
.id(blockId).documentId(docId).text("old").build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.TYPEWRITER).build());
UpdateTranscriptionBlockDTO dto = new UpdateTranscriptionBlockDTO("new text", null);
@@ -175,6 +180,8 @@ class TranscriptionServiceTest {
.id(blockId).documentId(docId).text("text").label("old label").build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.TYPEWRITER).build());
UpdateTranscriptionBlockDTO dto = new UpdateTranscriptionBlockDTO("text", "Anrede");
@@ -183,6 +190,57 @@ class TranscriptionServiceTest {
assertThat(result.getLabel()).isEqualTo("Anrede");
}
@Test
void updateBlock_setsSourceToManual() {
UUID docId = UUID.randomUUID();
UUID blockId = UUID.randomUUID();
TranscriptionBlock block = TranscriptionBlock.builder()
.id(blockId).documentId(docId).text("old").source(BlockSource.OCR).build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.TYPEWRITER).build());
TranscriptionBlock result = transcriptionService.updateBlock(
docId, blockId, new UpdateTranscriptionBlockDTO("new", null), UUID.randomUUID());
assertThat(result.getSource()).isEqualTo(BlockSource.MANUAL);
}
@Test
void updateBlock_triggersTraining_whenKurrentSenderPresent() {
UUID docId = UUID.randomUUID();
UUID blockId = UUID.randomUUID();
UUID senderId = UUID.randomUUID();
Person sender = Person.builder().id(senderId).build();
TranscriptionBlock block = TranscriptionBlock.builder()
.id(blockId).documentId(docId).text("old").build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).sender(sender).build());
transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID());
verify(senderModelService).checkAndTriggerTraining(senderId);
}
@Test
void updateBlock_doesNotTriggerTraining_whenNoSender() {
UUID docId = UUID.randomUUID();
UUID blockId = UUID.randomUUID();
TranscriptionBlock block = TranscriptionBlock.builder()
.id(blockId).documentId(docId).text("old").build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).build());
transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID());
verify(senderModelService, never()).checkAndTriggerTraining(any());
}
// ─── deleteBlock ─────────────────────────────────────────────────────────────
@Test

View File

@@ -97,6 +97,7 @@ services:
DETECTOR_BATCH_SIZE: "8"
OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency)
OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page)
OCR_MAX_CACHED_MODELS: "2" # LRU cache; each model ~500 MB, so 2 = ~1 GB resident
networks:
- archive-net
healthcheck:

View File

@@ -330,6 +330,7 @@
"comment_time_days": "vor {count} Tag(en)",
"comment_panel_title": "Kommentare",
"comment_panel_close": "Schließen",
"comp_dismiss": "Schließen",
"doc_panel_tab_metadata": "Metadaten",
"doc_panel_tab_transcription": "Transkription",
"doc_panel_tab_discussion": "Diskussion",
@@ -541,6 +542,7 @@
"training_start_btn": "Training starten",
"training_in_progress": "…",
"training_success": "Training wurde gestartet und abgeschlossen.",
"training_start_failed": "Training konnte nicht gestartet werden. Bitte versuche es erneut.",
"training_too_few_blocks": "Mindestens 5 geprüfte Blöcke erforderlich (aktuell: {available}).",
"training_service_down": "OCR-Dienst ist nicht erreichbar.",
"training_history_heading": "Verlauf",
@@ -561,6 +563,11 @@
"transcription_block_segmentation_only": "Nur Segmentierung",
"training_chip_kurrent": "Kurrent-Erkennung",
"training_chip_segmentation": "Segmentierung",
"training_col_type": "Typ",
"training_type_base": "Basis",
"training_type_personalized": "Personalisiert",
"training_col_person": "Absender",
"training_status_queued": "Warteschlange",
"mission_control_heading": "Was braucht Aufmerksamkeit?",
"mission_control_segmentation_heading": "Text markieren",
"mission_control_segmentation_description": "Textbereiche markieren — keine Vorkenntnisse nötig",

View File

@@ -330,6 +330,7 @@
"comment_time_days": "{count} day(s) ago",
"comment_panel_title": "Comments",
"comment_panel_close": "Close",
"comp_dismiss": "Dismiss",
"doc_panel_tab_metadata": "Metadata",
"doc_panel_tab_transcription": "Transcription",
"doc_panel_tab_discussion": "Discussion",
@@ -541,6 +542,7 @@
"training_start_btn": "Start training",
"training_in_progress": "…",
"training_success": "Training started and completed.",
"training_start_failed": "Failed to start training. Please try again.",
"training_too_few_blocks": "At least 5 reviewed blocks required (currently: {available}).",
"training_service_down": "OCR service is unavailable.",
"training_history_heading": "History",
@@ -561,6 +563,11 @@
"transcription_block_segmentation_only": "Segmentation only",
"training_chip_kurrent": "Kurrent recognition",
"training_chip_segmentation": "Segmentation",
"training_col_type": "Type",
"training_type_base": "Base",
"training_type_personalized": "Personalized",
"training_col_person": "Sender",
"training_status_queued": "Queued",
"mission_control_heading": "What needs attention?",
"mission_control_segmentation_heading": "Mark text",
"mission_control_segmentation_description": "Mark text areas — no prior knowledge needed",

View File

@@ -330,6 +330,7 @@
"comment_time_days": "hace {count} día(s)",
"comment_panel_title": "Comentarios",
"comment_panel_close": "Cerrar",
"comp_dismiss": "Cerrar",
"doc_panel_tab_metadata": "Metadatos",
"doc_panel_tab_transcription": "Transcripción",
"doc_panel_tab_discussion": "Discusión",
@@ -541,6 +542,7 @@
"training_start_btn": "Iniciar entrenamiento",
"training_in_progress": "…",
"training_success": "Entrenamiento iniciado y completado.",
"training_start_failed": "No se pudo iniciar el entrenamiento. Por favor, inténtalo de nuevo.",
"training_too_few_blocks": "Se requieren al menos 5 bloques revisados (actualmente: {available}).",
"training_service_down": "El servicio OCR no está disponible.",
"training_history_heading": "Historial",
@@ -561,6 +563,11 @@
"transcription_block_segmentation_only": "Solo segmentación",
"training_chip_kurrent": "Reconocimiento Kurrent",
"training_chip_segmentation": "Segmentación",
"training_col_type": "Tipo",
"training_type_base": "Base",
"training_type_personalized": "Personalizado",
"training_col_person": "Remitente",
"training_status_queued": "En cola",
"mission_control_heading": "¿Qué necesita atención?",
"mission_control_segmentation_heading": "Marcar texto",
"mission_control_segmentation_description": "Marcar áreas de texto — sin conocimientos previos",

View File

@@ -1,26 +1,16 @@
<script lang="ts">
import TrainingHistory from './TrainingHistory.svelte';
import { m } from '$lib/paraglide/messages.js';
interface Run {
id: string;
status: 'RUNNING' | 'DONE' | 'FAILED';
blockCount: number;
documentCount: number;
modelName: string;
errorMessage?: string;
triggeredBy?: string;
createdAt: string;
completedAt?: string;
}
import type { TrainingRun } from '$lib/types/training.js';
interface TrainingInfo {
availableBlocks?: number;
totalOcrBlocks?: number;
availableDocuments?: number;
ocrServiceAvailable?: boolean;
lastRun?: Run | null;
runs?: Run[];
lastRun?: TrainingRun | null;
runs?: TrainingRun[];
personNames?: Record<string, string>;
}
interface Props {
@@ -31,6 +21,7 @@ let { trainingInfo }: Props = $props();
let training = $state(false);
let successMessage = $state<string | null>(null);
let errorMessage = $state<string | null>(null);
const available = $derived(trainingInfo?.availableBlocks ?? 0);
const tooFewBlocks = $derived(available < 5);
@@ -40,6 +31,7 @@ const disabled = $derived(training || tooFewBlocks || serviceDown);
async function startTraining() {
training = true;
successMessage = null;
errorMessage = null;
try {
const res = await fetch('/api/ocr/train', { method: 'POST' });
if (res.ok) {
@@ -47,7 +39,11 @@ async function startTraining() {
setTimeout(() => {
successMessage = null;
}, 5000);
} else {
errorMessage = m.training_start_failed();
}
} catch {
errorMessage = m.training_start_failed();
} finally {
training = false;
}
@@ -82,11 +78,26 @@ async function startTraining() {
{/if}
{#if successMessage}
<p class="mt-2 text-xs text-green-700">{successMessage}</p>
<p class="mt-2 flex items-center gap-2 text-xs text-green-700" aria-live="polite">
{successMessage}
<button
type="button"
class="-my-2 ml-1 inline-flex h-11 w-11 items-center justify-center rounded-sm underline hover:no-underline focus-visible:ring-2 focus-visible:ring-brand-navy focus-visible:outline-none"
onclick={() => (successMessage = null)}
aria-label={m.comp_dismiss()}>×</button
>
</p>
{/if}
{#if errorMessage}
<p class="mt-2 text-xs text-red-600" aria-live="assertive">{errorMessage}</p>
{/if}
<h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase">
{m.training_history_heading()}
</h3>
<TrainingHistory runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName !== 'blla')} />
<TrainingHistory
runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName !== 'blla')}
personNames={trainingInfo?.personNames ?? {}}
/>
</div>

View File

@@ -73,6 +73,24 @@ describe('OcrTrainingCard — enabled state', () => {
});
});
describe('OcrTrainingCard — success dismiss button', () => {
it('dismiss button has 44×44px touch target (h-11 w-11)', async () => {
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: true }));
render(OcrTrainingCard, { trainingInfo: baseInfo });
const btn = page.getByRole('button', { name: /Training starten/i });
await btn.click();
const dismissBtn = page.getByRole('button', { name: /Schließen/i });
await expect.element(dismissBtn).toBeInTheDocument();
const el = await dismissBtn.element();
expect(el.classList.contains('h-11')).toBe(true);
expect(el.classList.contains('w-11')).toBe(true);
});
});
describe('OcrTrainingCard — in-flight state', () => {
it('shows "…" while POST is in-flight', async () => {
let resolveFetch!: (v: unknown) => void;

View File

@@ -1,24 +1,13 @@
<script lang="ts">
import { m } from '$lib/paraglide/messages.js';
interface Run {
id: string;
status: 'RUNNING' | 'DONE' | 'FAILED';
blockCount: number;
documentCount: number;
modelName: string;
errorMessage?: string;
triggeredBy?: string;
createdAt: string;
completedAt?: string;
cer?: number;
}
import type { TrainingRun } from '$lib/types/training.js';
interface Props {
runs: Run[];
runs: TrainingRun[];
personNames?: Record<string, string>;
}
let { runs }: Props = $props();
let { runs, personNames }: Props = $props();
const COLLAPSED_COUNT = 3;
let expanded = $state(false);
@@ -47,6 +36,8 @@ function formatCer(cer: number | undefined | null): string {
<tr class="border-b border-line text-xs font-bold tracking-widest text-ink-3 uppercase">
<th class="pb-2 text-left">{m.training_history_col_date()}</th>
<th class="pb-2 text-left">{m.training_history_col_status()}</th>
<th class="hidden pb-2 text-left md:table-cell">{m.training_col_type()}</th>
<th class="hidden pb-2 text-left md:table-cell">{m.training_col_person()}</th>
<th class="pb-2 text-right">{m.training_history_col_blocks()}</th>
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_docs()}</th>
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_cer()}</th>
@@ -55,7 +46,7 @@ function formatCer(cer: number | undefined | null): string {
<tbody id="training-history-rows">
{#if runs.length === 0}
<tr>
<td colspan="5" class="py-4 text-center text-sm text-ink-2">
<td colspan="7" class="py-4 text-center text-sm text-ink-2">
{m.training_history_empty()}
</td>
</tr>
@@ -64,7 +55,14 @@ function formatCer(cer: number | undefined | null): string {
<tr class="border-b border-line/50 last:border-0">
<td class="py-2 text-ink-2">{formatDate(run.createdAt)}</td>
<td class="py-2">
{#if run.status === 'DONE'}
{#if run.status === 'QUEUED'}
<span
class="inline-flex items-center gap-1 rounded-sm bg-amber-100 px-1.5 py-0.5 text-xs font-medium text-amber-700"
>
<span aria-hidden="true" class="h-1.5 w-1.5 rounded-full bg-amber-500"></span>
{m.training_status_queued()}
</span>
{:else if run.status === 'DONE'}
<span
class="inline-flex items-center gap-1 rounded-sm bg-green-100 px-1.5 py-0.5 text-xs font-medium text-green-700"
>
@@ -119,6 +117,17 @@ function formatCer(cer: number | undefined | null): string {
{m.training_status_running()}
</span>
{/if}
{#if run.personId && personNames?.[run.personId]}
<span class="mt-0.5 block text-xs text-ink-3 md:hidden"
>{personNames[run.personId]}</span
>
{/if}
</td>
<td class="hidden py-2 text-left text-ink-2 md:table-cell">
{run.personId ? m.training_type_personalized() : m.training_type_base()}
</td>
<td class="hidden py-2 text-left text-ink-2 md:table-cell">
{run.personId && personNames?.[run.personId] ? personNames[run.personId] : '—'}
</td>
<td class="py-2 text-right text-ink-2">{run.blockCount}</td>
<td class="hidden py-2 text-right text-ink-2 md:table-cell">{run.documentCount}</td>
@@ -137,7 +146,7 @@ function formatCer(cer: number | undefined | null): string {
type="button"
aria-expanded={expanded}
aria-controls="training-history-rows"
class="text-xs font-medium text-ink-3 transition-colors hover:text-ink"
class="rounded-sm py-2 text-xs font-medium text-ink-3 transition-colors hover:text-ink focus-visible:ring-2 focus-visible:ring-brand-navy focus-visible:ring-offset-1 focus-visible:outline-none"
onclick={() => (expanded = !expanded)}
>
{expanded ? m.comp_expandable_show_less() : m.comp_expandable_show_more()}

View File

@@ -50,3 +50,36 @@ describe('TrainingHistory — expand/collapse', () => {
.not.toBeInTheDocument();
});
});
describe('TrainingHistory — status badges', () => {
it('shows amber QUEUED badge for queued runs', async () => {
const run = { ...makeRun(0), status: 'QUEUED' as const };
render(TrainingHistory, { runs: [run] });
await expect.element(page.getByText(/Warteschlange/i)).toBeInTheDocument();
});
});
describe('TrainingHistory — type and person columns', () => {
it('shows "Basis" for runs without personId', async () => {
render(TrainingHistory, { runs: [makeRun(0)] });
await expect.element(page.getByText(/Basis/i)).toBeInTheDocument();
});
it('shows "Personalisiert" for runs with personId', async () => {
const run = { ...makeRun(0), personId: 'person-1' };
render(TrainingHistory, { runs: [run], personNames: { 'person-1': 'Karl Müller' } });
await expect.element(page.getByText(/Personalisiert/i)).toBeInTheDocument();
});
it('shows person name from personNames for sender runs', async () => {
const run = { ...makeRun(0), personId: 'person-1' };
render(TrainingHistory, { runs: [run], personNames: { 'person-1': 'Karl Müller' } });
await expect.element(page.getByText(/Personalisiert/i)).toBeInTheDocument();
// Name appears in both the mobile inline span and the desktop table cell
await expect.element(page.getByText('Karl Müller').first()).toBeInTheDocument();
});
});

View File

@@ -1374,7 +1374,7 @@ export interface components {
};
MergeTagDTO: {
/** Format: uuid */
targetId?: string;
targetId: string;
};
PersonNameAliasDTO: {
lastName: string;
@@ -1398,7 +1398,7 @@ export interface components {
/** Format: uuid */
id: string;
/** @enum {string} */
status: "RUNNING" | "DONE" | "FAILED";
status: "QUEUED" | "RUNNING" | "DONE" | "FAILED";
/** Format: int32 */
blockCount: number;
/** Format: int32 */
@@ -1415,6 +1415,8 @@ export interface components {
errorMessage?: string;
/** Format: uuid */
triggeredBy?: string;
/** Format: uuid */
personId?: string;
/** Format: date-time */
createdAt: string;
/** Format: date-time */
@@ -1596,11 +1598,11 @@ export interface components {
};
TagTreeNodeDTO: {
/** Format: uuid */
id?: string;
name?: string;
id: string;
name: string;
color?: string;
/** Format: int32 */
documentCount?: number;
documentCount: number;
children?: components["schemas"]["TagTreeNodeDTO"][];
/**
* Format: uuid
@@ -1619,7 +1621,6 @@ export interface components {
/** Format: uuid */
id?: string;
displayName?: string;
personType?: string;
firstName?: string;
lastName?: string;
/** Format: int64 */
@@ -1630,19 +1631,7 @@ export interface components {
deathYear?: number;
alias?: string;
notes?: string;
};
TrainingInfoResponse: {
/** Format: int32 */
availableBlocks?: number;
/** Format: int32 */
totalOcrBlocks?: number;
/** Format: int32 */
availableDocuments?: number;
/** Format: int32 */
availableSegBlocks?: number;
ocrServiceAvailable?: boolean;
lastRun?: components["schemas"]["OcrTrainingRun"];
runs?: components["schemas"]["OcrTrainingRun"][];
personType?: string;
};
StreamingResponseBody: unknown;
OcrJob: {
@@ -1671,13 +1660,11 @@ export interface components {
timeout?: number;
};
PageNotificationDTO: {
/** Format: int64 */
totalElements?: number;
/** Format: int32 */
totalPages?: number;
/** Format: int64 */
totalElements?: number;
pageable?: components["schemas"]["PageableObject"];
first?: boolean;
last?: boolean;
/** Format: int32 */
size?: number;
content?: components["schemas"]["NotificationDTO"][];
@@ -1686,6 +1673,8 @@ export interface components {
sort?: components["schemas"]["SortObject"];
/** Format: int32 */
numberOfElements?: number;
first?: boolean;
last?: boolean;
empty?: boolean;
};
PageableObject: {
@@ -3501,7 +3490,9 @@ export interface operations {
[name: string]: unknown;
};
content: {
"*/*": components["schemas"]["TrainingInfoResponse"];
"*/*": {
[key: string]: unknown;
};
};
};
};

View File

@@ -0,0 +1,13 @@
export interface TrainingRun {
id: string;
status: 'QUEUED' | 'RUNNING' | 'DONE' | 'FAILED';
personId?: string;
blockCount: number;
documentCount: number;
modelName: string;
errorMessage?: string;
triggeredBy?: string;
createdAt: string;
completedAt?: string;
cer?: number;
}

View File

@@ -1,13 +1,84 @@
"""Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin."""
from __future__ import annotations
import collections
import logging
import os
import threading
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from PIL.Image import Image
logger = logging.getLogger(__name__)
_model = None
_model_path = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
_MODELS_DIR = os.path.realpath("/app/models")
_MAX_CACHED_SENDER_MODELS = int(os.environ.get("OCR_MAX_CACHED_MODELS", "5"))
def _load_sender_model(path: str):
"""Load a Kraken model from disk. Extracted for testability."""
from kraken.lib import models as kraken_models
return kraken_models.load_any(path)
class _SenderModelRegistry:
"""Thread-safe LRU cache for per-sender Kraken models.
Uses double-checked locking: model loading happens outside the lock to
avoid blocking concurrent OCR requests. At most one entry per path is
stored even under concurrent load.
"""
def __init__(self, max_size: int):
self._max_size = max_size
self._cache: collections.OrderedDict = collections.OrderedDict()
self._lock = threading.Lock()
def get_model(self, model_path: str):
"""Return the cached model or load it. Validates path is within /app/models/."""
resolved = os.path.realpath(model_path)
if not resolved.startswith(_MODELS_DIR + os.sep) and resolved != _MODELS_DIR:
raise ValueError(f"Sender model path not allowed: {model_path}")
with self._lock:
if model_path in self._cache:
self._cache.move_to_end(model_path)
return self._cache[model_path]
new_model = _load_sender_model(model_path)
with self._lock:
if model_path in self._cache:
self._cache.move_to_end(model_path)
return self._cache[model_path]
self._cache[model_path] = new_model
self._cache.move_to_end(model_path)
while len(self._cache) > self._max_size:
self._cache.popitem(last=False)
return new_model
def invalidate(self, model_path: str) -> None:
"""Remove model from cache so the next request reloads from disk."""
with self._lock:
self._cache.pop(model_path, None)
def size(self) -> int:
with self._lock:
return len(self._cache)
def _contains(self, model_path: str) -> bool:
with self._lock:
return model_path in self._cache
_sender_registry = _SenderModelRegistry(_MAX_CACHED_SENDER_MODELS)
def load_models():
"""Load the Kraken model at startup. Skips if model file is not present."""
@@ -29,10 +100,12 @@ def is_available() -> bool:
return _model is not None
def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict]:
def extract_page_blocks(image: Image, page_idx: int, language: str = "de",
sender_model_path: str | None = None) -> list[dict]:
"""Run Kraken segmentation + recognition on a single PIL image.
Returns block dicts for that page. Coordinates are normalized to [0, 1].
When sender_model_path is provided, the per-sender fine-tuned model is used.
"""
from kraken import blla, rpred
from confidence import words_from_characters
@@ -40,11 +113,13 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
if _model is None:
raise RuntimeError("Kraken model is not loaded")
active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model
page_w, page_h = image.size
blocks = []
baseline_seg = blla.segment(image)
pred_it = rpred.rpred(_model, image, baseline_seg)
pred_it = rpred.rpred(active_model, image, baseline_seg)
for record in pred_it:
polygon_pts = record.boundary if hasattr(record, "boundary") and record.boundary else []
@@ -79,13 +154,15 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
return blocks
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
def extract_region_text(image: Image, x: float, y: float, w: float, h: float,
sender_model_path: str | None = None) -> str:
"""Crop image to a normalized region and run Kraken recognition on the crop.
Used for guided OCR — skips full-page layout detection entirely.
A single synthetic baseline spanning the full crop width is used so that
blla.segment() (which crashes on small crops) is never called.
Coordinates are normalized to [0, 1].
When sender_model_path is provided, the per-sender fine-tuned model is used.
"""
from kraken import rpred
from kraken.containers import Segmentation, BaselineLine
@@ -93,6 +170,8 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
if _model is None:
raise RuntimeError("Kraken model is not loaded")
active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model
pw, ph = image.size
x1 = max(0, int(x * pw))
y1 = max(0, int(y * ph))
@@ -123,11 +202,12 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
regions={},
line_orders=[],
)
pred_it = rpred.rpred(_model, crop, synthetic_seg)
pred_it = rpred.rpred(active_model, crop, synthetic_seg)
return " ".join(r.prediction for r in pred_it)
def extract_blocks(images: list, language: str = "de") -> list[dict]:
def extract_blocks(images: list, language: str = "de",
sender_model_path: str | None = None) -> list[dict]:
"""Run Kraken segmentation + recognition on a list of PIL images.
Returns block dicts with pageNumber, x, y, width, height, polygon, text.
@@ -137,7 +217,7 @@ def extract_blocks(images: list, language: str = "de") -> list[dict]:
all_blocks = []
for page_idx, image in enumerate(images, start=1):
all_blocks.extend(extract_page_blocks(image, page_idx, language))
all_blocks.extend(extract_page_blocks(image, page_idx, language, sender_model_path))
return all_blocks

View File

@@ -17,7 +17,7 @@ from urllib.parse import urlparse
import httpx
import pypdfium2 as pdfium
from fastapi import FastAPI, Header, HTTPException, UploadFile
from fastapi import FastAPI, Form, Header, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from PIL import Image
@@ -103,7 +103,8 @@ async def run_ocr(request: OcrRequest):
status_code=400,
detail="Kraken model not available — cannot process Kurrent script",
)
blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language)
blocks = await asyncio.to_thread(
kraken_engine.extract_blocks, images, request.language, request.senderModelPath)
else:
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
@@ -171,10 +172,12 @@ async def run_ocr_stream(request: OcrRequest):
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
image = await asyncio.to_thread(preprocess_page, image)
blocks = []
sender_path = request.senderModelPath if use_kraken else None
for region in page_regions:
text = await asyncio.to_thread(
engine.extract_region_text, image,
region.x, region.y, region.width, region.height,
sender_path,
)
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
text = correct_text(text)
@@ -231,8 +234,9 @@ async def run_ocr_stream(request: OcrRequest):
try:
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
image = await asyncio.to_thread(preprocess_page, image)
sender_path = request.senderModelPath if use_kraken else None
blocks = await asyncio.to_thread(
engine.extract_page_blocks, image, page_idx, request.language
engine.extract_page_blocks, image, page_idx, request.language, sender_path
)
for block in blocks:
@@ -443,6 +447,86 @@ async def train_model(
return result
@app.post("/train-sender")
async def train_sender_model(
file: UploadFile,
output_model_path: str = Form(...),
x_training_token: str | None = Header(default=None),
):
"""Fine-tune a per-sender Kurrent model and save to output_model_path.
output_model_path must be within /app/models/. The per-sender model cache
is invalidated after training so the next OCR request picks up the new model.
"""
_check_training_token(x_training_token)
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
models_dir = os.path.realpath("/app/models")
resolved_output = os.path.realpath(output_model_path)
if not resolved_output.startswith(models_dir + os.sep):
raise HTTPException(status_code=400, detail="output_model_path must be within /app/models/")
zip_bytes = await file.read()
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
log.info("Starting sender training run %s%s", training_run_id, output_model_path)
def _run_sender_training() -> dict:
with tempfile.TemporaryDirectory() as tmp_dir:
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
for entry in zf.namelist():
_validate_zip_entry(entry, tmp_dir)
zf.extractall(tmp_dir)
ground_truth = glob.glob(os.path.join(tmp_dir, "*.xml"))
if not ground_truth:
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
log.info("Sender training on %d ground-truth pairs", len(ground_truth))
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
cmd = [
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
"train",
"-f", "page",
"--weights-format", "coreml",
"-o", checkpoint_dir,
"-q", "fixed",
"-N", "10",
"-B", "1",
]
if os.path.exists(KRAKEN_MODEL_PATH):
cmd += ["-i", KRAKEN_MODEL_PATH]
cmd += ground_truth
log.info("Running sender training: %s", " ".join(cmd[:6]) + " ...")
proc = subprocess.run(cmd, capture_output=True, text=True)
if proc.returncode != 0:
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
best_model = _find_best_model(checkpoint_dir)
if best_model is None:
raise RuntimeError("Sender training produced no model file")
os.makedirs(os.path.dirname(resolved_output), exist_ok=True)
shutil.copy2(best_model, resolved_output)
log.info("Saved sender model to %s", resolved_output)
kraken_engine._sender_registry.invalidate(output_model_path)
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_sender_training)
return result
@app.post("/segtrain")
async def segtrain_model(
file: UploadFile,

View File

@@ -19,6 +19,7 @@ class OcrRequest(BaseModel):
scriptType: str = "UNKNOWN"
language: str = "de"
regions: list[OcrRegion] | None = None
senderModelPath: str | None = None
class OcrBlock(BaseModel):

View File

@@ -0,0 +1,75 @@
"""Tests for the per-sender model LRU registry in engines/kraken.py."""
from unittest.mock import MagicMock, call, patch
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_registry(max_size=5):
from engines.kraken import _SenderModelRegistry
return _SenderModelRegistry(max_size=max_size)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_cache_hit_returns_same_object():
"""Second get_model call with the same path must return the cached object."""
registry = _make_registry()
mock_model = MagicMock(name="model_a")
with patch("engines.kraken._load_sender_model", return_value=mock_model) as loader:
m1 = registry.get_model("/app/models/sender_a.mlmodel")
m2 = registry.get_model("/app/models/sender_a.mlmodel")
assert m1 is m2
loader.assert_called_once() # only loaded once despite two gets
def test_lru_eviction_removes_least_recently_used():
"""When the cache exceeds max_size, the oldest-accessed entry is evicted."""
registry = _make_registry(max_size=2)
def _side_effect(path):
return MagicMock(name=path)
with patch("engines.kraken._load_sender_model", side_effect=_side_effect):
registry.get_model("/app/models/sender_a.mlmodel")
registry.get_model("/app/models/sender_b.mlmodel")
registry.get_model("/app/models/sender_c.mlmodel") # should evict 'a'
assert registry.size() == 2
# 'a' was the least-recently-used and should be gone
assert not registry._contains("/app/models/sender_a.mlmodel")
def test_invalidate_removes_entry_from_cache():
"""invalidate() must evict the entry so the next get reloads from disk."""
registry = _make_registry()
mock_model = MagicMock(name="model_x")
with patch("engines.kraken._load_sender_model", return_value=mock_model):
registry.get_model("/app/models/sender_x.mlmodel")
assert registry.size() == 1
registry.invalidate("/app/models/sender_x.mlmodel")
assert registry.size() == 0
def test_path_outside_models_dir_raises():
"""get_model must reject paths outside /app/models/ (path traversal guard)."""
registry = _make_registry()
with pytest.raises(ValueError, match="not allowed"):
registry.get_model("/etc/passwd")
def test_load_failure_does_not_cache_broken_entry():
"""A failed load must not leave a broken entry in the cache."""
registry = _make_registry()
with patch("engines.kraken._load_sender_model", side_effect=RuntimeError("corrupt model")):
with pytest.raises(RuntimeError, match="corrupt model"):
registry.get_model("/app/models/sender_bad.mlmodel")
assert registry.size() == 0

View File

@@ -67,3 +67,42 @@ async def test_train_returns_403_when_token_wrong():
)
assert response.status_code == 403
# ─── /train-sender authentication ─────────────────────────────────────────────
@pytest.mark.asyncio
async def test_train_sender_returns_503_when_training_token_not_configured():
"""POST /train-sender must return 503 when TRAINING_TOKEN env var is empty.
An empty token means the service was started without training configured.
Allowing requests through would grant unauthenticated access to the
training endpoint, contradicting the principle of failing closed.
"""
with patch("main.TRAINING_TOKEN", ""), \
patch("main._models_ready", True):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/train-sender",
data={"output_model_path": "/app/models/sender_test.mlmodel"},
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
)
assert response.status_code == 503
@pytest.mark.asyncio
async def test_train_sender_returns_403_when_token_wrong():
"""POST /train-sender must return 403 when TRAINING_TOKEN is set but header is wrong."""
with patch("main.TRAINING_TOKEN", "secret-token"), \
patch("main._models_ready", True):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/train-sender",
data={"output_model_path": "/app/models/sender_test.mlmodel"},
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
headers={"X-Training-Token": "wrong-token"},
)
assert response.status_code == 403