feat(ocr): per-sender specialized Kurrent models with automatic active-learning retraining #263

Merged
marcel merged 32 commits from feat/issue-253-sender-models into main 2026-04-18 12:30:57 +02:00
42 changed files with 1495 additions and 136 deletions

View File

@@ -48,7 +48,7 @@ jobs:
path: frontend/test-results/screenshots/ path: frontend/test-results/screenshots/
# ─── OCR Service Unit Tests ─────────────────────────────────────────────────── # ─── OCR Service Unit Tests ───────────────────────────────────────────────────
# Only spell_check.py and test_confidence.py — no ML stack required. # Only spell_check.py, test_confidence.py, test_sender_registry.py — no ML stack required.
ocr-tests: ocr-tests:
name: OCR Service Tests name: OCR Service Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -60,11 +60,11 @@ jobs:
python-version: '3.11' python-version: '3.11'
- name: Install test dependencies - name: Install test dependencies
run: pip install "pyspellchecker==0.9.0" pytest run: pip install "pyspellchecker==0.9.0" pytest pytest-asyncio
working-directory: ocr-service working-directory: ocr-service
- name: Run OCR unit tests (no ML stack required) - name: Run OCR unit tests (no ML stack required)
run: python -m pytest test_spell_check.py test_confidence.py -v run: python -m pytest test_spell_check.py test_confidence.py test_sender_registry.py -v
working-directory: ocr-service working-directory: ocr-service
# ─── Backend Unit & Slice Tests ─────────────────────────────────────────────── # ─── Backend Unit & Slice Tests ───────────────────────────────────────────────

View File

@@ -59,6 +59,9 @@ public class OcrTrainingRun {
@Column(name = "triggered_by") @Column(name = "triggered_by")
private UUID triggeredBy; private UUID triggeredBy;
@Column(name = "person_id")
private UUID personId;
@CreationTimestamp @CreationTimestamp
@Column(name = "created_at", nullable = false, updatable = false) @Column(name = "created_at", nullable = false, updatable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED) @Schema(requiredMode = Schema.RequiredMode.REQUIRED)

View File

@@ -0,0 +1,56 @@
package org.raddatz.familienarchiv.model;
import com.fasterxml.jackson.annotation.JsonIgnore;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.persistence.*;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.hibernate.annotations.CreationTimestamp;
import org.hibernate.annotations.UpdateTimestamp;
import java.time.Instant;
import java.util.UUID;
@Entity
@Table(name = "sender_models")
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class SenderModel {
@Id
@GeneratedValue(strategy = GenerationType.UUID)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID id;
@Column(name = "person_id", nullable = false, unique = true)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private UUID personId;
@JsonIgnore
@Column(name = "model_path", nullable = false)
private String modelPath;
@Column
private Double accuracy;
@Column
private Double cer;
@Column(name = "corrected_lines_at_training", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private int correctedLinesAtTraining;
@CreationTimestamp
@Column(name = "created_at", nullable = false, updatable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private Instant createdAt;
@UpdateTimestamp
@Column(name = "updated_at", nullable = false)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
private Instant updatedAt;
}

View File

@@ -1,6 +1,7 @@
package org.raddatz.familienarchiv.model; package org.raddatz.familienarchiv.model;
public enum TrainingStatus { public enum TrainingStatus {
QUEUED,
RUNNING, RUNNING,
DONE, DONE,
FAILED FAILED

View File

@@ -12,5 +12,11 @@ public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun,
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status); Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
List<OcrTrainingRun> findTop10ByOrderByCreatedAtDesc(); Optional<OcrTrainingRun> findFirstByStatusOrderByCreatedAtAsc(TrainingStatus status);
Optional<OcrTrainingRun> findFirstByPersonIdAndStatus(UUID personId, TrainingStatus status);
boolean existsByPersonIdAndStatus(UUID personId, TrainingStatus status);
List<OcrTrainingRun> findTop20ByOrderByCreatedAtDesc();
} }

View File

@@ -0,0 +1,12 @@
package org.raddatz.familienarchiv.repository;
import org.raddatz.familienarchiv.model.SenderModel;
import org.springframework.data.jpa.repository.JpaRepository;
import java.util.Optional;
import java.util.UUID;
public interface SenderModelRepository extends JpaRepository<SenderModel, UUID> {
Optional<SenderModel> findByPersonId(UUID personId);
}

View File

@@ -3,6 +3,7 @@ package org.raddatz.familienarchiv.repository;
import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query; import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
@@ -37,4 +38,22 @@ public interface TranscriptionBlockRepository extends JpaRepository<Transcriptio
AND 'KURRENT_SEGMENTATION' MEMBER OF d.trainingLabels AND 'KURRENT_SEGMENTATION' MEMBER OF d.trainingLabels
""") """)
List<TranscriptionBlock> findSegmentationBlocks(); List<TranscriptionBlock> findSegmentationBlocks();
@Query("""
SELECT COUNT(b) FROM TranscriptionBlock b
JOIN Document d ON d.id = b.documentId
WHERE b.source = 'MANUAL'
AND d.sender.id = :personId
AND d.scriptType = 'HANDWRITING_KURRENT'
""")
long countManualKurrentBlocksByPerson(@Param("personId") UUID personId);
@Query("""
SELECT b FROM TranscriptionBlock b
JOIN Document d ON d.id = b.documentId
WHERE b.source = 'MANUAL'
AND d.sender.id = :personId
AND d.scriptType = 'HANDWRITING_KURRENT'
""")
List<TranscriptionBlock> findManualKurrentBlocksByPerson(@Param("personId") UUID personId);
} }

View File

@@ -9,10 +9,12 @@ import org.raddatz.familienarchiv.repository.OcrJobRepository;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
@Component @Component
@RequiredArgsConstructor @RequiredArgsConstructor
@@ -29,6 +31,7 @@ public class OcrAsyncRunner {
private final OcrJobRepository ocrJobRepository; private final OcrJobRepository ocrJobRepository;
private final OcrJobDocumentRepository ocrJobDocumentRepository; private final OcrJobDocumentRepository ocrJobDocumentRepository;
private final OcrProgressService ocrProgressService; private final OcrProgressService ocrProgressService;
private final SenderModelService senderModelService;
@Async @Async
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) { public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
@@ -68,12 +71,18 @@ public class OcrAsyncRunner {
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath()); String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
String senderModelPath = null;
if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) {
senderModelPath = senderModelService.maybeGetModelPath(doc.getSender().getId()).orElse(null);
}
AtomicInteger blockCounter = new AtomicInteger(0); AtomicInteger blockCounter = new AtomicInteger(0);
AtomicInteger currentPage = new AtomicInteger(0); AtomicInteger currentPage = new AtomicInteger(0);
AtomicInteger skippedPages = new AtomicInteger(0); AtomicInteger skippedPages = new AtomicInteger(0);
AtomicInteger totalPages = new AtomicInteger(0); AtomicInteger totalPages = new AtomicInteger(0);
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, event -> { final String finalSenderModelPath = senderModelPath;
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, finalSenderModelPath, event -> {
switch (event) { switch (event) {
case OcrStreamEvent.Start start -> { case OcrStreamEvent.Start start -> {
totalPages.set(start.totalPages()); totalPages.set(start.totalPages());
@@ -207,7 +216,25 @@ public class OcrAsyncRunner {
clearExistingBlocks(documentId); clearExistingBlocks(documentId);
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath()); String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
List<OcrBlockResult> blocks = ocrClient.extractBlocks(pdfUrl, doc.getScriptType());
String senderModelPath = null;
if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) {
senderModelPath = senderModelService.maybeGetModelPath(doc.getSender().getId()).orElse(null);
}
final AtomicReference<List<OcrBlockResult>> blocksRef = new AtomicReference<>();
final String finalSenderModelPath = senderModelPath;
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), null, finalSenderModelPath, event -> {
switch (event) {
case OcrStreamEvent.Page page -> {
blocksRef.compareAndSet(null, new ArrayList<>());
blocksRef.get().addAll(page.blocks());
}
default -> {}
}
});
List<OcrBlockResult> blocks = blocksRef.get() != null ? blocksRef.get() : List.of();
createTranscriptionBlocks(documentId, blocks, userId, doc.getFileHash()); createTranscriptionBlocks(documentId, blocks, userId, doc.getFileHash());
} }

View File

@@ -1,6 +1,7 @@
package org.raddatz.familienarchiv.service; package org.raddatz.familienarchiv.service;
import org.raddatz.familienarchiv.model.ScriptType; import org.raddatz.familienarchiv.model.ScriptType;
import org.springframework.lang.Nullable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
@@ -37,15 +38,27 @@ public interface OcrClient {
TrainingResult segtrainModel(byte[] trainingDataZip); TrainingResult segtrainModel(byte[] trainingDataZip);
/** /**
* Stream OCR results page-by-page via NDJSON. Implementations should override * Fine-tune the Kurrent model for a specific sender.
* this method. The default exists only for backward compatibility during migration
* — it calls extractBlocks() and synthesizes events from the collected result.
* *
* @param regions optional list of pre-drawn annotation regions; when non-null, * @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService.exportForSender()
* the OCR service runs in guided mode (crop + recognize per region) * @param outputModelPath where to save the trained model (e.g. /app/models/sender_{uuid}.mlmodel)
* @return training result metrics
*/
TrainingResult trainSenderModel(byte[] trainingDataZip, String outputModelPath);
/**
* Stream OCR results page-by-page via NDJSON, optionally using a sender-specific model.
* The default implementation synthesizes events from extractBlocks() for backward compatibility.
* Implementations that support real streaming (e.g. RestClientOcrClient) override this.
*
* @param regions optional list of pre-drawn annotation regions; when non-null,
* the OCR service runs in guided mode (crop + recognize per region)
* @param senderModelPath optional path to a per-sender model file; null means use base model
*/ */
default void streamBlocks(String pdfUrl, ScriptType scriptType, default void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) { List<OcrRegion> regions,
@Nullable String senderModelPath,
Consumer<OcrStreamEvent> handler) {
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType); List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>(); LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();
@@ -62,4 +75,9 @@ public interface OcrClient {
handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0)); handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0));
} }
default void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
streamBlocks(pdfUrl, scriptType, regions, null, handler);
}
} }

View File

@@ -17,9 +17,11 @@ import org.springframework.transaction.support.TransactionTemplate;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.time.Instant; import java.time.Instant;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@Service @Service
@@ -34,6 +36,7 @@ public class OcrTrainingService {
private final OcrHealthClient ocrHealthClient; private final OcrHealthClient ocrHealthClient;
private final TranscriptionBlockRepository blockRepository; private final TranscriptionBlockRepository blockRepository;
private final TransactionTemplate txTemplate; private final TransactionTemplate txTemplate;
private final PersonService personService;
public record TrainingInfoResponse( public record TrainingInfoResponse(
int availableBlocks, int availableBlocks,
@@ -42,7 +45,8 @@ public class OcrTrainingService {
int availableSegBlocks, int availableSegBlocks,
boolean ocrServiceAvailable, boolean ocrServiceAvailable,
OcrTrainingRun lastRun, OcrTrainingRun lastRun,
List<OcrTrainingRun> runs List<OcrTrainingRun> runs,
Map<String, String> personNames
) {} ) {}
private void assertNoRunningTraining() { private void assertNoRunningTraining() {
@@ -195,9 +199,20 @@ public class OcrTrainingService {
int totalOcrBlocks = (int) blockRepository.count(); int totalOcrBlocks = (int) blockRepository.count();
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size(); int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop10ByOrderByCreatedAtDesc(); List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop20ByOrderByCreatedAtDesc();
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0); OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
List<UUID> distinctPersonIds = recentRuns.stream()
.map(OcrTrainingRun::getPersonId)
.filter(Objects::nonNull)
.distinct()
.toList();
Map<String, String> personNames = new HashMap<>();
if (!distinctPersonIds.isEmpty()) {
personService.getAllById(distinctPersonIds)
.forEach(p -> personNames.put(p.getId().toString(), p.getDisplayName()));
}
return new TrainingInfoResponse( return new TrainingInfoResponse(
eligibleBlocks.size(), eligibleBlocks.size(),
totalOcrBlocks, totalOcrBlocks,
@@ -205,7 +220,8 @@ public class OcrTrainingService {
availableSegBlocks, availableSegBlocks,
ocrHealthClient.isHealthy(), ocrHealthClient.isHealthy(),
lastRun, lastRun,
recentRuns recentRuns,
personNames
); );
} }
@@ -224,15 +240,4 @@ public class OcrTrainingService {
}); });
} }
public Map<String, Object> buildTrainingInfoMap(TrainingInfoResponse info) {
return Map.of(
"availableBlocks", info.availableBlocks(),
"totalOcrBlocks", info.totalOcrBlocks(),
"availableDocuments", info.availableDocuments(),
"availableSegBlocks", info.availableSegBlocks(),
"ocrServiceAvailable", info.ocrServiceAvailable(),
"lastRun", info.lastRun() != null ? info.lastRun() : Map.of(),
"runs", info.runs()
);
}
} }

View File

@@ -14,6 +14,7 @@ import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.client.JdkClientHttpRequestFactory; import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
@@ -102,6 +103,13 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
.toList(); .toList();
} }
private RestClient.RequestBodySpec addTrainingAuth(RestClient.RequestBodySpec spec) {
if (trainingToken != null && !trainingToken.isBlank()) {
return spec.header("X-Training-Token", trainingToken);
}
return spec;
}
@Override @Override
public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) { public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) {
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) { ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
@@ -114,15 +122,10 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
partHeaders.setContentType(MediaType.parseMediaType("application/zip")); partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders)); body.add("file", new HttpEntity<>(zipResource, partHeaders));
var spec = trainingRestClient.post() TrainingResultJson result = addTrainingAuth(
.uri("/train") trainingRestClient.post()
.contentType(MediaType.MULTIPART_FORM_DATA); .uri("/train")
.contentType(MediaType.MULTIPART_FORM_DATA))
if (trainingToken != null && !trainingToken.isBlank()) {
spec = spec.header("X-Training-Token", trainingToken);
}
TrainingResultJson result = spec
.body(body) .body(body)
.retrieve() .retrieve()
.body(TrainingResultJson.class); .body(TrainingResultJson.class);
@@ -143,15 +146,35 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
partHeaders.setContentType(MediaType.parseMediaType("application/zip")); partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders)); body.add("file", new HttpEntity<>(zipResource, partHeaders));
var spec = trainingRestClient.post() TrainingResultJson result = addTrainingAuth(
.uri("/segtrain") trainingRestClient.post()
.contentType(MediaType.MULTIPART_FORM_DATA); .uri("/segtrain")
.contentType(MediaType.MULTIPART_FORM_DATA))
.body(body)
.retrieve()
.body(TrainingResultJson.class);
if (trainingToken != null && !trainingToken.isBlank()) { if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
spec = spec.header("X-Training-Token", trainingToken); return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
} }
TrainingResultJson result = spec @Override
public OcrClient.TrainingResult trainSenderModel(byte[] trainingDataZip, String outputModelPath) {
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
@Override
public String getFilename() { return "sender-training-data.zip"; }
};
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
HttpHeaders partHeaders = new HttpHeaders();
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
body.add("file", new HttpEntity<>(zipResource, partHeaders));
body.add("output_model_path", outputModelPath);
TrainingResultJson result = addTrainingAuth(
trainingRestClient.post()
.uri("/train-sender")
.contentType(MediaType.MULTIPART_FORM_DATA))
.body(body) .body(body)
.retrieve() .retrieve()
.body(TrainingResultJson.class); .body(TrainingResultJson.class);
@@ -176,7 +199,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
@Override @Override
public void streamBlocks(String pdfUrl, ScriptType scriptType, public void streamBlocks(String pdfUrl, ScriptType scriptType,
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) { List<OcrRegion> regions, @Nullable String senderModelPath,
Consumer<OcrStreamEvent> handler) {
String body; String body;
try { try {
var requestMap = new java.util.LinkedHashMap<String, Object>(); var requestMap = new java.util.LinkedHashMap<String, Object>();
@@ -186,6 +210,9 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
if (regions != null && !regions.isEmpty()) { if (regions != null && !regions.isEmpty()) {
requestMap.put("regions", regions); requestMap.put("regions", regions);
} }
if (senderModelPath != null) {
requestMap.put("senderModelPath", senderModelPath);
}
body = NDJSON_MAPPER.writeValueAsString(requestMap); body = NDJSON_MAPPER.writeValueAsString(requestMap);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException("Failed to serialize OCR request", e); throw new RuntimeException("Failed to serialize OCR request", e);
@@ -204,7 +231,12 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
if (response.statusCode() == 404) { if (response.statusCode() == 404) {
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr"); log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
OcrClient.super.streamBlocks(pdfUrl, scriptType, regions, handler); List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
handler.accept(new OcrStreamEvent.Start(0));
for (OcrBlockResult block : allBlocks) {
handler.accept(new OcrStreamEvent.Page(block.pageNumber(), List.of(block)));
}
handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0));
return; return;
} }

View File

@@ -0,0 +1,197 @@
package org.raddatz.familienarchiv.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.slf4j.MDC;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.TransactionTemplate;
import java.io.ByteArrayOutputStream;
import java.time.Instant;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
@Service
@RequiredArgsConstructor
@Slf4j
public class SenderModelService {
private final SenderModelRepository senderModelRepository;
private final TranscriptionBlockRepository blockRepository;
private final OcrTrainingRunRepository trainingRunRepository;
private final OcrClient ocrClient;
private final TransactionTemplate txTemplate;
private final TrainingDataExportService trainingDataExportService;
@Value("${ocr.sender-model.activation-threshold:100}")
private int activationThreshold;
@Value("${ocr.sender-model.retrain-delta:50}")
private int retrainDelta;
/** Returns the model path if a trained sender model exists for this person. */
public Optional<String> maybeGetModelPath(UUID personId) {
return senderModelRepository.findByPersonId(personId)
.map(SenderModel::getModelPath);
}
/**
* Called after every MANUAL block save for HANDWRITING_KURRENT documents.
* Checks activation and retrain thresholds; enqueues or starts sender training when met.
*/
@Async
public void checkAndTriggerTraining(UUID personId) {
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
Optional<SenderModel> existing = senderModelRepository.findByPersonId(personId);
boolean shouldActivate = existing.isEmpty() && correctedLines >= activationThreshold;
boolean shouldRetrain = existing.isPresent()
&& (correctedLines - existing.get().getCorrectedLinesAtTraining()) >= retrainDelta;
if (!shouldActivate && !shouldRetrain) {
return;
}
log.info("Sender training threshold met for person {} (correctedLines={}, activate={}, retrain={})",
personId, correctedLines, shouldActivate, shouldRetrain);
boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines);
if (runNow) {
triggerSenderTraining(personId, (int) correctedLines);
}
}
/**
* Atomically checks the queue state and either creates a RUNNING row (returns true) or a
* QUEUED row (returns false). All three operations — idle check, duplicate-queue guard, and
* RUNNING row creation — happen in one transaction, eliminating the race window that would
* otherwise exist between the check and a separate RUNNING row creation.
*/
@Transactional
public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) {
if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) {
log.info("Sender training already queued for person {} — skipping duplicate trigger", personId);
return false;
}
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId);
trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.QUEUED)
.personId(personId)
.blockCount(blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
log.info("Queued sender training for person {} — training already running", personId);
return false;
}
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId);
trainingRunRepository.save(OcrTrainingRun.builder()
.status(TrainingStatus.RUNNING)
.personId(personId)
.blockCount((int) blockCount)
.documentCount(0)
.modelName("sender_" + personId)
.build());
return true;
}
/**
* Executes sender training synchronously. Caller must run this on a background thread.
* The RUNNING row is expected to already exist — created atomically by
* runOrQueueSenderTraining (for new runs) or by promoteNextQueuedRun (for promoted runs).
*/
public void triggerSenderTraining(UUID personId, int correctedLines) {
String outputModelPath = "/app/models/sender_" + personId + ".mlmodel";
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status ->
trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING)
.orElseThrow(() -> new IllegalStateException(
"Expected RUNNING row for person " + personId + " but none found"))));
String runId = run.getId().toString();
MDC.put("trainingRunId", runId);
log.info("Started sender training run {} for person {}", runId, personId);
try {
byte[] zipBytes = exportSenderData(personId);
log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length);
OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath);
txTemplate.execute(status -> {
SenderModel model = senderModelRepository.findByPersonId(personId)
.orElseGet(() -> SenderModel.builder().personId(personId).build());
model.setModelPath(outputModelPath);
model.setCer(result.cer());
model.setAccuracy(result.accuracy());
model.setCorrectedLinesAtTraining(correctedLines);
senderModelRepository.save(model);
run.setStatus(TrainingStatus.DONE);
run.setCompletedAt(Instant.now());
run.setCer(result.cer());
run.setAccuracy(result.accuracy());
run.setEpochs(result.epochs());
trainingRunRepository.save(run);
log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer());
return null;
});
} catch (Exception e) {
txTemplate.execute(status -> {
run.setStatus(TrainingStatus.FAILED);
run.setErrorMessage(e.getMessage());
run.setCompletedAt(Instant.now());
trainingRunRepository.save(run);
log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e);
return null;
});
} finally {
MDC.remove("trainingRunId");
promoteNextQueuedRun();
}
}
private byte[] exportSenderData(UUID personId) throws java.io.IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
trainingDataExportService.exportForSender(personId).writeTo(baos);
return baos.toByteArray();
}
/**
* Promotes the oldest QUEUED sender run to RUNNING and triggers its training.
* Called in the finally block of triggerSenderTraining, creating a sequential chain:
* each run promotes the next only after it fully completes (success or failure).
*
* This is intentionally tail-recursive via the @Async thread: the same thread holds the
* full queue drain, serialising all sender training runs naturally without an external
* scheduler. With N queued runs the thread stays occupied for N sequential trainings —
* acceptable because the @Async executor is dedicated to long-running background work.
*/
private void promoteNextQueuedRun() {
Optional<OcrTrainingRun> queuedOpt = txTemplate.execute(status ->
trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)
.map(queued -> {
queued.setStatus(TrainingStatus.RUNNING);
return trainingRunRepository.save(queued);
}));
if (queuedOpt != null && queuedOpt.isPresent()) {
OcrTrainingRun promoted = queuedOpt.get();
log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId());
long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId());
triggerSenderTraining(promoted.getPersonId(), (int) freshCount);
}
}
}

View File

@@ -38,10 +38,20 @@ public class TrainingDataExportService {
} }
public StreamingResponseBody exportToZip() { public StreamingResponseBody exportToZip() {
// Collect all data before entering the lambda — no open DB txn during streaming return exportBlocksToZip(queryEligibleBlocks());
List<TranscriptionBlock> blocks = queryEligibleBlocks(); }
public List<TranscriptionBlock> queryBlocksForSender(UUID personId) {
return blockRepository.findManualKurrentBlocksByPerson(personId);
}
public StreamingResponseBody exportForSender(UUID personId) {
return exportBlocksToZip(queryBlocksForSender(personId));
}
private StreamingResponseBody exportBlocksToZip(List<TranscriptionBlock> blocks) {
if (blocks.isEmpty()) { if (blocks.isEmpty()) {
return out -> {}; // caller checks isEmpty() for 204 response return out -> {};
} }
// Group blocks by documentId so we only download each PDF once // Group blocks by documentId so we only download each PDF once

View File

@@ -11,6 +11,7 @@ import org.raddatz.familienarchiv.exception.ErrorCode;
import org.raddatz.familienarchiv.model.BlockSource; import org.raddatz.familienarchiv.model.BlockSource;
import org.raddatz.familienarchiv.model.Document; import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.model.DocumentAnnotation; import org.raddatz.familienarchiv.model.DocumentAnnotation;
import org.raddatz.familienarchiv.model.ScriptType;
import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.model.TranscriptionBlockVersion; import org.raddatz.familienarchiv.model.TranscriptionBlockVersion;
import org.raddatz.familienarchiv.repository.AnnotationRepository; import org.raddatz.familienarchiv.repository.AnnotationRepository;
@@ -35,6 +36,7 @@ public class TranscriptionService {
private final AnnotationRepository annotationRepository; private final AnnotationRepository annotationRepository;
private final AnnotationService annotationService; private final AnnotationService annotationService;
private final DocumentService documentService; private final DocumentService documentService;
private final SenderModelService senderModelService;
public List<TranscriptionBlock> listBlocks(UUID documentId) { public List<TranscriptionBlock> listBlocks(UUID documentId) {
return blockRepository.findByDocumentIdOrderBySortOrderAsc(documentId); return blockRepository.findByDocumentIdOrderBySortOrderAsc(documentId);
@@ -122,6 +124,7 @@ public class TranscriptionService {
String text = sanitizeText(dto.getText()); String text = sanitizeText(dto.getText());
block.setText(text); block.setText(text);
block.setSource(BlockSource.MANUAL);
if (dto.getLabel() != null) { if (dto.getLabel() != null) {
block.setLabel(dto.getLabel()); block.setLabel(dto.getLabel());
} }
@@ -129,6 +132,12 @@ public class TranscriptionService {
TranscriptionBlock saved = blockRepository.save(block); TranscriptionBlock saved = blockRepository.save(block);
saveVersion(saved, userId); saveVersion(saved, userId);
Document doc = documentService.getDocumentById(documentId);
if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) {
senderModelService.checkAndTriggerTraining(doc.getSender().getId());
}
return saved; return saved;
} }

View File

@@ -77,3 +77,8 @@ app:
tags: 10 tags: 10
summary: 11 summary: 11
transcription: 13 transcription: 13
ocr:
sender-model:
activation-threshold: 100
retrain-delta: 50

View File

@@ -0,0 +1,10 @@
CREATE TABLE sender_models (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
person_id UUID NOT NULL UNIQUE REFERENCES persons(id) ON DELETE CASCADE,
model_path TEXT NOT NULL,
accuracy DOUBLE PRECISION,
cer DOUBLE PRECISION,
corrected_lines_at_training INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);

View File

@@ -0,0 +1,2 @@
ALTER TABLE ocr_training_runs
ADD COLUMN person_id UUID REFERENCES persons(id) ON DELETE SET NULL;

View File

@@ -0,0 +1,3 @@
CREATE UNIQUE INDEX idx_training_runs_queued_per_person
ON ocr_training_runs(person_id)
WHERE status = 'QUEUED';

View File

@@ -23,6 +23,8 @@ import org.springframework.test.web.servlet.MockMvc;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import java.util.Map;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
@@ -219,7 +221,7 @@ class OcrControllerTest {
@WithMockUser(authorities = "ADMIN") @WithMockUser(authorities = "ADMIN")
void getTrainingInfo_returns200_withInfo() throws Exception { void getTrainingInfo_returns200_withInfo() throws Exception {
OcrTrainingService.TrainingInfoResponse info = OcrTrainingService.TrainingInfoResponse info =
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of()); new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(), Map.of());
when(ocrTrainingService.getTrainingInfo()).thenReturn(info); when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info")) mockMvc.perform(get("/api/ocr/training-info"))
@@ -228,6 +230,52 @@ class OcrControllerTest {
.andExpect(jsonPath("$.ocrServiceAvailable").value(true)); .andExpect(jsonPath("$.ocrServiceAvailable").value(true));
} }
@Test
@WithMockUser(authorities = "ADMIN")
void getTrainingInfo_returns200_and_omits_personName_when_service_provides_empty_map() throws Exception {
UUID personId = UUID.randomUUID();
OcrTrainingRun runWithPerson = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build();
OcrTrainingService.TrainingInfoResponse info =
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null,
List.of(runWithPerson), Map.of());
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.personNames").isEmpty());
}
@Test
@WithMockUser(authorities = "ADMIN")
void getTrainingInfo_includesPersonName_whenPersonIdResolves() throws Exception {
UUID personId = UUID.randomUUID();
OcrTrainingRun runWithPerson = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build();
OcrTrainingService.TrainingInfoResponse info =
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null,
List.of(runWithPerson), Map.of(personId.toString(), "Max Mustermann"));
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.personNames." + personId).value("Max Mustermann"));
}
@Test
@WithMockUser(authorities = "ADMIN")
void getTrainingInfo_serializes_null_lastRun_as_json_null() throws Exception {
OcrTrainingService.TrainingInfoResponse info =
new OcrTrainingService.TrainingInfoResponse(0, 0, 0, 0, false, null, List.of(), Map.of());
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
mockMvc.perform(get("/api/ocr/training-info"))
.andExpect(status().isOk())
.andExpect(jsonPath("$.lastRun").doesNotExist());
}
@Test @Test
@WithMockUser(authorities = "READ_ALL") @WithMockUser(authorities = "READ_ALL")
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception { void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {

View File

@@ -217,8 +217,71 @@ class MigrationIntegrationTest {
assertThat(stored).isEqualTo("sage"); assertThat(stored).isEqualTo("sage");
} }
// ─── V42: idx_training_runs_queued_per_person partial unique index ────────
@Test
@Transactional(propagation = Propagation.NOT_SUPPORTED)
void v42_partialUniqueIndex_preventsTwoQueuedRunsForSamePerson() {
UUID personId = createPerson("V42Test", "SamePerson");
jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
VALUES (gen_random_uuid(), 'QUEUED', 10, 2, 'sender_x', ?)
""", personId);
assertThatThrownBy(() ->
jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
VALUES (gen_random_uuid(), 'QUEUED', 5, 1, 'sender_x', ?)
""", personId)
).isInstanceOf(DataIntegrityViolationException.class);
jdbc.update("DELETE FROM ocr_training_runs WHERE person_id = ?", personId);
jdbc.update("DELETE FROM persons WHERE id = ?", personId);
}
@Test
void v42_partialUniqueIndex_allowsQueuedRunsForDifferentPersons() {
UUID person1 = createPerson("V42Test", "PersonA");
UUID person2 = createPerson("V42Test", "PersonB");
int rows1 = jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
VALUES (gen_random_uuid(), 'QUEUED', 10, 2, 'sender_a', ?)
""", person1);
int rows2 = jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
VALUES (gen_random_uuid(), 'QUEUED', 5, 1, 'sender_b', ?)
""", person2);
assertThat(rows1).isEqualTo(1);
assertThat(rows2).isEqualTo(1);
}
@Test
void v42_partialUniqueIndex_allowsMultipleDoneRunsForSamePerson() {
UUID personId = createPerson("V42Test", "DonePerson");
int rows1 = jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
VALUES (gen_random_uuid(), 'DONE', 10, 2, 'sender_x', ?)
""", personId);
int rows2 = jdbc.update("""
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
VALUES (gen_random_uuid(), 'DONE', 15, 3, 'sender_x', ?)
""", personId);
assertThat(rows1).isEqualTo(1);
assertThat(rows2).isEqualTo(1);
}
// ─── helpers ───────────────────────────────────────────────────────────── // ─── helpers ─────────────────────────────────────────────────────────────
private UUID createPerson(String firstName, String lastName) {
UUID id = UUID.randomUUID();
jdbc.update("INSERT INTO persons (id, first_name, last_name) VALUES (?, ?, ?)", id, firstName, lastName);
return id;
}
private UUID createTag(String name) { private UUID createTag(String name) {
UUID id = UUID.randomUUID(); UUID id = UUID.randomUUID();
jdbc.update("INSERT INTO tag (id, name) VALUES (?, ?)", id, name); jdbc.update("INSERT INTO tag (id, name) VALUES (?, ?)", id, name);

View File

@@ -31,6 +31,7 @@ class OcrAsyncRunnerTest {
@Mock OcrJobRepository ocrJobRepository; @Mock OcrJobRepository ocrJobRepository;
@Mock OcrJobDocumentRepository ocrJobDocumentRepository; @Mock OcrJobDocumentRepository ocrJobDocumentRepository;
@Mock OcrProgressService ocrProgressService; @Mock OcrProgressService ocrProgressService;
@Mock SenderModelService senderModelService;
@InjectMocks OcrAsyncRunner ocrAsyncRunner; @InjectMocks OcrAsyncRunner ocrAsyncRunner;
@@ -42,7 +43,12 @@ class OcrAsyncRunnerTest {
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build(); .fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of()); doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(0));
handler.accept(new OcrStreamEvent.Done(0, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.processDocument(docId, doc, userId); ocrAsyncRunner.processDocument(docId, doc, userId);
@@ -59,9 +65,15 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of( doAnswer(inv -> {
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null), Consumer<OcrStreamEvent> handler = inv.getArgument(4);
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null))); handler.accept(new OcrStreamEvent.Start(1));
handler.accept(new OcrStreamEvent.Page(0, List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null))));
handler.accept(new OcrStreamEvent.Done(2, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build(); DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann); when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
@@ -83,8 +95,14 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of( doAnswer(inv -> {
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null))); Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(1));
handler.accept(new OcrStreamEvent.Page(0, List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null))));
handler.accept(new OcrStreamEvent.Done(1, 0));
return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build(); DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann); when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
@@ -112,12 +130,12 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> { doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(3); Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(1)); handler.accept(new OcrStreamEvent.Start(1));
handler.accept(new OcrStreamEvent.Page(0, List.of())); handler.accept(new OcrStreamEvent.Page(0, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 0)); handler.accept(new OcrStreamEvent.Done(0, 0));
return null; return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any()); }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId); ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -142,7 +160,7 @@ class OcrAsyncRunnerTest {
when(documentService.getDocumentById(docId)).thenReturn(doc); when(documentService.getDocumentById(docId)).thenReturn(doc);
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any()); doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId); ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -174,7 +192,7 @@ class OcrAsyncRunnerTest {
List<String> progressMessages = new ArrayList<>(); List<String> progressMessages = new ArrayList<>();
doAnswer(inv -> { doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(3); Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(3)); handler.accept(new OcrStreamEvent.Start(3));
handler.accept(new OcrStreamEvent.Page(0, List.of( handler.accept(new OcrStreamEvent.Page(0, List.of(
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null), new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null),
@@ -185,7 +203,7 @@ class OcrAsyncRunnerTest {
progressMessages.add(job.getProgressMessage()); progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Done(3, 0)); handler.accept(new OcrStreamEvent.Done(3, 0));
return null; return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any()); }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId); ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -215,14 +233,14 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> { doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(3); Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(3)); handler.accept(new OcrStreamEvent.Start(3));
handler.accept(new OcrStreamEvent.Page(0, List.of())); handler.accept(new OcrStreamEvent.Page(0, List.of()));
handler.accept(new OcrStreamEvent.Error(1, "failed")); handler.accept(new OcrStreamEvent.Error(1, "failed"));
handler.accept(new OcrStreamEvent.Page(2, List.of())); handler.accept(new OcrStreamEvent.Page(2, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 1)); handler.accept(new OcrStreamEvent.Done(0, 1));
return null; return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any()); }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId); ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -251,14 +269,14 @@ class OcrAsyncRunnerTest {
List<String> progressMessages = new ArrayList<>(); List<String> progressMessages = new ArrayList<>();
doAnswer(inv -> { doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(3); Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(5)); handler.accept(new OcrStreamEvent.Start(5));
handler.accept(new OcrStreamEvent.Preprocessing(1)); handler.accept(new OcrStreamEvent.Preprocessing(1));
progressMessages.add(job.getProgressMessage()); progressMessages.add(job.getProgressMessage());
handler.accept(new OcrStreamEvent.Page(1, List.of())); handler.accept(new OcrStreamEvent.Page(1, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 0)); handler.accept(new OcrStreamEvent.Done(0, 0));
return null; return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any()); }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId); ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
@@ -287,13 +305,13 @@ class OcrAsyncRunnerTest {
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned"); when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
doAnswer(inv -> { doAnswer(inv -> {
Consumer<OcrStreamEvent> handler = inv.getArgument(3); Consumer<OcrStreamEvent> handler = inv.getArgument(4);
handler.accept(new OcrStreamEvent.Start(2)); handler.accept(new OcrStreamEvent.Start(2));
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details")); handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
handler.accept(new OcrStreamEvent.Page(1, List.of())); handler.accept(new OcrStreamEvent.Page(1, List.of()));
handler.accept(new OcrStreamEvent.Done(0, 1)); handler.accept(new OcrStreamEvent.Done(0, 1));
return null; return null;
}).when(ocrClient).streamBlocks(any(), any(), any(), any()); }).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
ocrAsyncRunner.runSingleDocument(jobId, docId, userId); ocrAsyncRunner.runSingleDocument(jobId, docId, userId);

View File

@@ -21,6 +21,7 @@ class OcrClientDefaultStreamTest {
} }
@Override public TrainingResult trainModel(byte[] zip) { return null; } @Override public TrainingResult trainModel(byte[] zip) { return null; }
@Override public TrainingResult segtrainModel(byte[] zip) { return null; } @Override public TrainingResult segtrainModel(byte[] zip) { return null; }
@Override public TrainingResult trainSenderModel(byte[] zip, String path) { return null; }
}; };
List<OcrStreamEvent> events = new ArrayList<>(); List<OcrStreamEvent> events = new ArrayList<>();
@@ -52,6 +53,7 @@ class OcrClientDefaultStreamTest {
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); } @Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); }
@Override public TrainingResult trainModel(byte[] zip) { return null; } @Override public TrainingResult trainModel(byte[] zip) { return null; }
@Override public TrainingResult segtrainModel(byte[] zip) { return null; } @Override public TrainingResult segtrainModel(byte[] zip) { return null; }
@Override public TrainingResult trainSenderModel(byte[] zip, String path) { return null; }
}; };
List<OcrStreamEvent> events = new ArrayList<>(); List<OcrStreamEvent> events = new ArrayList<>();

View File

@@ -4,11 +4,12 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.raddatz.familienarchiv.exception.DomainException; import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.OcrTrainingRun; import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.TrainingLabel; import org.raddatz.familienarchiv.model.Person;
import org.raddatz.familienarchiv.model.TrainingStatus; import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository; import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository; import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.raddatz.familienarchiv.service.PersonService;
import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate; import org.springframework.transaction.support.TransactionTemplate;
@@ -32,6 +33,7 @@ class OcrTrainingServiceTest {
OcrHealthClient healthClient; OcrHealthClient healthClient;
TranscriptionBlockRepository blockRepository; TranscriptionBlockRepository blockRepository;
TransactionTemplate txTemplate; TransactionTemplate txTemplate;
PersonService personService;
OcrTrainingService service; OcrTrainingService service;
@BeforeEach @BeforeEach
@@ -43,6 +45,7 @@ class OcrTrainingServiceTest {
healthClient = mock(OcrHealthClient.class); healthClient = mock(OcrHealthClient.class);
blockRepository = mock(TranscriptionBlockRepository.class); blockRepository = mock(TranscriptionBlockRepository.class);
txTemplate = mock(TransactionTemplate.class); txTemplate = mock(TransactionTemplate.class);
personService = mock(PersonService.class);
// Execute transaction callbacks inline so unit tests run without a real DataSource // Execute transaction callbacks inline so unit tests run without a real DataSource
when(txTemplate.execute(any())).thenAnswer(inv -> { when(txTemplate.execute(any())).thenAnswer(inv -> {
@@ -50,10 +53,10 @@ class OcrTrainingServiceTest {
return callback.doInTransaction(null); return callback.doInTransaction(null);
}); });
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate); service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate, personService);
when(blockRepository.count()).thenReturn(0L); when(blockRepository.count()).thenReturn(0L);
when(runRepository.findTop10ByOrderByCreatedAtDesc()).thenReturn(List.of()); when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of());
when(segExportService.querySegmentationBlocks()).thenReturn(List.of()); when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
} }
@@ -230,6 +233,37 @@ class OcrTrainingServiceTest {
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null)); run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
} }
// ─── getTrainingInfo: batch person name resolution ────────────────────────
@Test
void getTrainingInfo_resolves_person_names_in_single_batch_call() {
UUID personA = UUID.randomUUID();
UUID personB = UUID.randomUUID();
List<OcrTrainingRun> runs = List.of(
OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personA).blockCount(5).documentCount(1).modelName("sender_a").build(),
OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personB).blockCount(5).documentCount(1).modelName("sender_b").build(),
OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE)
.personId(personA).blockCount(5).documentCount(1).modelName("sender_a").build()
);
when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(runs);
when(exportService.queryEligibleBlocks()).thenReturn(List.of());
Person pa = Person.builder().id(personA).firstName("Anna").lastName("Müller").build();
Person pb = Person.builder().id(personB).firstName("Karl").lastName("Bauer").build();
when(personService.getAllById(argThat(ids -> ids.size() == 2 && ids.containsAll(List.of(personA, personB)))))
.thenReturn(List.of(pa, pb));
when(healthClient.isHealthy()).thenReturn(true);
OcrTrainingService.TrainingInfoResponse info = service.getTrainingInfo();
verify(personService, never()).getById(any());
verify(personService, times(1)).getAllById(any());
assertThat(info.personNames()).containsKey(personA.toString());
assertThat(info.personNames()).containsKey(personB.toString());
}
// ─── Orphan recovery ────────────────────────────────────────────────────── // ─── Orphan recovery ──────────────────────────────────────────────────────
@Test @Test

View File

@@ -0,0 +1,41 @@
package org.raddatz.familienarchiv.service;
import com.sun.net.httpserver.HttpServer;
import org.junit.jupiter.api.Test;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import static org.assertj.core.api.Assertions.assertThat;
class RestClientOcrClientTrainingTest {
@Test
void trainSenderModel_postsToTrainSenderEndpoint() throws Exception {
List<String> capturedPaths = new CopyOnWriteArrayList<>();
HttpServer server = HttpServer.create(new InetSocketAddress(0), 0);
server.createContext("/", exchange -> {
capturedPaths.add(exchange.getRequestURI().getPath());
byte[] response = "{\"loss\":0.1,\"accuracy\":0.95,\"cer\":0.05,\"epochs\":3}"
.getBytes(StandardCharsets.UTF_8);
exchange.getResponseHeaders().set("Content-Type", "application/json");
exchange.sendResponseHeaders(200, response.length);
try (OutputStream os = exchange.getResponseBody()) {
os.write(response);
}
});
server.start();
try {
int port = server.getAddress().getPort();
RestClientOcrClient client = new RestClientOcrClient("http://localhost:" + port, "");
client.trainSenderModel(new byte[]{1, 2, 3}, "/app/models/sender_test.mlmodel");
assertThat(capturedPaths).contains("/train-sender");
} finally {
server.stop(0);
}
}
}

View File

@@ -0,0 +1,302 @@
package org.raddatz.familienarchiv.service;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.junit.jupiter.MockitoExtension;
import org.raddatz.familienarchiv.model.OcrTrainingRun;
import org.raddatz.familienarchiv.model.SenderModel;
import org.raddatz.familienarchiv.model.TrainingStatus;
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
import org.raddatz.familienarchiv.repository.SenderModelRepository;
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
import java.util.Optional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class SenderModelServiceTest {
SenderModelRepository senderModelRepository;
TranscriptionBlockRepository blockRepository;
OcrTrainingRunRepository trainingRunRepository;
OcrClient ocrClient;
TransactionTemplate txTemplate;
TrainingDataExportService trainingDataExportService;
SenderModelService service;
UUID personId = UUID.randomUUID();
@BeforeEach
void setUp() {
senderModelRepository = mock(SenderModelRepository.class);
blockRepository = mock(TranscriptionBlockRepository.class);
trainingRunRepository = mock(OcrTrainingRunRepository.class);
ocrClient = mock(OcrClient.class);
txTemplate = mock(TransactionTemplate.class);
trainingDataExportService = mock(TrainingDataExportService.class);
// Execute transaction callbacks inline so unit tests run without a real DataSource.
// lenient: not every test hits the txTemplate path, but the setup is shared.
lenient().when(txTemplate.execute(any())).thenAnswer(inv -> {
TransactionCallback<?> callback = inv.getArgument(0);
return callback.doInTransaction(null);
});
service = new SenderModelService(senderModelRepository, blockRepository,
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService);
ReflectionTestUtils.setField(service, "activationThreshold", 100);
ReflectionTestUtils.setField(service, "retrainDelta", 50);
}
// ─── Activation threshold ─────────────────────────────────────────────────
@Test
void checkAndTriggerTraining_doesNothing_belowActivationThreshold() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
SenderModelService spy = spy(service);
spy.checkAndTriggerTraining(personId);
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
}
@Test
void checkAndTriggerTraining_triggersTraining_atActivationThreshold() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
SenderModelService spy = spy(service);
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
spy.checkAndTriggerTraining(personId);
verify(spy).runOrQueueSenderTraining(personId, 100);
}
// ─── Retrain delta ────────────────────────────────────────────────────────
@Test
void checkAndTriggerTraining_doesNothing_belowRetrainDelta() {
SenderModel existing = SenderModel.builder().personId(personId)
.correctedLinesAtTraining(100).build();
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L);
SenderModelService spy = spy(service);
spy.checkAndTriggerTraining(personId);
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
}
@Test
void checkAndTriggerTraining_triggersRetrain_atRetrainDelta() {
SenderModel existing = SenderModel.builder().personId(personId)
.correctedLinesAtTraining(100).build();
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L);
SenderModelService spy = spy(service);
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 150);
spy.checkAndTriggerTraining(personId);
verify(spy).runOrQueueSenderTraining(personId, 150);
}
// ─── runNow flag ──────────────────────────────────────────────────────────
@Test
void checkAndTriggerTraining_callsTrigger_whenRunNow() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
SenderModelService spy = spy(service);
doReturn(true).when(spy).runOrQueueSenderTraining(personId, 100);
doNothing().when(spy).triggerSenderTraining(personId, 100);
spy.checkAndTriggerTraining(personId);
verify(spy).triggerSenderTraining(personId, 100);
}
@Test
void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() {
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
SenderModelService spy = spy(service);
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
spy.checkAndTriggerTraining(personId);
verify(spy, never()).triggerSenderTraining(any(), anyInt());
}
// ─── runOrQueueSenderTraining ─────────────────────────────────────────────
@Test
void runOrQueueSenderTraining_returnsTrue_whenIdle() {
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isTrue();
verify(trainingRunRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED));
}
@Test
void runOrQueueSenderTraining_savesQueued_whenRunning() {
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(
Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.blockCount(5).documentCount(1).modelName("german_kurrent").build()));
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(trainingRunRepository).save(argThat(r ->
r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId())));
}
@Test
void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() {
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true);
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isFalse();
verify(trainingRunRepository, never()).save(any());
}
@Test
void runOrQueueSenderTraining_createsRunningRow_atomically() {
// RUNNING row must be created inside the same @Transactional boundary as the idle check,
// eliminating the race window between the check and a separate triggerSenderTraining call.
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
OcrTrainingRun r = inv.getArgument(0);
if (r.getId() == null) r.setId(UUID.randomUUID());
return r;
});
boolean result = service.runOrQueueSenderTraining(personId, 120);
assertThat(result).isTrue();
verify(trainingRunRepository).save(argThat(r ->
r.getStatus() == TrainingStatus.RUNNING && personId.equals(r.getPersonId())));
}
// ─── triggerSenderTraining — happy path ──────────────────────────────────
@Test
void triggerSenderTraining_savesModelRecord_onSuccess() throws Exception {
OcrTrainingRun runningRun = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.personId(personId).blockCount(10).documentCount(0)
.modelName("sender_" + personId).build();
when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING))
.thenReturn(Optional.of(runningRun));
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
when(ocrClient.trainSenderModel(any(), any()))
.thenReturn(new OcrClient.TrainingResult(0.08, 0.92, 0.08, 5));
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
when(senderModelRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
.thenReturn(Optional.empty());
service.triggerSenderTraining(personId, 120);
verify(senderModelRepository).save(argThat(m ->
("/app/models/sender_" + personId + ".mlmodel").equals(m.getModelPath())
&& Double.valueOf(0.08).equals(m.getCer())
&& personId.equals(m.getPersonId())));
verify(trainingRunRepository, atLeastOnce()).save(argThat(r ->
r.getStatus() == TrainingStatus.DONE
&& Double.valueOf(0.08).equals(r.getCer())
&& r.getCompletedAt() != null));
}
// ─── triggerSenderTraining — failure path ────────────────────────────────
@Test
void triggerSenderTraining_marksRunFailed_whenOcrClientThrows() throws Exception {
OcrTrainingRun runningRun = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.personId(personId).blockCount(10).documentCount(0)
.modelName("sender_" + personId).build();
when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING))
.thenReturn(Optional.of(runningRun));
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
when(ocrClient.trainSenderModel(any(), any()))
.thenThrow(new RuntimeException("OCR service timeout"));
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
.thenReturn(Optional.empty());
service.triggerSenderTraining(personId, 120);
verify(trainingRunRepository, atLeastOnce()).save(argThat(r ->
r.getStatus() == TrainingStatus.FAILED
&& r.getErrorMessage() != null
&& r.getCompletedAt() != null));
verify(senderModelRepository, never()).save(any());
}
// ─── triggerSenderTraining — queue promotion ──────────────────────────────
@Test
void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception {
UUID nextPersonId = UUID.randomUUID();
OcrTrainingRun runningRun = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
.personId(personId).blockCount(10).documentCount(0)
.modelName("sender_" + personId).build();
when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING))
.thenReturn(Optional.of(runningRun));
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
OcrTrainingRun r = inv.getArgument(0);
if (r.getId() == null) r.setId(UUID.randomUUID());
return r;
});
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
when(ocrClient.trainSenderModel(any(), any()))
.thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
OcrTrainingRun queued = OcrTrainingRun.builder()
.id(UUID.randomUUID()).status(TrainingStatus.QUEUED)
.personId(nextPersonId).blockCount(5).documentCount(0)
.modelName("sender_" + nextPersonId).build();
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
.thenReturn(Optional.of(queued));
when(blockRepository.countManualKurrentBlocksByPerson(nextPersonId)).thenReturn(5L);
SenderModelService spy = spy(service);
// Stub the recursive call to stop the chain after one promotion
doNothing().when(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
spy.triggerSenderTraining(personId, 10);
// Queued run promoted to RUNNING
verify(trainingRunRepository).save(argThat(r ->
nextPersonId.equals(r.getPersonId()) && r.getStatus() == TrainingStatus.RUNNING));
// Next training was kicked off
verify(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
}
}

View File

@@ -22,6 +22,7 @@ class TranscriptionServiceGuidedTest {
AnnotationRepository annotationRepository; AnnotationRepository annotationRepository;
AnnotationService annotationService; AnnotationService annotationService;
DocumentService documentService; DocumentService documentService;
SenderModelService senderModelService;
TranscriptionService service; TranscriptionService service;
UUID docId = UUID.randomUUID(); UUID docId = UUID.randomUUID();
@@ -35,9 +36,10 @@ class TranscriptionServiceGuidedTest {
annotationRepository = mock(AnnotationRepository.class); annotationRepository = mock(AnnotationRepository.class);
annotationService = mock(AnnotationService.class); annotationService = mock(AnnotationService.class);
documentService = mock(DocumentService.class); documentService = mock(DocumentService.class);
senderModelService = mock(SenderModelService.class);
service = new TranscriptionService(blockRepository, versionRepository, service = new TranscriptionService(blockRepository, versionRepository,
annotationRepository, annotationService, documentService); annotationRepository, annotationService, documentService, senderModelService);
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));

View File

@@ -13,6 +13,8 @@ import org.raddatz.familienarchiv.exception.DomainException;
import org.raddatz.familienarchiv.model.BlockSource; import org.raddatz.familienarchiv.model.BlockSource;
import org.raddatz.familienarchiv.model.Document; import org.raddatz.familienarchiv.model.Document;
import org.raddatz.familienarchiv.model.DocumentAnnotation; import org.raddatz.familienarchiv.model.DocumentAnnotation;
import org.raddatz.familienarchiv.model.Person;
import org.raddatz.familienarchiv.model.ScriptType;
import org.raddatz.familienarchiv.model.TranscriptionBlock; import org.raddatz.familienarchiv.model.TranscriptionBlock;
import org.raddatz.familienarchiv.model.TranscriptionBlockVersion; import org.raddatz.familienarchiv.model.TranscriptionBlockVersion;
import org.raddatz.familienarchiv.repository.AnnotationRepository; import org.raddatz.familienarchiv.repository.AnnotationRepository;
@@ -39,6 +41,7 @@ class TranscriptionServiceTest {
@Mock AnnotationRepository annotationRepository; @Mock AnnotationRepository annotationRepository;
@Mock AnnotationService annotationService; @Mock AnnotationService annotationService;
@Mock DocumentService documentService; @Mock DocumentService documentService;
@Mock SenderModelService senderModelService;
@InjectMocks TranscriptionService transcriptionService; @InjectMocks TranscriptionService transcriptionService;
// ─── getBlock ──────────────────────────────────────────────────────────────── // ─── getBlock ────────────────────────────────────────────────────────────────
@@ -156,6 +159,8 @@ class TranscriptionServiceTest {
.id(blockId).documentId(docId).text("old").build(); .id(blockId).documentId(docId).text("old").build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.TYPEWRITER).build());
UpdateTranscriptionBlockDTO dto = new UpdateTranscriptionBlockDTO("new text", null); UpdateTranscriptionBlockDTO dto = new UpdateTranscriptionBlockDTO("new text", null);
@@ -175,6 +180,8 @@ class TranscriptionServiceTest {
.id(blockId).documentId(docId).text("text").label("old label").build(); .id(blockId).documentId(docId).text("text").label("old label").build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block)); when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0)); when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.TYPEWRITER).build());
UpdateTranscriptionBlockDTO dto = new UpdateTranscriptionBlockDTO("text", "Anrede"); UpdateTranscriptionBlockDTO dto = new UpdateTranscriptionBlockDTO("text", "Anrede");
@@ -183,6 +190,57 @@ class TranscriptionServiceTest {
assertThat(result.getLabel()).isEqualTo("Anrede"); assertThat(result.getLabel()).isEqualTo("Anrede");
} }
@Test
void updateBlock_setsSourceToManual() {
UUID docId = UUID.randomUUID();
UUID blockId = UUID.randomUUID();
TranscriptionBlock block = TranscriptionBlock.builder()
.id(blockId).documentId(docId).text("old").source(BlockSource.OCR).build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.TYPEWRITER).build());
TranscriptionBlock result = transcriptionService.updateBlock(
docId, blockId, new UpdateTranscriptionBlockDTO("new", null), UUID.randomUUID());
assertThat(result.getSource()).isEqualTo(BlockSource.MANUAL);
}
@Test
void updateBlock_triggersTraining_whenKurrentSenderPresent() {
UUID docId = UUID.randomUUID();
UUID blockId = UUID.randomUUID();
UUID senderId = UUID.randomUUID();
Person sender = Person.builder().id(senderId).build();
TranscriptionBlock block = TranscriptionBlock.builder()
.id(blockId).documentId(docId).text("old").build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).sender(sender).build());
transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID());
verify(senderModelService).checkAndTriggerTraining(senderId);
}
@Test
void updateBlock_doesNotTriggerTraining_whenNoSender() {
UUID docId = UUID.randomUUID();
UUID blockId = UUID.randomUUID();
TranscriptionBlock block = TranscriptionBlock.builder()
.id(blockId).documentId(docId).text("old").build();
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
when(documentService.getDocumentById(any())).thenReturn(
Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).build());
transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID());
verify(senderModelService, never()).checkAndTriggerTraining(any());
}
// ─── deleteBlock ───────────────────────────────────────────────────────────── // ─── deleteBlock ─────────────────────────────────────────────────────────────
@Test @Test

View File

@@ -97,6 +97,7 @@ services:
DETECTOR_BATCH_SIZE: "8" DETECTOR_BATCH_SIZE: "8"
OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency) OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency)
OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page) OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page)
OCR_MAX_CACHED_MODELS: "2" # LRU cache; each model ~500 MB, so 2 = ~1 GB resident
networks: networks:
- archive-net - archive-net
healthcheck: healthcheck:

View File

@@ -330,6 +330,7 @@
"comment_time_days": "vor {count} Tag(en)", "comment_time_days": "vor {count} Tag(en)",
"comment_panel_title": "Kommentare", "comment_panel_title": "Kommentare",
"comment_panel_close": "Schließen", "comment_panel_close": "Schließen",
"comp_dismiss": "Schließen",
"doc_panel_tab_metadata": "Metadaten", "doc_panel_tab_metadata": "Metadaten",
"doc_panel_tab_transcription": "Transkription", "doc_panel_tab_transcription": "Transkription",
"doc_panel_tab_discussion": "Diskussion", "doc_panel_tab_discussion": "Diskussion",
@@ -541,6 +542,7 @@
"training_start_btn": "Training starten", "training_start_btn": "Training starten",
"training_in_progress": "…", "training_in_progress": "…",
"training_success": "Training wurde gestartet und abgeschlossen.", "training_success": "Training wurde gestartet und abgeschlossen.",
"training_start_failed": "Training konnte nicht gestartet werden. Bitte versuche es erneut.",
"training_too_few_blocks": "Mindestens 5 geprüfte Blöcke erforderlich (aktuell: {available}).", "training_too_few_blocks": "Mindestens 5 geprüfte Blöcke erforderlich (aktuell: {available}).",
"training_service_down": "OCR-Dienst ist nicht erreichbar.", "training_service_down": "OCR-Dienst ist nicht erreichbar.",
"training_history_heading": "Verlauf", "training_history_heading": "Verlauf",
@@ -561,6 +563,11 @@
"transcription_block_segmentation_only": "Nur Segmentierung", "transcription_block_segmentation_only": "Nur Segmentierung",
"training_chip_kurrent": "Kurrent-Erkennung", "training_chip_kurrent": "Kurrent-Erkennung",
"training_chip_segmentation": "Segmentierung", "training_chip_segmentation": "Segmentierung",
"training_col_type": "Typ",
"training_type_base": "Basis",
"training_type_personalized": "Personalisiert",
"training_col_person": "Absender",
"training_status_queued": "Warteschlange",
"mission_control_heading": "Was braucht Aufmerksamkeit?", "mission_control_heading": "Was braucht Aufmerksamkeit?",
"mission_control_segmentation_heading": "Text markieren", "mission_control_segmentation_heading": "Text markieren",
"mission_control_segmentation_description": "Textbereiche markieren — keine Vorkenntnisse nötig", "mission_control_segmentation_description": "Textbereiche markieren — keine Vorkenntnisse nötig",

View File

@@ -330,6 +330,7 @@
"comment_time_days": "{count} day(s) ago", "comment_time_days": "{count} day(s) ago",
"comment_panel_title": "Comments", "comment_panel_title": "Comments",
"comment_panel_close": "Close", "comment_panel_close": "Close",
"comp_dismiss": "Dismiss",
"doc_panel_tab_metadata": "Metadata", "doc_panel_tab_metadata": "Metadata",
"doc_panel_tab_transcription": "Transcription", "doc_panel_tab_transcription": "Transcription",
"doc_panel_tab_discussion": "Discussion", "doc_panel_tab_discussion": "Discussion",
@@ -541,6 +542,7 @@
"training_start_btn": "Start training", "training_start_btn": "Start training",
"training_in_progress": "…", "training_in_progress": "…",
"training_success": "Training started and completed.", "training_success": "Training started and completed.",
"training_start_failed": "Failed to start training. Please try again.",
"training_too_few_blocks": "At least 5 reviewed blocks required (currently: {available}).", "training_too_few_blocks": "At least 5 reviewed blocks required (currently: {available}).",
"training_service_down": "OCR service is unavailable.", "training_service_down": "OCR service is unavailable.",
"training_history_heading": "History", "training_history_heading": "History",
@@ -561,6 +563,11 @@
"transcription_block_segmentation_only": "Segmentation only", "transcription_block_segmentation_only": "Segmentation only",
"training_chip_kurrent": "Kurrent recognition", "training_chip_kurrent": "Kurrent recognition",
"training_chip_segmentation": "Segmentation", "training_chip_segmentation": "Segmentation",
"training_col_type": "Type",
"training_type_base": "Base",
"training_type_personalized": "Personalized",
"training_col_person": "Sender",
"training_status_queued": "Queued",
"mission_control_heading": "What needs attention?", "mission_control_heading": "What needs attention?",
"mission_control_segmentation_heading": "Mark text", "mission_control_segmentation_heading": "Mark text",
"mission_control_segmentation_description": "Mark text areas — no prior knowledge needed", "mission_control_segmentation_description": "Mark text areas — no prior knowledge needed",

View File

@@ -330,6 +330,7 @@
"comment_time_days": "hace {count} día(s)", "comment_time_days": "hace {count} día(s)",
"comment_panel_title": "Comentarios", "comment_panel_title": "Comentarios",
"comment_panel_close": "Cerrar", "comment_panel_close": "Cerrar",
"comp_dismiss": "Cerrar",
"doc_panel_tab_metadata": "Metadatos", "doc_panel_tab_metadata": "Metadatos",
"doc_panel_tab_transcription": "Transcripción", "doc_panel_tab_transcription": "Transcripción",
"doc_panel_tab_discussion": "Discusión", "doc_panel_tab_discussion": "Discusión",
@@ -541,6 +542,7 @@
"training_start_btn": "Iniciar entrenamiento", "training_start_btn": "Iniciar entrenamiento",
"training_in_progress": "…", "training_in_progress": "…",
"training_success": "Entrenamiento iniciado y completado.", "training_success": "Entrenamiento iniciado y completado.",
"training_start_failed": "No se pudo iniciar el entrenamiento. Por favor, inténtalo de nuevo.",
"training_too_few_blocks": "Se requieren al menos 5 bloques revisados (actualmente: {available}).", "training_too_few_blocks": "Se requieren al menos 5 bloques revisados (actualmente: {available}).",
"training_service_down": "El servicio OCR no está disponible.", "training_service_down": "El servicio OCR no está disponible.",
"training_history_heading": "Historial", "training_history_heading": "Historial",
@@ -561,6 +563,11 @@
"transcription_block_segmentation_only": "Solo segmentación", "transcription_block_segmentation_only": "Solo segmentación",
"training_chip_kurrent": "Reconocimiento Kurrent", "training_chip_kurrent": "Reconocimiento Kurrent",
"training_chip_segmentation": "Segmentación", "training_chip_segmentation": "Segmentación",
"training_col_type": "Tipo",
"training_type_base": "Base",
"training_type_personalized": "Personalizado",
"training_col_person": "Remitente",
"training_status_queued": "En cola",
"mission_control_heading": "¿Qué necesita atención?", "mission_control_heading": "¿Qué necesita atención?",
"mission_control_segmentation_heading": "Marcar texto", "mission_control_segmentation_heading": "Marcar texto",
"mission_control_segmentation_description": "Marcar áreas de texto — sin conocimientos previos", "mission_control_segmentation_description": "Marcar áreas de texto — sin conocimientos previos",

View File

@@ -1,26 +1,16 @@
<script lang="ts"> <script lang="ts">
import TrainingHistory from './TrainingHistory.svelte'; import TrainingHistory from './TrainingHistory.svelte';
import { m } from '$lib/paraglide/messages.js'; import { m } from '$lib/paraglide/messages.js';
import type { TrainingRun } from '$lib/types/training.js';
interface Run {
id: string;
status: 'RUNNING' | 'DONE' | 'FAILED';
blockCount: number;
documentCount: number;
modelName: string;
errorMessage?: string;
triggeredBy?: string;
createdAt: string;
completedAt?: string;
}
interface TrainingInfo { interface TrainingInfo {
availableBlocks?: number; availableBlocks?: number;
totalOcrBlocks?: number; totalOcrBlocks?: number;
availableDocuments?: number; availableDocuments?: number;
ocrServiceAvailable?: boolean; ocrServiceAvailable?: boolean;
lastRun?: Run | null; lastRun?: TrainingRun | null;
runs?: Run[]; runs?: TrainingRun[];
personNames?: Record<string, string>;
} }
interface Props { interface Props {
@@ -31,6 +21,7 @@ let { trainingInfo }: Props = $props();
let training = $state(false); let training = $state(false);
let successMessage = $state<string | null>(null); let successMessage = $state<string | null>(null);
let errorMessage = $state<string | null>(null);
const available = $derived(trainingInfo?.availableBlocks ?? 0); const available = $derived(trainingInfo?.availableBlocks ?? 0);
const tooFewBlocks = $derived(available < 5); const tooFewBlocks = $derived(available < 5);
@@ -40,6 +31,7 @@ const disabled = $derived(training || tooFewBlocks || serviceDown);
async function startTraining() { async function startTraining() {
training = true; training = true;
successMessage = null; successMessage = null;
errorMessage = null;
try { try {
const res = await fetch('/api/ocr/train', { method: 'POST' }); const res = await fetch('/api/ocr/train', { method: 'POST' });
if (res.ok) { if (res.ok) {
@@ -47,7 +39,11 @@ async function startTraining() {
setTimeout(() => { setTimeout(() => {
successMessage = null; successMessage = null;
}, 5000); }, 5000);
} else {
errorMessage = m.training_start_failed();
} }
} catch {
errorMessage = m.training_start_failed();
} finally { } finally {
training = false; training = false;
} }
@@ -82,11 +78,26 @@ async function startTraining() {
{/if} {/if}
{#if successMessage} {#if successMessage}
<p class="mt-2 text-xs text-green-700">{successMessage}</p> <p class="mt-2 flex items-center gap-2 text-xs text-green-700" aria-live="polite">
{successMessage}
<button
type="button"
class="-my-2 ml-1 inline-flex h-11 w-11 items-center justify-center rounded-sm underline hover:no-underline focus-visible:ring-2 focus-visible:ring-brand-navy focus-visible:outline-none"
onclick={() => (successMessage = null)}
aria-label={m.comp_dismiss()}>×</button
>
</p>
{/if}
{#if errorMessage}
<p class="mt-2 text-xs text-red-600" aria-live="assertive">{errorMessage}</p>
{/if} {/if}
<h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase"> <h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase">
{m.training_history_heading()} {m.training_history_heading()}
</h3> </h3>
<TrainingHistory runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName !== 'blla')} /> <TrainingHistory
runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName !== 'blla')}
personNames={trainingInfo?.personNames ?? {}}
/>
</div> </div>

View File

@@ -73,6 +73,24 @@ describe('OcrTrainingCard — enabled state', () => {
}); });
}); });
describe('OcrTrainingCard — success dismiss button', () => {
it('dismiss button has 44×44px touch target (h-11 w-11)', async () => {
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: true }));
render(OcrTrainingCard, { trainingInfo: baseInfo });
const btn = page.getByRole('button', { name: /Training starten/i });
await btn.click();
const dismissBtn = page.getByRole('button', { name: /Schließen/i });
await expect.element(dismissBtn).toBeInTheDocument();
const el = await dismissBtn.element();
expect(el.classList.contains('h-11')).toBe(true);
expect(el.classList.contains('w-11')).toBe(true);
});
});
describe('OcrTrainingCard — in-flight state', () => { describe('OcrTrainingCard — in-flight state', () => {
it('shows "…" while POST is in-flight', async () => { it('shows "…" while POST is in-flight', async () => {
let resolveFetch!: (v: unknown) => void; let resolveFetch!: (v: unknown) => void;

View File

@@ -1,24 +1,13 @@
<script lang="ts"> <script lang="ts">
import { m } from '$lib/paraglide/messages.js'; import { m } from '$lib/paraglide/messages.js';
import type { TrainingRun } from '$lib/types/training.js';
interface Run {
id: string;
status: 'RUNNING' | 'DONE' | 'FAILED';
blockCount: number;
documentCount: number;
modelName: string;
errorMessage?: string;
triggeredBy?: string;
createdAt: string;
completedAt?: string;
cer?: number;
}
interface Props { interface Props {
runs: Run[]; runs: TrainingRun[];
personNames?: Record<string, string>;
} }
let { runs }: Props = $props(); let { runs, personNames }: Props = $props();
const COLLAPSED_COUNT = 3; const COLLAPSED_COUNT = 3;
let expanded = $state(false); let expanded = $state(false);
@@ -47,6 +36,8 @@ function formatCer(cer: number | undefined | null): string {
<tr class="border-b border-line text-xs font-bold tracking-widest text-ink-3 uppercase"> <tr class="border-b border-line text-xs font-bold tracking-widest text-ink-3 uppercase">
<th class="pb-2 text-left">{m.training_history_col_date()}</th> <th class="pb-2 text-left">{m.training_history_col_date()}</th>
<th class="pb-2 text-left">{m.training_history_col_status()}</th> <th class="pb-2 text-left">{m.training_history_col_status()}</th>
<th class="hidden pb-2 text-left md:table-cell">{m.training_col_type()}</th>
<th class="hidden pb-2 text-left md:table-cell">{m.training_col_person()}</th>
<th class="pb-2 text-right">{m.training_history_col_blocks()}</th> <th class="pb-2 text-right">{m.training_history_col_blocks()}</th>
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_docs()}</th> <th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_docs()}</th>
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_cer()}</th> <th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_cer()}</th>
@@ -55,7 +46,7 @@ function formatCer(cer: number | undefined | null): string {
<tbody id="training-history-rows"> <tbody id="training-history-rows">
{#if runs.length === 0} {#if runs.length === 0}
<tr> <tr>
<td colspan="5" class="py-4 text-center text-sm text-ink-2"> <td colspan="7" class="py-4 text-center text-sm text-ink-2">
{m.training_history_empty()} {m.training_history_empty()}
</td> </td>
</tr> </tr>
@@ -64,7 +55,14 @@ function formatCer(cer: number | undefined | null): string {
<tr class="border-b border-line/50 last:border-0"> <tr class="border-b border-line/50 last:border-0">
<td class="py-2 text-ink-2">{formatDate(run.createdAt)}</td> <td class="py-2 text-ink-2">{formatDate(run.createdAt)}</td>
<td class="py-2"> <td class="py-2">
{#if run.status === 'DONE'} {#if run.status === 'QUEUED'}
<span
class="inline-flex items-center gap-1 rounded-sm bg-amber-100 px-1.5 py-0.5 text-xs font-medium text-amber-700"
>
<span aria-hidden="true" class="h-1.5 w-1.5 rounded-full bg-amber-500"></span>
{m.training_status_queued()}
</span>
{:else if run.status === 'DONE'}
<span <span
class="inline-flex items-center gap-1 rounded-sm bg-green-100 px-1.5 py-0.5 text-xs font-medium text-green-700" class="inline-flex items-center gap-1 rounded-sm bg-green-100 px-1.5 py-0.5 text-xs font-medium text-green-700"
> >
@@ -119,6 +117,17 @@ function formatCer(cer: number | undefined | null): string {
{m.training_status_running()} {m.training_status_running()}
</span> </span>
{/if} {/if}
{#if run.personId && personNames?.[run.personId]}
<span class="mt-0.5 block text-xs text-ink-3 md:hidden"
>{personNames[run.personId]}</span
>
{/if}
</td>
<td class="hidden py-2 text-left text-ink-2 md:table-cell">
{run.personId ? m.training_type_personalized() : m.training_type_base()}
</td>
<td class="hidden py-2 text-left text-ink-2 md:table-cell">
{run.personId && personNames?.[run.personId] ? personNames[run.personId] : '—'}
</td> </td>
<td class="py-2 text-right text-ink-2">{run.blockCount}</td> <td class="py-2 text-right text-ink-2">{run.blockCount}</td>
<td class="hidden py-2 text-right text-ink-2 md:table-cell">{run.documentCount}</td> <td class="hidden py-2 text-right text-ink-2 md:table-cell">{run.documentCount}</td>
@@ -137,7 +146,7 @@ function formatCer(cer: number | undefined | null): string {
type="button" type="button"
aria-expanded={expanded} aria-expanded={expanded}
aria-controls="training-history-rows" aria-controls="training-history-rows"
class="text-xs font-medium text-ink-3 transition-colors hover:text-ink" class="rounded-sm py-2 text-xs font-medium text-ink-3 transition-colors hover:text-ink focus-visible:ring-2 focus-visible:ring-brand-navy focus-visible:ring-offset-1 focus-visible:outline-none"
onclick={() => (expanded = !expanded)} onclick={() => (expanded = !expanded)}
> >
{expanded ? m.comp_expandable_show_less() : m.comp_expandable_show_more()} {expanded ? m.comp_expandable_show_less() : m.comp_expandable_show_more()}

View File

@@ -50,3 +50,36 @@ describe('TrainingHistory — expand/collapse', () => {
.not.toBeInTheDocument(); .not.toBeInTheDocument();
}); });
}); });
describe('TrainingHistory — status badges', () => {
it('shows amber QUEUED badge for queued runs', async () => {
const run = { ...makeRun(0), status: 'QUEUED' as const };
render(TrainingHistory, { runs: [run] });
await expect.element(page.getByText(/Warteschlange/i)).toBeInTheDocument();
});
});
describe('TrainingHistory — type and person columns', () => {
it('shows "Basis" for runs without personId', async () => {
render(TrainingHistory, { runs: [makeRun(0)] });
await expect.element(page.getByText(/Basis/i)).toBeInTheDocument();
});
it('shows "Personalisiert" for runs with personId', async () => {
const run = { ...makeRun(0), personId: 'person-1' };
render(TrainingHistory, { runs: [run], personNames: { 'person-1': 'Karl Müller' } });
await expect.element(page.getByText(/Personalisiert/i)).toBeInTheDocument();
});
it('shows person name from personNames for sender runs', async () => {
const run = { ...makeRun(0), personId: 'person-1' };
render(TrainingHistory, { runs: [run], personNames: { 'person-1': 'Karl Müller' } });
await expect.element(page.getByText(/Personalisiert/i)).toBeInTheDocument();
// Name appears in both the mobile inline span and the desktop table cell
await expect.element(page.getByText('Karl Müller').first()).toBeInTheDocument();
});
});

View File

@@ -1374,7 +1374,7 @@ export interface components {
}; };
MergeTagDTO: { MergeTagDTO: {
/** Format: uuid */ /** Format: uuid */
targetId?: string; targetId: string;
}; };
PersonNameAliasDTO: { PersonNameAliasDTO: {
lastName: string; lastName: string;
@@ -1398,7 +1398,7 @@ export interface components {
/** Format: uuid */ /** Format: uuid */
id: string; id: string;
/** @enum {string} */ /** @enum {string} */
status: "RUNNING" | "DONE" | "FAILED"; status: "QUEUED" | "RUNNING" | "DONE" | "FAILED";
/** Format: int32 */ /** Format: int32 */
blockCount: number; blockCount: number;
/** Format: int32 */ /** Format: int32 */
@@ -1415,6 +1415,8 @@ export interface components {
errorMessage?: string; errorMessage?: string;
/** Format: uuid */ /** Format: uuid */
triggeredBy?: string; triggeredBy?: string;
/** Format: uuid */
personId?: string;
/** Format: date-time */ /** Format: date-time */
createdAt: string; createdAt: string;
/** Format: date-time */ /** Format: date-time */
@@ -1596,11 +1598,11 @@ export interface components {
}; };
TagTreeNodeDTO: { TagTreeNodeDTO: {
/** Format: uuid */ /** Format: uuid */
id?: string; id: string;
name?: string; name: string;
color?: string; color?: string;
/** Format: int32 */ /** Format: int32 */
documentCount?: number; documentCount: number;
children?: components["schemas"]["TagTreeNodeDTO"][]; children?: components["schemas"]["TagTreeNodeDTO"][];
/** /**
* Format: uuid * Format: uuid
@@ -1619,7 +1621,6 @@ export interface components {
/** Format: uuid */ /** Format: uuid */
id?: string; id?: string;
displayName?: string; displayName?: string;
personType?: string;
firstName?: string; firstName?: string;
lastName?: string; lastName?: string;
/** Format: int64 */ /** Format: int64 */
@@ -1630,19 +1631,7 @@ export interface components {
deathYear?: number; deathYear?: number;
alias?: string; alias?: string;
notes?: string; notes?: string;
}; personType?: string;
TrainingInfoResponse: {
/** Format: int32 */
availableBlocks?: number;
/** Format: int32 */
totalOcrBlocks?: number;
/** Format: int32 */
availableDocuments?: number;
/** Format: int32 */
availableSegBlocks?: number;
ocrServiceAvailable?: boolean;
lastRun?: components["schemas"]["OcrTrainingRun"];
runs?: components["schemas"]["OcrTrainingRun"][];
}; };
StreamingResponseBody: unknown; StreamingResponseBody: unknown;
OcrJob: { OcrJob: {
@@ -1671,13 +1660,11 @@ export interface components {
timeout?: number; timeout?: number;
}; };
PageNotificationDTO: { PageNotificationDTO: {
/** Format: int64 */
totalElements?: number;
/** Format: int32 */ /** Format: int32 */
totalPages?: number; totalPages?: number;
/** Format: int64 */
totalElements?: number;
pageable?: components["schemas"]["PageableObject"]; pageable?: components["schemas"]["PageableObject"];
first?: boolean;
last?: boolean;
/** Format: int32 */ /** Format: int32 */
size?: number; size?: number;
content?: components["schemas"]["NotificationDTO"][]; content?: components["schemas"]["NotificationDTO"][];
@@ -1686,6 +1673,8 @@ export interface components {
sort?: components["schemas"]["SortObject"]; sort?: components["schemas"]["SortObject"];
/** Format: int32 */ /** Format: int32 */
numberOfElements?: number; numberOfElements?: number;
first?: boolean;
last?: boolean;
empty?: boolean; empty?: boolean;
}; };
PageableObject: { PageableObject: {
@@ -3501,7 +3490,9 @@ export interface operations {
[name: string]: unknown; [name: string]: unknown;
}; };
content: { content: {
"*/*": components["schemas"]["TrainingInfoResponse"]; "*/*": {
[key: string]: unknown;
};
}; };
}; };
}; };

View File

@@ -0,0 +1,13 @@
export interface TrainingRun {
id: string;
status: 'QUEUED' | 'RUNNING' | 'DONE' | 'FAILED';
personId?: string;
blockCount: number;
documentCount: number;
modelName: string;
errorMessage?: string;
triggeredBy?: string;
createdAt: string;
completedAt?: string;
cer?: number;
}

View File

@@ -1,13 +1,84 @@
"""Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin.""" """Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin."""
from __future__ import annotations
import collections
import logging import logging
import os import os
import threading
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from PIL.Image import Image
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_model = None _model = None
_model_path = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel") _model_path = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
_MODELS_DIR = os.path.realpath("/app/models")
_MAX_CACHED_SENDER_MODELS = int(os.environ.get("OCR_MAX_CACHED_MODELS", "5"))
def _load_sender_model(path: str):
"""Load a Kraken model from disk. Extracted for testability."""
from kraken.lib import models as kraken_models
return kraken_models.load_any(path)
class _SenderModelRegistry:
"""Thread-safe LRU cache for per-sender Kraken models.
Uses double-checked locking: model loading happens outside the lock to
avoid blocking concurrent OCR requests. At most one entry per path is
stored even under concurrent load.
"""
def __init__(self, max_size: int):
self._max_size = max_size
self._cache: collections.OrderedDict = collections.OrderedDict()
self._lock = threading.Lock()
def get_model(self, model_path: str):
"""Return the cached model or load it. Validates path is within /app/models/."""
resolved = os.path.realpath(model_path)
if not resolved.startswith(_MODELS_DIR + os.sep) and resolved != _MODELS_DIR:
raise ValueError(f"Sender model path not allowed: {model_path}")
with self._lock:
if model_path in self._cache:
self._cache.move_to_end(model_path)
return self._cache[model_path]
new_model = _load_sender_model(model_path)
with self._lock:
if model_path in self._cache:
self._cache.move_to_end(model_path)
return self._cache[model_path]
self._cache[model_path] = new_model
self._cache.move_to_end(model_path)
while len(self._cache) > self._max_size:
self._cache.popitem(last=False)
return new_model
def invalidate(self, model_path: str) -> None:
"""Remove model from cache so the next request reloads from disk."""
with self._lock:
self._cache.pop(model_path, None)
def size(self) -> int:
with self._lock:
return len(self._cache)
def _contains(self, model_path: str) -> bool:
with self._lock:
return model_path in self._cache
_sender_registry = _SenderModelRegistry(_MAX_CACHED_SENDER_MODELS)
def load_models(): def load_models():
"""Load the Kraken model at startup. Skips if model file is not present.""" """Load the Kraken model at startup. Skips if model file is not present."""
@@ -29,10 +100,12 @@ def is_available() -> bool:
return _model is not None return _model is not None
def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict]: def extract_page_blocks(image: Image, page_idx: int, language: str = "de",
sender_model_path: str | None = None) -> list[dict]:
"""Run Kraken segmentation + recognition on a single PIL image. """Run Kraken segmentation + recognition on a single PIL image.
Returns block dicts for that page. Coordinates are normalized to [0, 1]. Returns block dicts for that page. Coordinates are normalized to [0, 1].
When sender_model_path is provided, the per-sender fine-tuned model is used.
""" """
from kraken import blla, rpred from kraken import blla, rpred
from confidence import words_from_characters from confidence import words_from_characters
@@ -40,11 +113,13 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
if _model is None: if _model is None:
raise RuntimeError("Kraken model is not loaded") raise RuntimeError("Kraken model is not loaded")
active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model
page_w, page_h = image.size page_w, page_h = image.size
blocks = [] blocks = []
baseline_seg = blla.segment(image) baseline_seg = blla.segment(image)
pred_it = rpred.rpred(_model, image, baseline_seg) pred_it = rpred.rpred(active_model, image, baseline_seg)
for record in pred_it: for record in pred_it:
polygon_pts = record.boundary if hasattr(record, "boundary") and record.boundary else [] polygon_pts = record.boundary if hasattr(record, "boundary") and record.boundary else []
@@ -79,13 +154,15 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
return blocks return blocks
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str: def extract_region_text(image: Image, x: float, y: float, w: float, h: float,
sender_model_path: str | None = None) -> str:
"""Crop image to a normalized region and run Kraken recognition on the crop. """Crop image to a normalized region and run Kraken recognition on the crop.
Used for guided OCR — skips full-page layout detection entirely. Used for guided OCR — skips full-page layout detection entirely.
A single synthetic baseline spanning the full crop width is used so that A single synthetic baseline spanning the full crop width is used so that
blla.segment() (which crashes on small crops) is never called. blla.segment() (which crashes on small crops) is never called.
Coordinates are normalized to [0, 1]. Coordinates are normalized to [0, 1].
When sender_model_path is provided, the per-sender fine-tuned model is used.
""" """
from kraken import rpred from kraken import rpred
from kraken.containers import Segmentation, BaselineLine from kraken.containers import Segmentation, BaselineLine
@@ -93,6 +170,8 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
if _model is None: if _model is None:
raise RuntimeError("Kraken model is not loaded") raise RuntimeError("Kraken model is not loaded")
active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model
pw, ph = image.size pw, ph = image.size
x1 = max(0, int(x * pw)) x1 = max(0, int(x * pw))
y1 = max(0, int(y * ph)) y1 = max(0, int(y * ph))
@@ -123,11 +202,12 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
regions={}, regions={},
line_orders=[], line_orders=[],
) )
pred_it = rpred.rpred(_model, crop, synthetic_seg) pred_it = rpred.rpred(active_model, crop, synthetic_seg)
return " ".join(r.prediction for r in pred_it) return " ".join(r.prediction for r in pred_it)
def extract_blocks(images: list, language: str = "de") -> list[dict]: def extract_blocks(images: list, language: str = "de",
sender_model_path: str | None = None) -> list[dict]:
"""Run Kraken segmentation + recognition on a list of PIL images. """Run Kraken segmentation + recognition on a list of PIL images.
Returns block dicts with pageNumber, x, y, width, height, polygon, text. Returns block dicts with pageNumber, x, y, width, height, polygon, text.
@@ -137,7 +217,7 @@ def extract_blocks(images: list, language: str = "de") -> list[dict]:
all_blocks = [] all_blocks = []
for page_idx, image in enumerate(images, start=1): for page_idx, image in enumerate(images, start=1):
all_blocks.extend(extract_page_blocks(image, page_idx, language)) all_blocks.extend(extract_page_blocks(image, page_idx, language, sender_model_path))
return all_blocks return all_blocks

View File

@@ -17,7 +17,7 @@ from urllib.parse import urlparse
import httpx import httpx
import pypdfium2 as pdfium import pypdfium2 as pdfium
from fastapi import FastAPI, Header, HTTPException, UploadFile from fastapi import FastAPI, Form, Header, HTTPException, UploadFile
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from PIL import Image from PIL import Image
@@ -103,7 +103,8 @@ async def run_ocr(request: OcrRequest):
status_code=400, status_code=400,
detail="Kraken model not available — cannot process Kurrent script", detail="Kraken model not available — cannot process Kurrent script",
) )
blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language) blocks = await asyncio.to_thread(
kraken_engine.extract_blocks, images, request.language, request.senderModelPath)
else: else:
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya # TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language) blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
@@ -171,10 +172,12 @@ async def run_ocr_stream(request: OcrRequest):
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n" yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
image = await asyncio.to_thread(preprocess_page, image) image = await asyncio.to_thread(preprocess_page, image)
blocks = [] blocks = []
sender_path = request.senderModelPath if use_kraken else None
for region in page_regions: for region in page_regions:
text = await asyncio.to_thread( text = await asyncio.to_thread(
engine.extract_region_text, image, engine.extract_region_text, image,
region.x, region.y, region.width, region.height, region.x, region.y, region.width, region.height,
sender_path,
) )
if script_type in _SPELL_CHECK_SCRIPT_TYPES: if script_type in _SPELL_CHECK_SCRIPT_TYPES:
text = correct_text(text) text = correct_text(text)
@@ -231,8 +234,9 @@ async def run_ocr_stream(request: OcrRequest):
try: try:
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n" yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
image = await asyncio.to_thread(preprocess_page, image) image = await asyncio.to_thread(preprocess_page, image)
sender_path = request.senderModelPath if use_kraken else None
blocks = await asyncio.to_thread( blocks = await asyncio.to_thread(
engine.extract_page_blocks, image, page_idx, request.language engine.extract_page_blocks, image, page_idx, request.language, sender_path
) )
for block in blocks: for block in blocks:
@@ -443,6 +447,86 @@ async def train_model(
return result return result
@app.post("/train-sender")
async def train_sender_model(
file: UploadFile,
output_model_path: str = Form(...),
x_training_token: str | None = Header(default=None),
):
"""Fine-tune a per-sender Kurrent model and save to output_model_path.
output_model_path must be within /app/models/. The per-sender model cache
is invalidated after training so the next OCR request picks up the new model.
"""
_check_training_token(x_training_token)
if not _models_ready:
raise HTTPException(status_code=503, detail="Models not loaded yet")
models_dir = os.path.realpath("/app/models")
resolved_output = os.path.realpath(output_model_path)
if not resolved_output.startswith(models_dir + os.sep):
raise HTTPException(status_code=400, detail="output_model_path must be within /app/models/")
zip_bytes = await file.read()
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
log.info("Starting sender training run %s%s", training_run_id, output_model_path)
def _run_sender_training() -> dict:
with tempfile.TemporaryDirectory() as tmp_dir:
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
for entry in zf.namelist():
_validate_zip_entry(entry, tmp_dir)
zf.extractall(tmp_dir)
ground_truth = glob.glob(os.path.join(tmp_dir, "*.xml"))
if not ground_truth:
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
log.info("Sender training on %d ground-truth pairs", len(ground_truth))
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
cmd = [
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
"train",
"-f", "page",
"--weights-format", "coreml",
"-o", checkpoint_dir,
"-q", "fixed",
"-N", "10",
"-B", "1",
]
if os.path.exists(KRAKEN_MODEL_PATH):
cmd += ["-i", KRAKEN_MODEL_PATH]
cmd += ground_truth
log.info("Running sender training: %s", " ".join(cmd[:6]) + " ...")
proc = subprocess.run(cmd, capture_output=True, text=True)
if proc.returncode != 0:
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
best_model = _find_best_model(checkpoint_dir)
if best_model is None:
raise RuntimeError("Sender training produced no model file")
os.makedirs(os.path.dirname(resolved_output), exist_ok=True)
shutil.copy2(best_model, resolved_output)
log.info("Saved sender model to %s", resolved_output)
kraken_engine._sender_registry.invalidate(output_model_path)
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
result = await asyncio.to_thread(_run_sender_training)
return result
@app.post("/segtrain") @app.post("/segtrain")
async def segtrain_model( async def segtrain_model(
file: UploadFile, file: UploadFile,

View File

@@ -19,6 +19,7 @@ class OcrRequest(BaseModel):
scriptType: str = "UNKNOWN" scriptType: str = "UNKNOWN"
language: str = "de" language: str = "de"
regions: list[OcrRegion] | None = None regions: list[OcrRegion] | None = None
senderModelPath: str | None = None
class OcrBlock(BaseModel): class OcrBlock(BaseModel):

View File

@@ -0,0 +1,75 @@
"""Tests for the per-sender model LRU registry in engines/kraken.py."""
from unittest.mock import MagicMock, call, patch
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_registry(max_size=5):
from engines.kraken import _SenderModelRegistry
return _SenderModelRegistry(max_size=max_size)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_cache_hit_returns_same_object():
"""Second get_model call with the same path must return the cached object."""
registry = _make_registry()
mock_model = MagicMock(name="model_a")
with patch("engines.kraken._load_sender_model", return_value=mock_model) as loader:
m1 = registry.get_model("/app/models/sender_a.mlmodel")
m2 = registry.get_model("/app/models/sender_a.mlmodel")
assert m1 is m2
loader.assert_called_once() # only loaded once despite two gets
def test_lru_eviction_removes_least_recently_used():
"""When the cache exceeds max_size, the oldest-accessed entry is evicted."""
registry = _make_registry(max_size=2)
def _side_effect(path):
return MagicMock(name=path)
with patch("engines.kraken._load_sender_model", side_effect=_side_effect):
registry.get_model("/app/models/sender_a.mlmodel")
registry.get_model("/app/models/sender_b.mlmodel")
registry.get_model("/app/models/sender_c.mlmodel") # should evict 'a'
assert registry.size() == 2
# 'a' was the least-recently-used and should be gone
assert not registry._contains("/app/models/sender_a.mlmodel")
def test_invalidate_removes_entry_from_cache():
"""invalidate() must evict the entry so the next get reloads from disk."""
registry = _make_registry()
mock_model = MagicMock(name="model_x")
with patch("engines.kraken._load_sender_model", return_value=mock_model):
registry.get_model("/app/models/sender_x.mlmodel")
assert registry.size() == 1
registry.invalidate("/app/models/sender_x.mlmodel")
assert registry.size() == 0
def test_path_outside_models_dir_raises():
"""get_model must reject paths outside /app/models/ (path traversal guard)."""
registry = _make_registry()
with pytest.raises(ValueError, match="not allowed"):
registry.get_model("/etc/passwd")
def test_load_failure_does_not_cache_broken_entry():
"""A failed load must not leave a broken entry in the cache."""
registry = _make_registry()
with patch("engines.kraken._load_sender_model", side_effect=RuntimeError("corrupt model")):
with pytest.raises(RuntimeError, match="corrupt model"):
registry.get_model("/app/models/sender_bad.mlmodel")
assert registry.size() == 0

View File

@@ -67,3 +67,42 @@ async def test_train_returns_403_when_token_wrong():
) )
assert response.status_code == 403 assert response.status_code == 403
# ─── /train-sender authentication ─────────────────────────────────────────────
@pytest.mark.asyncio
async def test_train_sender_returns_503_when_training_token_not_configured():
"""POST /train-sender must return 503 when TRAINING_TOKEN env var is empty.
An empty token means the service was started without training configured.
Allowing requests through would grant unauthenticated access to the
training endpoint, contradicting the principle of failing closed.
"""
with patch("main.TRAINING_TOKEN", ""), \
patch("main._models_ready", True):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/train-sender",
data={"output_model_path": "/app/models/sender_test.mlmodel"},
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
)
assert response.status_code == 503
@pytest.mark.asyncio
async def test_train_sender_returns_403_when_token_wrong():
"""POST /train-sender must return 403 when TRAINING_TOKEN is set but header is wrong."""
with patch("main.TRAINING_TOKEN", "secret-token"), \
patch("main._models_ready", True):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/train-sender",
data={"output_model_path": "/app/models/sender_test.mlmodel"},
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
headers={"X-Training-Token": "wrong-token"},
)
assert response.status_code == 403