feat(ocr): per-sender specialized Kurrent models with automatic active-learning retraining #263
@@ -48,7 +48,7 @@ jobs:
|
||||
path: frontend/test-results/screenshots/
|
||||
|
||||
# ─── OCR Service Unit Tests ───────────────────────────────────────────────────
|
||||
# Only spell_check.py and test_confidence.py — no ML stack required.
|
||||
# Only spell_check.py, test_confidence.py, test_sender_registry.py — no ML stack required.
|
||||
ocr-tests:
|
||||
name: OCR Service Tests
|
||||
runs-on: ubuntu-latest
|
||||
@@ -60,11 +60,11 @@ jobs:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install test dependencies
|
||||
run: pip install "pyspellchecker==0.9.0" pytest
|
||||
run: pip install "pyspellchecker==0.9.0" pytest pytest-asyncio
|
||||
working-directory: ocr-service
|
||||
|
||||
- name: Run OCR unit tests (no ML stack required)
|
||||
run: python -m pytest test_spell_check.py test_confidence.py -v
|
||||
run: python -m pytest test_spell_check.py test_confidence.py test_sender_registry.py -v
|
||||
working-directory: ocr-service
|
||||
|
||||
# ─── Backend Unit & Slice Tests ───────────────────────────────────────────────
|
||||
|
||||
@@ -59,6 +59,9 @@ public class OcrTrainingRun {
|
||||
@Column(name = "triggered_by")
|
||||
private UUID triggeredBy;
|
||||
|
||||
@Column(name = "person_id")
|
||||
private UUID personId;
|
||||
|
||||
@CreationTimestamp
|
||||
@Column(name = "created_at", nullable = false, updatable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.persistence.*;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.hibernate.annotations.CreationTimestamp;
|
||||
import org.hibernate.annotations.UpdateTimestamp;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.UUID;
|
||||
|
||||
@Entity
|
||||
@Table(name = "sender_models")
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class SenderModel {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.UUID)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private UUID id;
|
||||
|
||||
@Column(name = "person_id", nullable = false, unique = true)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private UUID personId;
|
||||
|
||||
@JsonIgnore
|
||||
@Column(name = "model_path", nullable = false)
|
||||
private String modelPath;
|
||||
|
||||
@Column
|
||||
private Double accuracy;
|
||||
|
||||
@Column
|
||||
private Double cer;
|
||||
|
||||
@Column(name = "corrected_lines_at_training", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private int correctedLinesAtTraining;
|
||||
|
||||
@CreationTimestamp
|
||||
@Column(name = "created_at", nullable = false, updatable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private Instant createdAt;
|
||||
|
||||
@UpdateTimestamp
|
||||
@Column(name = "updated_at", nullable = false)
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private Instant updatedAt;
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
public enum TrainingStatus {
|
||||
QUEUED,
|
||||
RUNNING,
|
||||
DONE,
|
||||
FAILED
|
||||
|
||||
@@ -12,5 +12,11 @@ public interface OcrTrainingRunRepository extends JpaRepository<OcrTrainingRun,
|
||||
|
||||
Optional<OcrTrainingRun> findFirstByStatus(TrainingStatus status);
|
||||
|
||||
List<OcrTrainingRun> findTop10ByOrderByCreatedAtDesc();
|
||||
Optional<OcrTrainingRun> findFirstByStatusOrderByCreatedAtAsc(TrainingStatus status);
|
||||
|
||||
Optional<OcrTrainingRun> findFirstByPersonIdAndStatus(UUID personId, TrainingStatus status);
|
||||
|
||||
boolean existsByPersonIdAndStatus(UUID personId, TrainingStatus status);
|
||||
|
||||
List<OcrTrainingRun> findTop20ByOrderByCreatedAtDesc();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
package org.raddatz.familienarchiv.repository;
|
||||
|
||||
import org.raddatz.familienarchiv.model.SenderModel;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
public interface SenderModelRepository extends JpaRepository<SenderModel, UUID> {
|
||||
|
||||
Optional<SenderModel> findByPersonId(UUID personId);
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package org.raddatz.familienarchiv.repository;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
import org.springframework.data.jpa.repository.Query;
|
||||
import org.springframework.data.repository.query.Param;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
@@ -37,4 +38,22 @@ public interface TranscriptionBlockRepository extends JpaRepository<Transcriptio
|
||||
AND 'KURRENT_SEGMENTATION' MEMBER OF d.trainingLabels
|
||||
""")
|
||||
List<TranscriptionBlock> findSegmentationBlocks();
|
||||
|
||||
@Query("""
|
||||
SELECT COUNT(b) FROM TranscriptionBlock b
|
||||
JOIN Document d ON d.id = b.documentId
|
||||
WHERE b.source = 'MANUAL'
|
||||
AND d.sender.id = :personId
|
||||
AND d.scriptType = 'HANDWRITING_KURRENT'
|
||||
""")
|
||||
long countManualKurrentBlocksByPerson(@Param("personId") UUID personId);
|
||||
|
||||
@Query("""
|
||||
SELECT b FROM TranscriptionBlock b
|
||||
JOIN Document d ON d.id = b.documentId
|
||||
WHERE b.source = 'MANUAL'
|
||||
AND d.sender.id = :personId
|
||||
AND d.scriptType = 'HANDWRITING_KURRENT'
|
||||
""")
|
||||
List<TranscriptionBlock> findManualKurrentBlocksByPerson(@Param("personId") UUID personId);
|
||||
}
|
||||
|
||||
@@ -9,10 +9,12 @@ import org.raddatz.familienarchiv.repository.OcrJobRepository;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
@@ -29,6 +31,7 @@ public class OcrAsyncRunner {
|
||||
private final OcrJobRepository ocrJobRepository;
|
||||
private final OcrJobDocumentRepository ocrJobDocumentRepository;
|
||||
private final OcrProgressService ocrProgressService;
|
||||
private final SenderModelService senderModelService;
|
||||
|
||||
@Async
|
||||
public void runSingleDocument(UUID jobId, UUID documentId, UUID userId) {
|
||||
@@ -68,12 +71,18 @@ public class OcrAsyncRunner {
|
||||
|
||||
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
|
||||
|
||||
String senderModelPath = null;
|
||||
if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) {
|
||||
senderModelPath = senderModelService.maybeGetModelPath(doc.getSender().getId()).orElse(null);
|
||||
}
|
||||
|
||||
AtomicInteger blockCounter = new AtomicInteger(0);
|
||||
AtomicInteger currentPage = new AtomicInteger(0);
|
||||
AtomicInteger skippedPages = new AtomicInteger(0);
|
||||
AtomicInteger totalPages = new AtomicInteger(0);
|
||||
|
||||
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, event -> {
|
||||
final String finalSenderModelPath = senderModelPath;
|
||||
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), regions, finalSenderModelPath, event -> {
|
||||
switch (event) {
|
||||
case OcrStreamEvent.Start start -> {
|
||||
totalPages.set(start.totalPages());
|
||||
@@ -207,7 +216,25 @@ public class OcrAsyncRunner {
|
||||
clearExistingBlocks(documentId);
|
||||
|
||||
String pdfUrl = fileService.generatePresignedUrl(doc.getFilePath());
|
||||
List<OcrBlockResult> blocks = ocrClient.extractBlocks(pdfUrl, doc.getScriptType());
|
||||
|
||||
String senderModelPath = null;
|
||||
if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) {
|
||||
senderModelPath = senderModelService.maybeGetModelPath(doc.getSender().getId()).orElse(null);
|
||||
}
|
||||
|
||||
final AtomicReference<List<OcrBlockResult>> blocksRef = new AtomicReference<>();
|
||||
final String finalSenderModelPath = senderModelPath;
|
||||
ocrClient.streamBlocks(pdfUrl, doc.getScriptType(), null, finalSenderModelPath, event -> {
|
||||
switch (event) {
|
||||
case OcrStreamEvent.Page page -> {
|
||||
blocksRef.compareAndSet(null, new ArrayList<>());
|
||||
blocksRef.get().addAll(page.blocks());
|
||||
}
|
||||
default -> {}
|
||||
}
|
||||
});
|
||||
|
||||
List<OcrBlockResult> blocks = blocksRef.get() != null ? blocksRef.get() : List.of();
|
||||
createTranscriptionBlocks(documentId, blocks, userId, doc.getFileHash());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
import org.springframework.lang.Nullable;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
@@ -37,15 +38,27 @@ public interface OcrClient {
|
||||
TrainingResult segtrainModel(byte[] trainingDataZip);
|
||||
|
||||
/**
|
||||
* Stream OCR results page-by-page via NDJSON. Implementations should override
|
||||
* this method. The default exists only for backward compatibility during migration
|
||||
* — it calls extractBlocks() and synthesizes events from the collected result.
|
||||
* Fine-tune the Kurrent model for a specific sender.
|
||||
*
|
||||
* @param regions optional list of pre-drawn annotation regions; when non-null,
|
||||
* the OCR service runs in guided mode (crop + recognize per region)
|
||||
* @param trainingDataZip raw ZIP bytes produced by TrainingDataExportService.exportForSender()
|
||||
* @param outputModelPath where to save the trained model (e.g. /app/models/sender_{uuid}.mlmodel)
|
||||
* @return training result metrics
|
||||
*/
|
||||
TrainingResult trainSenderModel(byte[] trainingDataZip, String outputModelPath);
|
||||
|
||||
/**
|
||||
* Stream OCR results page-by-page via NDJSON, optionally using a sender-specific model.
|
||||
* The default implementation synthesizes events from extractBlocks() for backward compatibility.
|
||||
* Implementations that support real streaming (e.g. RestClientOcrClient) override this.
|
||||
*
|
||||
* @param regions optional list of pre-drawn annotation regions; when non-null,
|
||||
* the OCR service runs in guided mode (crop + recognize per region)
|
||||
* @param senderModelPath optional path to a per-sender model file; null means use base model
|
||||
*/
|
||||
default void streamBlocks(String pdfUrl, ScriptType scriptType,
|
||||
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
|
||||
List<OcrRegion> regions,
|
||||
@Nullable String senderModelPath,
|
||||
Consumer<OcrStreamEvent> handler) {
|
||||
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
|
||||
|
||||
LinkedHashMap<Integer, List<OcrBlockResult>> byPage = new LinkedHashMap<>();
|
||||
@@ -62,4 +75,9 @@ public interface OcrClient {
|
||||
|
||||
handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0));
|
||||
}
|
||||
|
||||
default void streamBlocks(String pdfUrl, ScriptType scriptType,
|
||||
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
|
||||
streamBlocks(pdfUrl, scriptType, regions, null, handler);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,9 +17,11 @@ import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.time.Instant;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
@@ -34,6 +36,7 @@ public class OcrTrainingService {
|
||||
private final OcrHealthClient ocrHealthClient;
|
||||
private final TranscriptionBlockRepository blockRepository;
|
||||
private final TransactionTemplate txTemplate;
|
||||
private final PersonService personService;
|
||||
|
||||
public record TrainingInfoResponse(
|
||||
int availableBlocks,
|
||||
@@ -42,7 +45,8 @@ public class OcrTrainingService {
|
||||
int availableSegBlocks,
|
||||
boolean ocrServiceAvailable,
|
||||
OcrTrainingRun lastRun,
|
||||
List<OcrTrainingRun> runs
|
||||
List<OcrTrainingRun> runs,
|
||||
Map<String, String> personNames
|
||||
) {}
|
||||
|
||||
private void assertNoRunningTraining() {
|
||||
@@ -195,9 +199,20 @@ public class OcrTrainingService {
|
||||
int totalOcrBlocks = (int) blockRepository.count();
|
||||
int availableSegBlocks = segmentationTrainingExportService.querySegmentationBlocks().size();
|
||||
|
||||
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop10ByOrderByCreatedAtDesc();
|
||||
List<OcrTrainingRun> recentRuns = trainingRunRepository.findTop20ByOrderByCreatedAtDesc();
|
||||
OcrTrainingRun lastRun = recentRuns.isEmpty() ? null : recentRuns.get(0);
|
||||
|
||||
List<UUID> distinctPersonIds = recentRuns.stream()
|
||||
.map(OcrTrainingRun::getPersonId)
|
||||
.filter(Objects::nonNull)
|
||||
.distinct()
|
||||
.toList();
|
||||
Map<String, String> personNames = new HashMap<>();
|
||||
if (!distinctPersonIds.isEmpty()) {
|
||||
personService.getAllById(distinctPersonIds)
|
||||
.forEach(p -> personNames.put(p.getId().toString(), p.getDisplayName()));
|
||||
}
|
||||
|
||||
return new TrainingInfoResponse(
|
||||
eligibleBlocks.size(),
|
||||
totalOcrBlocks,
|
||||
@@ -205,7 +220,8 @@ public class OcrTrainingService {
|
||||
availableSegBlocks,
|
||||
ocrHealthClient.isHealthy(),
|
||||
lastRun,
|
||||
recentRuns
|
||||
recentRuns,
|
||||
personNames
|
||||
);
|
||||
}
|
||||
|
||||
@@ -224,15 +240,4 @@ public class OcrTrainingService {
|
||||
});
|
||||
}
|
||||
|
||||
public Map<String, Object> buildTrainingInfoMap(TrainingInfoResponse info) {
|
||||
return Map.of(
|
||||
"availableBlocks", info.availableBlocks(),
|
||||
"totalOcrBlocks", info.totalOcrBlocks(),
|
||||
"availableDocuments", info.availableDocuments(),
|
||||
"availableSegBlocks", info.availableSegBlocks(),
|
||||
"ocrServiceAvailable", info.ocrServiceAvailable(),
|
||||
"lastRun", info.lastRun() != null ? info.lastRun() : Map.of(),
|
||||
"runs", info.runs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.client.JdkClientHttpRequestFactory;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
@@ -102,6 +103,13 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
.toList();
|
||||
}
|
||||
|
||||
private RestClient.RequestBodySpec addTrainingAuth(RestClient.RequestBodySpec spec) {
|
||||
if (trainingToken != null && !trainingToken.isBlank()) {
|
||||
return spec.header("X-Training-Token", trainingToken);
|
||||
}
|
||||
return spec;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OcrClient.TrainingResult trainModel(byte[] trainingDataZip) {
|
||||
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
|
||||
@@ -114,15 +122,10 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
|
||||
body.add("file", new HttpEntity<>(zipResource, partHeaders));
|
||||
|
||||
var spec = trainingRestClient.post()
|
||||
.uri("/train")
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA);
|
||||
|
||||
if (trainingToken != null && !trainingToken.isBlank()) {
|
||||
spec = spec.header("X-Training-Token", trainingToken);
|
||||
}
|
||||
|
||||
TrainingResultJson result = spec
|
||||
TrainingResultJson result = addTrainingAuth(
|
||||
trainingRestClient.post()
|
||||
.uri("/train")
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA))
|
||||
.body(body)
|
||||
.retrieve()
|
||||
.body(TrainingResultJson.class);
|
||||
@@ -143,15 +146,35 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
|
||||
body.add("file", new HttpEntity<>(zipResource, partHeaders));
|
||||
|
||||
var spec = trainingRestClient.post()
|
||||
.uri("/segtrain")
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA);
|
||||
TrainingResultJson result = addTrainingAuth(
|
||||
trainingRestClient.post()
|
||||
.uri("/segtrain")
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA))
|
||||
.body(body)
|
||||
.retrieve()
|
||||
.body(TrainingResultJson.class);
|
||||
|
||||
if (trainingToken != null && !trainingToken.isBlank()) {
|
||||
spec = spec.header("X-Training-Token", trainingToken);
|
||||
}
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.cer(), result.epochs());
|
||||
}
|
||||
|
||||
TrainingResultJson result = spec
|
||||
@Override
|
||||
public OcrClient.TrainingResult trainSenderModel(byte[] trainingDataZip, String outputModelPath) {
|
||||
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
|
||||
@Override
|
||||
public String getFilename() { return "sender-training-data.zip"; }
|
||||
};
|
||||
|
||||
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
|
||||
HttpHeaders partHeaders = new HttpHeaders();
|
||||
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
|
||||
body.add("file", new HttpEntity<>(zipResource, partHeaders));
|
||||
body.add("output_model_path", outputModelPath);
|
||||
|
||||
TrainingResultJson result = addTrainingAuth(
|
||||
trainingRestClient.post()
|
||||
.uri("/train-sender")
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA))
|
||||
.body(body)
|
||||
.retrieve()
|
||||
.body(TrainingResultJson.class);
|
||||
@@ -176,7 +199,8 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
|
||||
@Override
|
||||
public void streamBlocks(String pdfUrl, ScriptType scriptType,
|
||||
List<OcrRegion> regions, Consumer<OcrStreamEvent> handler) {
|
||||
List<OcrRegion> regions, @Nullable String senderModelPath,
|
||||
Consumer<OcrStreamEvent> handler) {
|
||||
String body;
|
||||
try {
|
||||
var requestMap = new java.util.LinkedHashMap<String, Object>();
|
||||
@@ -186,6 +210,9 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
if (regions != null && !regions.isEmpty()) {
|
||||
requestMap.put("regions", regions);
|
||||
}
|
||||
if (senderModelPath != null) {
|
||||
requestMap.put("senderModelPath", senderModelPath);
|
||||
}
|
||||
body = NDJSON_MAPPER.writeValueAsString(requestMap);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to serialize OCR request", e);
|
||||
@@ -204,7 +231,12 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
|
||||
if (response.statusCode() == 404) {
|
||||
log.info("OCR service does not support /ocr/stream (404), falling back to /ocr");
|
||||
OcrClient.super.streamBlocks(pdfUrl, scriptType, regions, handler);
|
||||
List<OcrBlockResult> allBlocks = extractBlocks(pdfUrl, scriptType);
|
||||
handler.accept(new OcrStreamEvent.Start(0));
|
||||
for (OcrBlockResult block : allBlocks) {
|
||||
handler.accept(new OcrStreamEvent.Page(block.pageNumber(), List.of(block)));
|
||||
}
|
||||
handler.accept(new OcrStreamEvent.Done(allBlocks.size(), 0));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.SenderModel;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.SenderModelRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.slf4j.MDC;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.time.Instant;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class SenderModelService {
|
||||
|
||||
private final SenderModelRepository senderModelRepository;
|
||||
private final TranscriptionBlockRepository blockRepository;
|
||||
private final OcrTrainingRunRepository trainingRunRepository;
|
||||
private final OcrClient ocrClient;
|
||||
private final TransactionTemplate txTemplate;
|
||||
private final TrainingDataExportService trainingDataExportService;
|
||||
|
||||
@Value("${ocr.sender-model.activation-threshold:100}")
|
||||
private int activationThreshold;
|
||||
|
||||
@Value("${ocr.sender-model.retrain-delta:50}")
|
||||
private int retrainDelta;
|
||||
|
||||
/** Returns the model path if a trained sender model exists for this person. */
|
||||
public Optional<String> maybeGetModelPath(UUID personId) {
|
||||
return senderModelRepository.findByPersonId(personId)
|
||||
.map(SenderModel::getModelPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Called after every MANUAL block save for HANDWRITING_KURRENT documents.
|
||||
* Checks activation and retrain thresholds; enqueues or starts sender training when met.
|
||||
*/
|
||||
@Async
|
||||
public void checkAndTriggerTraining(UUID personId) {
|
||||
long correctedLines = blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||
Optional<SenderModel> existing = senderModelRepository.findByPersonId(personId);
|
||||
|
||||
boolean shouldActivate = existing.isEmpty() && correctedLines >= activationThreshold;
|
||||
boolean shouldRetrain = existing.isPresent()
|
||||
&& (correctedLines - existing.get().getCorrectedLinesAtTraining()) >= retrainDelta;
|
||||
|
||||
if (!shouldActivate && !shouldRetrain) {
|
||||
return;
|
||||
}
|
||||
|
||||
log.info("Sender training threshold met for person {} (correctedLines={}, activate={}, retrain={})",
|
||||
personId, correctedLines, shouldActivate, shouldRetrain);
|
||||
|
||||
boolean runNow = runOrQueueSenderTraining(personId, (int) correctedLines);
|
||||
if (runNow) {
|
||||
triggerSenderTraining(personId, (int) correctedLines);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Atomically checks the queue state and either creates a RUNNING row (returns true) or a
|
||||
* QUEUED row (returns false). All three operations — idle check, duplicate-queue guard, and
|
||||
* RUNNING row creation — happen in one transaction, eliminating the race window that would
|
||||
* otherwise exist between the check and a separate RUNNING row creation.
|
||||
*/
|
||||
@Transactional
|
||||
public boolean runOrQueueSenderTraining(UUID personId, int correctedLines) {
|
||||
if (trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)) {
|
||||
log.info("Sender training already queued for person {} — skipping duplicate trigger", personId);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING).isPresent()) {
|
||||
int blockCount = (int) blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||
trainingRunRepository.save(OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.QUEUED)
|
||||
.personId(personId)
|
||||
.blockCount(blockCount)
|
||||
.documentCount(0)
|
||||
.modelName("sender_" + personId)
|
||||
.build());
|
||||
log.info("Queued sender training for person {} — training already running", personId);
|
||||
return false;
|
||||
}
|
||||
|
||||
long blockCount = blockRepository.countManualKurrentBlocksByPerson(personId);
|
||||
trainingRunRepository.save(OcrTrainingRun.builder()
|
||||
.status(TrainingStatus.RUNNING)
|
||||
.personId(personId)
|
||||
.blockCount((int) blockCount)
|
||||
.documentCount(0)
|
||||
.modelName("sender_" + personId)
|
||||
.build());
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes sender training synchronously. Caller must run this on a background thread.
|
||||
* The RUNNING row is expected to already exist — created atomically by
|
||||
* runOrQueueSenderTraining (for new runs) or by promoteNextQueuedRun (for promoted runs).
|
||||
*/
|
||||
public void triggerSenderTraining(UUID personId, int correctedLines) {
|
||||
String outputModelPath = "/app/models/sender_" + personId + ".mlmodel";
|
||||
|
||||
OcrTrainingRun run = Objects.requireNonNull(txTemplate.execute(status ->
|
||||
trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING)
|
||||
.orElseThrow(() -> new IllegalStateException(
|
||||
"Expected RUNNING row for person " + personId + " but none found"))));
|
||||
|
||||
String runId = run.getId().toString();
|
||||
MDC.put("trainingRunId", runId);
|
||||
log.info("Started sender training run {} for person {}", runId, personId);
|
||||
|
||||
try {
|
||||
byte[] zipBytes = exportSenderData(personId);
|
||||
log.info("[trainingRun={}] Sending {} bytes to OCR service for sender training", runId, zipBytes.length);
|
||||
OcrClient.TrainingResult result = ocrClient.trainSenderModel(zipBytes, outputModelPath);
|
||||
|
||||
txTemplate.execute(status -> {
|
||||
SenderModel model = senderModelRepository.findByPersonId(personId)
|
||||
.orElseGet(() -> SenderModel.builder().personId(personId).build());
|
||||
model.setModelPath(outputModelPath);
|
||||
model.setCer(result.cer());
|
||||
model.setAccuracy(result.accuracy());
|
||||
model.setCorrectedLinesAtTraining(correctedLines);
|
||||
senderModelRepository.save(model);
|
||||
|
||||
run.setStatus(TrainingStatus.DONE);
|
||||
run.setCompletedAt(Instant.now());
|
||||
run.setCer(result.cer());
|
||||
run.setAccuracy(result.accuracy());
|
||||
run.setEpochs(result.epochs());
|
||||
trainingRunRepository.save(run);
|
||||
log.info("[trainingRun={}] Sender training completed — cer={}", runId, result.cer());
|
||||
return null;
|
||||
});
|
||||
} catch (Exception e) {
|
||||
txTemplate.execute(status -> {
|
||||
run.setStatus(TrainingStatus.FAILED);
|
||||
run.setErrorMessage(e.getMessage());
|
||||
run.setCompletedAt(Instant.now());
|
||||
trainingRunRepository.save(run);
|
||||
log.error("[trainingRun={}] Sender training failed: {}", runId, e.getMessage(), e);
|
||||
return null;
|
||||
});
|
||||
} finally {
|
||||
MDC.remove("trainingRunId");
|
||||
promoteNextQueuedRun();
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] exportSenderData(UUID personId) throws java.io.IOException {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
trainingDataExportService.exportForSender(personId).writeTo(baos);
|
||||
return baos.toByteArray();
|
||||
}
|
||||
|
||||
/**
|
||||
* Promotes the oldest QUEUED sender run to RUNNING and triggers its training.
|
||||
* Called in the finally block of triggerSenderTraining, creating a sequential chain:
|
||||
* each run promotes the next only after it fully completes (success or failure).
|
||||
*
|
||||
* This is intentionally tail-recursive via the @Async thread: the same thread holds the
|
||||
* full queue drain, serialising all sender training runs naturally without an external
|
||||
* scheduler. With N queued runs the thread stays occupied for N sequential trainings —
|
||||
* acceptable because the @Async executor is dedicated to long-running background work.
|
||||
*/
|
||||
private void promoteNextQueuedRun() {
|
||||
Optional<OcrTrainingRun> queuedOpt = txTemplate.execute(status ->
|
||||
trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED)
|
||||
.map(queued -> {
|
||||
queued.setStatus(TrainingStatus.RUNNING);
|
||||
return trainingRunRepository.save(queued);
|
||||
}));
|
||||
|
||||
if (queuedOpt != null && queuedOpt.isPresent()) {
|
||||
OcrTrainingRun promoted = queuedOpt.get();
|
||||
log.info("Promoting queued sender training run {} for person {}", promoted.getId(), promoted.getPersonId());
|
||||
long freshCount = blockRepository.countManualKurrentBlocksByPerson(promoted.getPersonId());
|
||||
triggerSenderTraining(promoted.getPersonId(), (int) freshCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -38,10 +38,20 @@ public class TrainingDataExportService {
|
||||
}
|
||||
|
||||
public StreamingResponseBody exportToZip() {
|
||||
// Collect all data before entering the lambda — no open DB txn during streaming
|
||||
List<TranscriptionBlock> blocks = queryEligibleBlocks();
|
||||
return exportBlocksToZip(queryEligibleBlocks());
|
||||
}
|
||||
|
||||
public List<TranscriptionBlock> queryBlocksForSender(UUID personId) {
|
||||
return blockRepository.findManualKurrentBlocksByPerson(personId);
|
||||
}
|
||||
|
||||
public StreamingResponseBody exportForSender(UUID personId) {
|
||||
return exportBlocksToZip(queryBlocksForSender(personId));
|
||||
}
|
||||
|
||||
private StreamingResponseBody exportBlocksToZip(List<TranscriptionBlock> blocks) {
|
||||
if (blocks.isEmpty()) {
|
||||
return out -> {}; // caller checks isEmpty() for 204 response
|
||||
return out -> {};
|
||||
}
|
||||
|
||||
// Group blocks by documentId so we only download each PDF once
|
||||
|
||||
@@ -11,6 +11,7 @@ import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.BlockSource;
|
||||
import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.model.DocumentAnnotation;
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlockVersion;
|
||||
import org.raddatz.familienarchiv.repository.AnnotationRepository;
|
||||
@@ -35,6 +36,7 @@ public class TranscriptionService {
|
||||
private final AnnotationRepository annotationRepository;
|
||||
private final AnnotationService annotationService;
|
||||
private final DocumentService documentService;
|
||||
private final SenderModelService senderModelService;
|
||||
|
||||
public List<TranscriptionBlock> listBlocks(UUID documentId) {
|
||||
return blockRepository.findByDocumentIdOrderBySortOrderAsc(documentId);
|
||||
@@ -122,6 +124,7 @@ public class TranscriptionService {
|
||||
|
||||
String text = sanitizeText(dto.getText());
|
||||
block.setText(text);
|
||||
block.setSource(BlockSource.MANUAL);
|
||||
if (dto.getLabel() != null) {
|
||||
block.setLabel(dto.getLabel());
|
||||
}
|
||||
@@ -129,6 +132,12 @@ public class TranscriptionService {
|
||||
|
||||
TranscriptionBlock saved = blockRepository.save(block);
|
||||
saveVersion(saved, userId);
|
||||
|
||||
Document doc = documentService.getDocumentById(documentId);
|
||||
if (doc.getSender() != null && doc.getScriptType() == ScriptType.HANDWRITING_KURRENT) {
|
||||
senderModelService.checkAndTriggerTraining(doc.getSender().getId());
|
||||
}
|
||||
|
||||
return saved;
|
||||
}
|
||||
|
||||
|
||||
@@ -77,3 +77,8 @@ app:
|
||||
tags: 10
|
||||
summary: 11
|
||||
transcription: 13
|
||||
|
||||
ocr:
|
||||
sender-model:
|
||||
activation-threshold: 100
|
||||
retrain-delta: 50
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
CREATE TABLE sender_models (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
person_id UUID NOT NULL UNIQUE REFERENCES persons(id) ON DELETE CASCADE,
|
||||
model_path TEXT NOT NULL,
|
||||
accuracy DOUBLE PRECISION,
|
||||
cer DOUBLE PRECISION,
|
||||
corrected_lines_at_training INT NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE ocr_training_runs
|
||||
ADD COLUMN person_id UUID REFERENCES persons(id) ON DELETE SET NULL;
|
||||
@@ -0,0 +1,3 @@
|
||||
CREATE UNIQUE INDEX idx_training_runs_queued_per_person
|
||||
ON ocr_training_runs(person_id)
|
||||
WHERE status = 'QUEUED';
|
||||
@@ -23,6 +23,8 @@ import org.springframework.test.web.servlet.MockMvc;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
@@ -219,7 +221,7 @@ class OcrControllerTest {
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void getTrainingInfo_returns200_withInfo() throws Exception {
|
||||
OcrTrainingService.TrainingInfoResponse info =
|
||||
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of());
|
||||
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null, List.of(), Map.of());
|
||||
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
|
||||
|
||||
mockMvc.perform(get("/api/ocr/training-info"))
|
||||
@@ -228,6 +230,52 @@ class OcrControllerTest {
|
||||
.andExpect(jsonPath("$.ocrServiceAvailable").value(true));
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void getTrainingInfo_returns200_and_omits_personName_when_service_provides_empty_map() throws Exception {
|
||||
UUID personId = UUID.randomUUID();
|
||||
OcrTrainingRun runWithPerson = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.DONE)
|
||||
.personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build();
|
||||
OcrTrainingService.TrainingInfoResponse info =
|
||||
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null,
|
||||
List.of(runWithPerson), Map.of());
|
||||
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
|
||||
|
||||
mockMvc.perform(get("/api/ocr/training-info"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.personNames").isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void getTrainingInfo_includesPersonName_whenPersonIdResolves() throws Exception {
|
||||
UUID personId = UUID.randomUUID();
|
||||
OcrTrainingRun runWithPerson = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.DONE)
|
||||
.personId(personId).blockCount(5).documentCount(1).modelName("sender_x").build();
|
||||
OcrTrainingService.TrainingInfoResponse info =
|
||||
new OcrTrainingService.TrainingInfoResponse(5, 20, 2, 3, true, null,
|
||||
List.of(runWithPerson), Map.of(personId.toString(), "Max Mustermann"));
|
||||
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
|
||||
|
||||
mockMvc.perform(get("/api/ocr/training-info"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.personNames." + personId).value("Max Mustermann"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "ADMIN")
|
||||
void getTrainingInfo_serializes_null_lastRun_as_json_null() throws Exception {
|
||||
OcrTrainingService.TrainingInfoResponse info =
|
||||
new OcrTrainingService.TrainingInfoResponse(0, 0, 0, 0, false, null, List.of(), Map.of());
|
||||
when(ocrTrainingService.getTrainingInfo()).thenReturn(info);
|
||||
|
||||
mockMvc.perform(get("/api/ocr/training-info"))
|
||||
.andExpect(status().isOk())
|
||||
.andExpect(jsonPath("$.lastRun").doesNotExist());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "READ_ALL")
|
||||
void getDocumentOcrStatus_returnsNone_whenNoOcrJobExists() throws Exception {
|
||||
|
||||
@@ -217,8 +217,71 @@ class MigrationIntegrationTest {
|
||||
assertThat(stored).isEqualTo("sage");
|
||||
}
|
||||
|
||||
// ─── V42: idx_training_runs_queued_per_person partial unique index ────────
|
||||
|
||||
@Test
|
||||
@Transactional(propagation = Propagation.NOT_SUPPORTED)
|
||||
void v42_partialUniqueIndex_preventsTwoQueuedRunsForSamePerson() {
|
||||
UUID personId = createPerson("V42Test", "SamePerson");
|
||||
jdbc.update("""
|
||||
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
|
||||
VALUES (gen_random_uuid(), 'QUEUED', 10, 2, 'sender_x', ?)
|
||||
""", personId);
|
||||
|
||||
assertThatThrownBy(() ->
|
||||
jdbc.update("""
|
||||
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
|
||||
VALUES (gen_random_uuid(), 'QUEUED', 5, 1, 'sender_x', ?)
|
||||
""", personId)
|
||||
).isInstanceOf(DataIntegrityViolationException.class);
|
||||
|
||||
jdbc.update("DELETE FROM ocr_training_runs WHERE person_id = ?", personId);
|
||||
jdbc.update("DELETE FROM persons WHERE id = ?", personId);
|
||||
}
|
||||
|
||||
@Test
|
||||
void v42_partialUniqueIndex_allowsQueuedRunsForDifferentPersons() {
|
||||
UUID person1 = createPerson("V42Test", "PersonA");
|
||||
UUID person2 = createPerson("V42Test", "PersonB");
|
||||
|
||||
int rows1 = jdbc.update("""
|
||||
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
|
||||
VALUES (gen_random_uuid(), 'QUEUED', 10, 2, 'sender_a', ?)
|
||||
""", person1);
|
||||
int rows2 = jdbc.update("""
|
||||
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
|
||||
VALUES (gen_random_uuid(), 'QUEUED', 5, 1, 'sender_b', ?)
|
||||
""", person2);
|
||||
|
||||
assertThat(rows1).isEqualTo(1);
|
||||
assertThat(rows2).isEqualTo(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void v42_partialUniqueIndex_allowsMultipleDoneRunsForSamePerson() {
|
||||
UUID personId = createPerson("V42Test", "DonePerson");
|
||||
|
||||
int rows1 = jdbc.update("""
|
||||
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
|
||||
VALUES (gen_random_uuid(), 'DONE', 10, 2, 'sender_x', ?)
|
||||
""", personId);
|
||||
int rows2 = jdbc.update("""
|
||||
INSERT INTO ocr_training_runs (id, status, block_count, document_count, model_name, person_id)
|
||||
VALUES (gen_random_uuid(), 'DONE', 15, 3, 'sender_x', ?)
|
||||
""", personId);
|
||||
|
||||
assertThat(rows1).isEqualTo(1);
|
||||
assertThat(rows2).isEqualTo(1);
|
||||
}
|
||||
|
||||
// ─── helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
private UUID createPerson(String firstName, String lastName) {
|
||||
UUID id = UUID.randomUUID();
|
||||
jdbc.update("INSERT INTO persons (id, first_name, last_name) VALUES (?, ?, ?)", id, firstName, lastName);
|
||||
return id;
|
||||
}
|
||||
|
||||
private UUID createTag(String name) {
|
||||
UUID id = UUID.randomUUID();
|
||||
jdbc.update("INSERT INTO tag (id, name) VALUES (?, ?)", id, name);
|
||||
|
||||
@@ -31,6 +31,7 @@ class OcrAsyncRunnerTest {
|
||||
@Mock OcrJobRepository ocrJobRepository;
|
||||
@Mock OcrJobDocumentRepository ocrJobDocumentRepository;
|
||||
@Mock OcrProgressService ocrProgressService;
|
||||
@Mock SenderModelService senderModelService;
|
||||
|
||||
@InjectMocks OcrAsyncRunner ocrAsyncRunner;
|
||||
|
||||
@@ -42,7 +43,12 @@ class OcrAsyncRunnerTest {
|
||||
.fileHash("hash").scriptType(ScriptType.TYPEWRITER).build();
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of());
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
|
||||
handler.accept(new OcrStreamEvent.Start(0));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.processDocument(docId, doc, userId);
|
||||
|
||||
@@ -59,9 +65,15 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null)));
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
|
||||
handler.accept(new OcrStreamEvent.Start(1));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Line 1", null),
|
||||
new OcrBlockResult(0, 0.1, 0.2, 0.8, 0.04, null, "Line 2", null))));
|
||||
handler.accept(new OcrStreamEvent.Done(2, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
|
||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||
|
||||
@@ -83,8 +95,14 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
when(ocrClient.extractBlocks(any(), any())).thenReturn(List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null)));
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
|
||||
handler.accept(new OcrStreamEvent.Start(1));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "Test", null))));
|
||||
handler.accept(new OcrStreamEvent.Done(1, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
|
||||
DocumentAnnotation ann = DocumentAnnotation.builder().id(annId).build();
|
||||
when(annotationService.createOcrAnnotation(any(), any(), any(), any(), any())).thenReturn(ann);
|
||||
|
||||
@@ -112,12 +130,12 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
|
||||
handler.accept(new OcrStreamEvent.Start(1));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -142,7 +160,7 @@ class OcrAsyncRunnerTest {
|
||||
when(documentService.getDocumentById(docId)).thenReturn(doc);
|
||||
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
doThrow(new RuntimeException("OCR failed")).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -174,7 +192,7 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
List<String> progressMessages = new ArrayList<>();
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
|
||||
handler.accept(new OcrStreamEvent.Start(3));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of(
|
||||
new OcrBlockResult(0, 0.1, 0.1, 0.8, 0.04, null, "L1", null),
|
||||
@@ -185,7 +203,7 @@ class OcrAsyncRunnerTest {
|
||||
progressMessages.add(job.getProgressMessage());
|
||||
handler.accept(new OcrStreamEvent.Done(3, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -215,14 +233,14 @@ class OcrAsyncRunnerTest {
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
|
||||
handler.accept(new OcrStreamEvent.Start(3));
|
||||
handler.accept(new OcrStreamEvent.Page(0, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Error(1, "failed"));
|
||||
handler.accept(new OcrStreamEvent.Page(2, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 1));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -251,14 +269,14 @@ class OcrAsyncRunnerTest {
|
||||
|
||||
List<String> progressMessages = new ArrayList<>();
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
|
||||
handler.accept(new OcrStreamEvent.Start(5));
|
||||
handler.accept(new OcrStreamEvent.Preprocessing(1));
|
||||
progressMessages.add(job.getProgressMessage());
|
||||
handler.accept(new OcrStreamEvent.Page(1, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 0));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
@@ -287,13 +305,13 @@ class OcrAsyncRunnerTest {
|
||||
when(fileService.generatePresignedUrl(any())).thenReturn("http://presigned");
|
||||
|
||||
doAnswer(inv -> {
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(3);
|
||||
Consumer<OcrStreamEvent> handler = inv.getArgument(4);
|
||||
handler.accept(new OcrStreamEvent.Start(2));
|
||||
handler.accept(new OcrStreamEvent.Error(0, "some python traceback details"));
|
||||
handler.accept(new OcrStreamEvent.Page(1, List.of()));
|
||||
handler.accept(new OcrStreamEvent.Done(0, 1));
|
||||
return null;
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any());
|
||||
}).when(ocrClient).streamBlocks(any(), any(), any(), any(), any());
|
||||
|
||||
ocrAsyncRunner.runSingleDocument(jobId, docId, userId);
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ class OcrClientDefaultStreamTest {
|
||||
}
|
||||
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
|
||||
@Override public TrainingResult trainSenderModel(byte[] zip, String path) { return null; }
|
||||
};
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
@@ -52,6 +53,7 @@ class OcrClientDefaultStreamTest {
|
||||
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); }
|
||||
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
|
||||
@Override public TrainingResult trainSenderModel(byte[] zip, String path) { return null; }
|
||||
};
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
|
||||
@@ -4,11 +4,12 @@ import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.TrainingLabel;
|
||||
import org.raddatz.familienarchiv.model.Person;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.raddatz.familienarchiv.service.PersonService;
|
||||
import org.springframework.transaction.support.TransactionCallback;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
@@ -32,6 +33,7 @@ class OcrTrainingServiceTest {
|
||||
OcrHealthClient healthClient;
|
||||
TranscriptionBlockRepository blockRepository;
|
||||
TransactionTemplate txTemplate;
|
||||
PersonService personService;
|
||||
OcrTrainingService service;
|
||||
|
||||
@BeforeEach
|
||||
@@ -43,6 +45,7 @@ class OcrTrainingServiceTest {
|
||||
healthClient = mock(OcrHealthClient.class);
|
||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||
txTemplate = mock(TransactionTemplate.class);
|
||||
personService = mock(PersonService.class);
|
||||
|
||||
// Execute transaction callbacks inline so unit tests run without a real DataSource
|
||||
when(txTemplate.execute(any())).thenAnswer(inv -> {
|
||||
@@ -50,10 +53,10 @@ class OcrTrainingServiceTest {
|
||||
return callback.doInTransaction(null);
|
||||
});
|
||||
|
||||
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate);
|
||||
service = new OcrTrainingService(runRepository, exportService, segExportService, ocrClient, healthClient, blockRepository, txTemplate, personService);
|
||||
|
||||
when(blockRepository.count()).thenReturn(0L);
|
||||
when(runRepository.findTop10ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||
when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(List.of());
|
||||
when(segExportService.querySegmentationBlocks()).thenReturn(List.of());
|
||||
}
|
||||
|
||||
@@ -230,6 +233,37 @@ class OcrTrainingServiceTest {
|
||||
run.getStatus() == TrainingStatus.FAILED && run.getErrorMessage() != null));
|
||||
}
|
||||
|
||||
// ─── getTrainingInfo: batch person name resolution ────────────────────────
|
||||
|
||||
@Test
|
||||
void getTrainingInfo_resolves_person_names_in_single_batch_call() {
|
||||
UUID personA = UUID.randomUUID();
|
||||
UUID personB = UUID.randomUUID();
|
||||
List<OcrTrainingRun> runs = List.of(
|
||||
OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE)
|
||||
.personId(personA).blockCount(5).documentCount(1).modelName("sender_a").build(),
|
||||
OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE)
|
||||
.personId(personB).blockCount(5).documentCount(1).modelName("sender_b").build(),
|
||||
OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.DONE)
|
||||
.personId(personA).blockCount(5).documentCount(1).modelName("sender_a").build()
|
||||
);
|
||||
when(runRepository.findTop20ByOrderByCreatedAtDesc()).thenReturn(runs);
|
||||
when(exportService.queryEligibleBlocks()).thenReturn(List.of());
|
||||
|
||||
Person pa = Person.builder().id(personA).firstName("Anna").lastName("Müller").build();
|
||||
Person pb = Person.builder().id(personB).firstName("Karl").lastName("Bauer").build();
|
||||
when(personService.getAllById(argThat(ids -> ids.size() == 2 && ids.containsAll(List.of(personA, personB)))))
|
||||
.thenReturn(List.of(pa, pb));
|
||||
when(healthClient.isHealthy()).thenReturn(true);
|
||||
|
||||
OcrTrainingService.TrainingInfoResponse info = service.getTrainingInfo();
|
||||
|
||||
verify(personService, never()).getById(any());
|
||||
verify(personService, times(1)).getAllById(any());
|
||||
assertThat(info.personNames()).containsKey(personA.toString());
|
||||
assertThat(info.personNames()).containsKey(personB.toString());
|
||||
}
|
||||
|
||||
// ─── Orphan recovery ──────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import com.sun.net.httpserver.HttpServer;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.OutputStream;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class RestClientOcrClientTrainingTest {
|
||||
|
||||
@Test
|
||||
void trainSenderModel_postsToTrainSenderEndpoint() throws Exception {
|
||||
List<String> capturedPaths = new CopyOnWriteArrayList<>();
|
||||
HttpServer server = HttpServer.create(new InetSocketAddress(0), 0);
|
||||
server.createContext("/", exchange -> {
|
||||
capturedPaths.add(exchange.getRequestURI().getPath());
|
||||
byte[] response = "{\"loss\":0.1,\"accuracy\":0.95,\"cer\":0.05,\"epochs\":3}"
|
||||
.getBytes(StandardCharsets.UTF_8);
|
||||
exchange.getResponseHeaders().set("Content-Type", "application/json");
|
||||
exchange.sendResponseHeaders(200, response.length);
|
||||
try (OutputStream os = exchange.getResponseBody()) {
|
||||
os.write(response);
|
||||
}
|
||||
});
|
||||
server.start();
|
||||
|
||||
try {
|
||||
int port = server.getAddress().getPort();
|
||||
RestClientOcrClient client = new RestClientOcrClient("http://localhost:" + port, "");
|
||||
client.trainSenderModel(new byte[]{1, 2, 3}, "/app/models/sender_test.mlmodel");
|
||||
assertThat(capturedPaths).contains("/train-sender");
|
||||
} finally {
|
||||
server.stop(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.raddatz.familienarchiv.model.OcrTrainingRun;
|
||||
import org.raddatz.familienarchiv.model.SenderModel;
|
||||
import org.raddatz.familienarchiv.model.TrainingStatus;
|
||||
import org.raddatz.familienarchiv.repository.OcrTrainingRunRepository;
|
||||
import org.raddatz.familienarchiv.repository.SenderModelRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.test.util.ReflectionTestUtils;
|
||||
import org.springframework.transaction.support.TransactionCallback;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class SenderModelServiceTest {
|
||||
|
||||
SenderModelRepository senderModelRepository;
|
||||
TranscriptionBlockRepository blockRepository;
|
||||
OcrTrainingRunRepository trainingRunRepository;
|
||||
OcrClient ocrClient;
|
||||
TransactionTemplate txTemplate;
|
||||
TrainingDataExportService trainingDataExportService;
|
||||
SenderModelService service;
|
||||
|
||||
UUID personId = UUID.randomUUID();
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
senderModelRepository = mock(SenderModelRepository.class);
|
||||
blockRepository = mock(TranscriptionBlockRepository.class);
|
||||
trainingRunRepository = mock(OcrTrainingRunRepository.class);
|
||||
ocrClient = mock(OcrClient.class);
|
||||
txTemplate = mock(TransactionTemplate.class);
|
||||
trainingDataExportService = mock(TrainingDataExportService.class);
|
||||
|
||||
// Execute transaction callbacks inline so unit tests run without a real DataSource.
|
||||
// lenient: not every test hits the txTemplate path, but the setup is shared.
|
||||
lenient().when(txTemplate.execute(any())).thenAnswer(inv -> {
|
||||
TransactionCallback<?> callback = inv.getArgument(0);
|
||||
return callback.doInTransaction(null);
|
||||
});
|
||||
|
||||
service = new SenderModelService(senderModelRepository, blockRepository,
|
||||
trainingRunRepository, ocrClient, txTemplate, trainingDataExportService);
|
||||
ReflectionTestUtils.setField(service, "activationThreshold", 100);
|
||||
ReflectionTestUtils.setField(service, "retrainDelta", 50);
|
||||
}
|
||||
|
||||
// ─── Activation threshold ─────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void checkAndTriggerTraining_doesNothing_belowActivationThreshold() {
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(99L);
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
|
||||
SenderModelService spy = spy(service);
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
|
||||
}
|
||||
|
||||
@Test
|
||||
void checkAndTriggerTraining_triggersTraining_atActivationThreshold() {
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
|
||||
SenderModelService spy = spy(service);
|
||||
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
|
||||
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy).runOrQueueSenderTraining(personId, 100);
|
||||
}
|
||||
|
||||
// ─── Retrain delta ────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void checkAndTriggerTraining_doesNothing_belowRetrainDelta() {
|
||||
SenderModel existing = SenderModel.builder().personId(personId)
|
||||
.correctedLinesAtTraining(100).build();
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(149L);
|
||||
|
||||
SenderModelService spy = spy(service);
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy, never()).runOrQueueSenderTraining(any(), anyInt());
|
||||
}
|
||||
|
||||
@Test
|
||||
void checkAndTriggerTraining_triggersRetrain_atRetrainDelta() {
|
||||
SenderModel existing = SenderModel.builder().personId(personId)
|
||||
.correctedLinesAtTraining(100).build();
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.of(existing));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(150L);
|
||||
|
||||
SenderModelService spy = spy(service);
|
||||
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 150);
|
||||
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy).runOrQueueSenderTraining(personId, 150);
|
||||
}
|
||||
|
||||
// ─── runNow flag ──────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void checkAndTriggerTraining_callsTrigger_whenRunNow() {
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
|
||||
SenderModelService spy = spy(service);
|
||||
doReturn(true).when(spy).runOrQueueSenderTraining(personId, 100);
|
||||
doNothing().when(spy).triggerSenderTraining(personId, 100);
|
||||
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy).triggerSenderTraining(personId, 100);
|
||||
}
|
||||
|
||||
@Test
|
||||
void checkAndTriggerTraining_doesNotCallTrigger_whenQueued() {
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(100L);
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
|
||||
SenderModelService spy = spy(service);
|
||||
doReturn(false).when(spy).runOrQueueSenderTraining(personId, 100);
|
||||
|
||||
spy.checkAndTriggerTraining(personId);
|
||||
|
||||
verify(spy, never()).triggerSenderTraining(any(), anyInt());
|
||||
}
|
||||
|
||||
// ─── runOrQueueSenderTraining ─────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_returnsTrue_whenIdle() {
|
||||
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
|
||||
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isTrue();
|
||||
verify(trainingRunRepository, never()).save(argThat(r -> r.getStatus() == TrainingStatus.QUEUED));
|
||||
}
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_savesQueued_whenRunning() {
|
||||
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
|
||||
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(
|
||||
Optional.of(OcrTrainingRun.builder().id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.blockCount(5).documentCount(1).modelName("german_kurrent").build()));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
|
||||
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isFalse();
|
||||
verify(trainingRunRepository).save(argThat(r ->
|
||||
r.getStatus() == TrainingStatus.QUEUED && personId.equals(r.getPersonId())));
|
||||
}
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_returnsFalse_whenAlreadyQueued() {
|
||||
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(true);
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isFalse();
|
||||
verify(trainingRunRepository, never()).save(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void runOrQueueSenderTraining_createsRunningRow_atomically() {
|
||||
// RUNNING row must be created inside the same @Transactional boundary as the idle check,
|
||||
// eliminating the race window between the check and a separate triggerSenderTraining call.
|
||||
when(trainingRunRepository.existsByPersonIdAndStatus(personId, TrainingStatus.QUEUED)).thenReturn(false);
|
||||
when(trainingRunRepository.findFirstByStatus(TrainingStatus.RUNNING)).thenReturn(Optional.empty());
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(personId)).thenReturn(120L);
|
||||
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
|
||||
OcrTrainingRun r = inv.getArgument(0);
|
||||
if (r.getId() == null) r.setId(UUID.randomUUID());
|
||||
return r;
|
||||
});
|
||||
|
||||
boolean result = service.runOrQueueSenderTraining(personId, 120);
|
||||
|
||||
assertThat(result).isTrue();
|
||||
verify(trainingRunRepository).save(argThat(r ->
|
||||
r.getStatus() == TrainingStatus.RUNNING && personId.equals(r.getPersonId())));
|
||||
}
|
||||
|
||||
// ─── triggerSenderTraining — happy path ──────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerSenderTraining_savesModelRecord_onSuccess() throws Exception {
|
||||
OcrTrainingRun runningRun = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.personId(personId).blockCount(10).documentCount(0)
|
||||
.modelName("sender_" + personId).build();
|
||||
when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING))
|
||||
.thenReturn(Optional.of(runningRun));
|
||||
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
|
||||
when(ocrClient.trainSenderModel(any(), any()))
|
||||
.thenReturn(new OcrClient.TrainingResult(0.08, 0.92, 0.08, 5));
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
when(senderModelRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
|
||||
.thenReturn(Optional.empty());
|
||||
|
||||
service.triggerSenderTraining(personId, 120);
|
||||
|
||||
verify(senderModelRepository).save(argThat(m ->
|
||||
("/app/models/sender_" + personId + ".mlmodel").equals(m.getModelPath())
|
||||
&& Double.valueOf(0.08).equals(m.getCer())
|
||||
&& personId.equals(m.getPersonId())));
|
||||
verify(trainingRunRepository, atLeastOnce()).save(argThat(r ->
|
||||
r.getStatus() == TrainingStatus.DONE
|
||||
&& Double.valueOf(0.08).equals(r.getCer())
|
||||
&& r.getCompletedAt() != null));
|
||||
}
|
||||
|
||||
// ─── triggerSenderTraining — failure path ────────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerSenderTraining_marksRunFailed_whenOcrClientThrows() throws Exception {
|
||||
OcrTrainingRun runningRun = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.personId(personId).blockCount(10).documentCount(0)
|
||||
.modelName("sender_" + personId).build();
|
||||
when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING))
|
||||
.thenReturn(Optional.of(runningRun));
|
||||
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
|
||||
when(ocrClient.trainSenderModel(any(), any()))
|
||||
.thenThrow(new RuntimeException("OCR service timeout"));
|
||||
when(trainingRunRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
|
||||
.thenReturn(Optional.empty());
|
||||
|
||||
service.triggerSenderTraining(personId, 120);
|
||||
|
||||
verify(trainingRunRepository, atLeastOnce()).save(argThat(r ->
|
||||
r.getStatus() == TrainingStatus.FAILED
|
||||
&& r.getErrorMessage() != null
|
||||
&& r.getCompletedAt() != null));
|
||||
verify(senderModelRepository, never()).save(any());
|
||||
}
|
||||
|
||||
// ─── triggerSenderTraining — queue promotion ──────────────────────────────
|
||||
|
||||
@Test
|
||||
void triggerSenderTraining_promotesNextQueued_afterCompletion() throws Exception {
|
||||
UUID nextPersonId = UUID.randomUUID();
|
||||
|
||||
OcrTrainingRun runningRun = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.RUNNING)
|
||||
.personId(personId).blockCount(10).documentCount(0)
|
||||
.modelName("sender_" + personId).build();
|
||||
when(trainingRunRepository.findFirstByPersonIdAndStatus(personId, TrainingStatus.RUNNING))
|
||||
.thenReturn(Optional.of(runningRun));
|
||||
when(trainingRunRepository.save(any())).thenAnswer(inv -> {
|
||||
OcrTrainingRun r = inv.getArgument(0);
|
||||
if (r.getId() == null) r.setId(UUID.randomUUID());
|
||||
return r;
|
||||
});
|
||||
when(trainingDataExportService.exportForSender(personId)).thenReturn(out -> {});
|
||||
when(ocrClient.trainSenderModel(any(), any()))
|
||||
.thenReturn(new OcrClient.TrainingResult(0.05, 0.95, 0.05, 3));
|
||||
when(senderModelRepository.findByPersonId(personId)).thenReturn(Optional.empty());
|
||||
|
||||
OcrTrainingRun queued = OcrTrainingRun.builder()
|
||||
.id(UUID.randomUUID()).status(TrainingStatus.QUEUED)
|
||||
.personId(nextPersonId).blockCount(5).documentCount(0)
|
||||
.modelName("sender_" + nextPersonId).build();
|
||||
when(trainingRunRepository.findFirstByStatusOrderByCreatedAtAsc(TrainingStatus.QUEUED))
|
||||
.thenReturn(Optional.of(queued));
|
||||
when(blockRepository.countManualKurrentBlocksByPerson(nextPersonId)).thenReturn(5L);
|
||||
|
||||
SenderModelService spy = spy(service);
|
||||
// Stub the recursive call to stop the chain after one promotion
|
||||
doNothing().when(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
|
||||
|
||||
spy.triggerSenderTraining(personId, 10);
|
||||
|
||||
// Queued run promoted to RUNNING
|
||||
verify(trainingRunRepository).save(argThat(r ->
|
||||
nextPersonId.equals(r.getPersonId()) && r.getStatus() == TrainingStatus.RUNNING));
|
||||
// Next training was kicked off
|
||||
verify(spy).triggerSenderTraining(eq(nextPersonId), anyInt());
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ class TranscriptionServiceGuidedTest {
|
||||
AnnotationRepository annotationRepository;
|
||||
AnnotationService annotationService;
|
||||
DocumentService documentService;
|
||||
SenderModelService senderModelService;
|
||||
TranscriptionService service;
|
||||
|
||||
UUID docId = UUID.randomUUID();
|
||||
@@ -35,9 +36,10 @@ class TranscriptionServiceGuidedTest {
|
||||
annotationRepository = mock(AnnotationRepository.class);
|
||||
annotationService = mock(AnnotationService.class);
|
||||
documentService = mock(DocumentService.class);
|
||||
senderModelService = mock(SenderModelService.class);
|
||||
|
||||
service = new TranscriptionService(blockRepository, versionRepository,
|
||||
annotationRepository, annotationService, documentService);
|
||||
annotationRepository, annotationService, documentService, senderModelService);
|
||||
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(versionRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
|
||||
@@ -13,6 +13,8 @@ import org.raddatz.familienarchiv.exception.DomainException;
|
||||
import org.raddatz.familienarchiv.model.BlockSource;
|
||||
import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.model.DocumentAnnotation;
|
||||
import org.raddatz.familienarchiv.model.Person;
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlockVersion;
|
||||
import org.raddatz.familienarchiv.repository.AnnotationRepository;
|
||||
@@ -39,6 +41,7 @@ class TranscriptionServiceTest {
|
||||
@Mock AnnotationRepository annotationRepository;
|
||||
@Mock AnnotationService annotationService;
|
||||
@Mock DocumentService documentService;
|
||||
@Mock SenderModelService senderModelService;
|
||||
@InjectMocks TranscriptionService transcriptionService;
|
||||
|
||||
// ─── getBlock ────────────────────────────────────────────────────────────────
|
||||
@@ -156,6 +159,8 @@ class TranscriptionServiceTest {
|
||||
.id(blockId).documentId(docId).text("old").build();
|
||||
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(documentService.getDocumentById(any())).thenReturn(
|
||||
Document.builder().scriptType(ScriptType.TYPEWRITER).build());
|
||||
|
||||
UpdateTranscriptionBlockDTO dto = new UpdateTranscriptionBlockDTO("new text", null);
|
||||
|
||||
@@ -175,6 +180,8 @@ class TranscriptionServiceTest {
|
||||
.id(blockId).documentId(docId).text("text").label("old label").build();
|
||||
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(documentService.getDocumentById(any())).thenReturn(
|
||||
Document.builder().scriptType(ScriptType.TYPEWRITER).build());
|
||||
|
||||
UpdateTranscriptionBlockDTO dto = new UpdateTranscriptionBlockDTO("text", "Anrede");
|
||||
|
||||
@@ -183,6 +190,57 @@ class TranscriptionServiceTest {
|
||||
assertThat(result.getLabel()).isEqualTo("Anrede");
|
||||
}
|
||||
|
||||
@Test
|
||||
void updateBlock_setsSourceToManual() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID blockId = UUID.randomUUID();
|
||||
TranscriptionBlock block = TranscriptionBlock.builder()
|
||||
.id(blockId).documentId(docId).text("old").source(BlockSource.OCR).build();
|
||||
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(documentService.getDocumentById(any())).thenReturn(
|
||||
Document.builder().scriptType(ScriptType.TYPEWRITER).build());
|
||||
|
||||
TranscriptionBlock result = transcriptionService.updateBlock(
|
||||
docId, blockId, new UpdateTranscriptionBlockDTO("new", null), UUID.randomUUID());
|
||||
|
||||
assertThat(result.getSource()).isEqualTo(BlockSource.MANUAL);
|
||||
}
|
||||
|
||||
@Test
|
||||
void updateBlock_triggersTraining_whenKurrentSenderPresent() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID blockId = UUID.randomUUID();
|
||||
UUID senderId = UUID.randomUUID();
|
||||
Person sender = Person.builder().id(senderId).build();
|
||||
TranscriptionBlock block = TranscriptionBlock.builder()
|
||||
.id(blockId).documentId(docId).text("old").build();
|
||||
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(documentService.getDocumentById(any())).thenReturn(
|
||||
Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).sender(sender).build());
|
||||
|
||||
transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID());
|
||||
|
||||
verify(senderModelService).checkAndTriggerTraining(senderId);
|
||||
}
|
||||
|
||||
@Test
|
||||
void updateBlock_doesNotTriggerTraining_whenNoSender() {
|
||||
UUID docId = UUID.randomUUID();
|
||||
UUID blockId = UUID.randomUUID();
|
||||
TranscriptionBlock block = TranscriptionBlock.builder()
|
||||
.id(blockId).documentId(docId).text("old").build();
|
||||
when(blockRepository.findByIdAndDocumentId(blockId, docId)).thenReturn(Optional.of(block));
|
||||
when(blockRepository.save(any())).thenAnswer(inv -> inv.getArgument(0));
|
||||
when(documentService.getDocumentById(any())).thenReturn(
|
||||
Document.builder().scriptType(ScriptType.HANDWRITING_KURRENT).build());
|
||||
|
||||
transcriptionService.updateBlock(docId, blockId, new UpdateTranscriptionBlockDTO("text", null), UUID.randomUUID());
|
||||
|
||||
verify(senderModelService, never()).checkAndTriggerTraining(any());
|
||||
}
|
||||
|
||||
// ─── deleteBlock ─────────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
|
||||
@@ -97,6 +97,7 @@ services:
|
||||
DETECTOR_BATCH_SIZE: "8"
|
||||
OCR_CLAHE_CLIP_LIMIT: "2.0" # CLAHE contrast limit (multiplier of average histogram frequency)
|
||||
OCR_CLAHE_TILE_SIZE: "8" # CLAHE tile grid size (NxN tiles per page)
|
||||
OCR_MAX_CACHED_MODELS: "2" # LRU cache; each model ~500 MB, so 2 = ~1 GB resident
|
||||
networks:
|
||||
- archive-net
|
||||
healthcheck:
|
||||
|
||||
@@ -330,6 +330,7 @@
|
||||
"comment_time_days": "vor {count} Tag(en)",
|
||||
"comment_panel_title": "Kommentare",
|
||||
"comment_panel_close": "Schließen",
|
||||
"comp_dismiss": "Schließen",
|
||||
"doc_panel_tab_metadata": "Metadaten",
|
||||
"doc_panel_tab_transcription": "Transkription",
|
||||
"doc_panel_tab_discussion": "Diskussion",
|
||||
@@ -541,6 +542,7 @@
|
||||
"training_start_btn": "Training starten",
|
||||
"training_in_progress": "…",
|
||||
"training_success": "Training wurde gestartet und abgeschlossen.",
|
||||
"training_start_failed": "Training konnte nicht gestartet werden. Bitte versuche es erneut.",
|
||||
"training_too_few_blocks": "Mindestens 5 geprüfte Blöcke erforderlich (aktuell: {available}).",
|
||||
"training_service_down": "OCR-Dienst ist nicht erreichbar.",
|
||||
"training_history_heading": "Verlauf",
|
||||
@@ -561,6 +563,11 @@
|
||||
"transcription_block_segmentation_only": "Nur Segmentierung",
|
||||
"training_chip_kurrent": "Kurrent-Erkennung",
|
||||
"training_chip_segmentation": "Segmentierung",
|
||||
"training_col_type": "Typ",
|
||||
"training_type_base": "Basis",
|
||||
"training_type_personalized": "Personalisiert",
|
||||
"training_col_person": "Absender",
|
||||
"training_status_queued": "Warteschlange",
|
||||
"mission_control_heading": "Was braucht Aufmerksamkeit?",
|
||||
"mission_control_segmentation_heading": "Text markieren",
|
||||
"mission_control_segmentation_description": "Textbereiche markieren — keine Vorkenntnisse nötig",
|
||||
|
||||
@@ -330,6 +330,7 @@
|
||||
"comment_time_days": "{count} day(s) ago",
|
||||
"comment_panel_title": "Comments",
|
||||
"comment_panel_close": "Close",
|
||||
"comp_dismiss": "Dismiss",
|
||||
"doc_panel_tab_metadata": "Metadata",
|
||||
"doc_panel_tab_transcription": "Transcription",
|
||||
"doc_panel_tab_discussion": "Discussion",
|
||||
@@ -541,6 +542,7 @@
|
||||
"training_start_btn": "Start training",
|
||||
"training_in_progress": "…",
|
||||
"training_success": "Training started and completed.",
|
||||
"training_start_failed": "Failed to start training. Please try again.",
|
||||
"training_too_few_blocks": "At least 5 reviewed blocks required (currently: {available}).",
|
||||
"training_service_down": "OCR service is unavailable.",
|
||||
"training_history_heading": "History",
|
||||
@@ -561,6 +563,11 @@
|
||||
"transcription_block_segmentation_only": "Segmentation only",
|
||||
"training_chip_kurrent": "Kurrent recognition",
|
||||
"training_chip_segmentation": "Segmentation",
|
||||
"training_col_type": "Type",
|
||||
"training_type_base": "Base",
|
||||
"training_type_personalized": "Personalized",
|
||||
"training_col_person": "Sender",
|
||||
"training_status_queued": "Queued",
|
||||
"mission_control_heading": "What needs attention?",
|
||||
"mission_control_segmentation_heading": "Mark text",
|
||||
"mission_control_segmentation_description": "Mark text areas — no prior knowledge needed",
|
||||
|
||||
@@ -330,6 +330,7 @@
|
||||
"comment_time_days": "hace {count} día(s)",
|
||||
"comment_panel_title": "Comentarios",
|
||||
"comment_panel_close": "Cerrar",
|
||||
"comp_dismiss": "Cerrar",
|
||||
"doc_panel_tab_metadata": "Metadatos",
|
||||
"doc_panel_tab_transcription": "Transcripción",
|
||||
"doc_panel_tab_discussion": "Discusión",
|
||||
@@ -541,6 +542,7 @@
|
||||
"training_start_btn": "Iniciar entrenamiento",
|
||||
"training_in_progress": "…",
|
||||
"training_success": "Entrenamiento iniciado y completado.",
|
||||
"training_start_failed": "No se pudo iniciar el entrenamiento. Por favor, inténtalo de nuevo.",
|
||||
"training_too_few_blocks": "Se requieren al menos 5 bloques revisados (actualmente: {available}).",
|
||||
"training_service_down": "El servicio OCR no está disponible.",
|
||||
"training_history_heading": "Historial",
|
||||
@@ -561,6 +563,11 @@
|
||||
"transcription_block_segmentation_only": "Solo segmentación",
|
||||
"training_chip_kurrent": "Reconocimiento Kurrent",
|
||||
"training_chip_segmentation": "Segmentación",
|
||||
"training_col_type": "Tipo",
|
||||
"training_type_base": "Base",
|
||||
"training_type_personalized": "Personalizado",
|
||||
"training_col_person": "Remitente",
|
||||
"training_status_queued": "En cola",
|
||||
"mission_control_heading": "¿Qué necesita atención?",
|
||||
"mission_control_segmentation_heading": "Marcar texto",
|
||||
"mission_control_segmentation_description": "Marcar áreas de texto — sin conocimientos previos",
|
||||
|
||||
@@ -1,26 +1,16 @@
|
||||
<script lang="ts">
|
||||
import TrainingHistory from './TrainingHistory.svelte';
|
||||
import { m } from '$lib/paraglide/messages.js';
|
||||
|
||||
interface Run {
|
||||
id: string;
|
||||
status: 'RUNNING' | 'DONE' | 'FAILED';
|
||||
blockCount: number;
|
||||
documentCount: number;
|
||||
modelName: string;
|
||||
errorMessage?: string;
|
||||
triggeredBy?: string;
|
||||
createdAt: string;
|
||||
completedAt?: string;
|
||||
}
|
||||
import type { TrainingRun } from '$lib/types/training.js';
|
||||
|
||||
interface TrainingInfo {
|
||||
availableBlocks?: number;
|
||||
totalOcrBlocks?: number;
|
||||
availableDocuments?: number;
|
||||
ocrServiceAvailable?: boolean;
|
||||
lastRun?: Run | null;
|
||||
runs?: Run[];
|
||||
lastRun?: TrainingRun | null;
|
||||
runs?: TrainingRun[];
|
||||
personNames?: Record<string, string>;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
@@ -31,6 +21,7 @@ let { trainingInfo }: Props = $props();
|
||||
|
||||
let training = $state(false);
|
||||
let successMessage = $state<string | null>(null);
|
||||
let errorMessage = $state<string | null>(null);
|
||||
|
||||
const available = $derived(trainingInfo?.availableBlocks ?? 0);
|
||||
const tooFewBlocks = $derived(available < 5);
|
||||
@@ -40,6 +31,7 @@ const disabled = $derived(training || tooFewBlocks || serviceDown);
|
||||
async function startTraining() {
|
||||
training = true;
|
||||
successMessage = null;
|
||||
errorMessage = null;
|
||||
try {
|
||||
const res = await fetch('/api/ocr/train', { method: 'POST' });
|
||||
if (res.ok) {
|
||||
@@ -47,7 +39,11 @@ async function startTraining() {
|
||||
setTimeout(() => {
|
||||
successMessage = null;
|
||||
}, 5000);
|
||||
} else {
|
||||
errorMessage = m.training_start_failed();
|
||||
}
|
||||
} catch {
|
||||
errorMessage = m.training_start_failed();
|
||||
} finally {
|
||||
training = false;
|
||||
}
|
||||
@@ -82,11 +78,26 @@ async function startTraining() {
|
||||
{/if}
|
||||
|
||||
{#if successMessage}
|
||||
<p class="mt-2 text-xs text-green-700">{successMessage}</p>
|
||||
<p class="mt-2 flex items-center gap-2 text-xs text-green-700" aria-live="polite">
|
||||
{successMessage}
|
||||
<button
|
||||
type="button"
|
||||
class="-my-2 ml-1 inline-flex h-11 w-11 items-center justify-center rounded-sm underline hover:no-underline focus-visible:ring-2 focus-visible:ring-brand-navy focus-visible:outline-none"
|
||||
onclick={() => (successMessage = null)}
|
||||
aria-label={m.comp_dismiss()}>×</button
|
||||
>
|
||||
</p>
|
||||
{/if}
|
||||
|
||||
{#if errorMessage}
|
||||
<p class="mt-2 text-xs text-red-600" aria-live="assertive">{errorMessage}</p>
|
||||
{/if}
|
||||
|
||||
<h3 class="mt-6 mb-3 text-xs font-bold tracking-widest text-ink-3 uppercase">
|
||||
{m.training_history_heading()}
|
||||
</h3>
|
||||
<TrainingHistory runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName !== 'blla')} />
|
||||
<TrainingHistory
|
||||
runs={(trainingInfo?.runs ?? []).filter((r) => r.modelName !== 'blla')}
|
||||
personNames={trainingInfo?.personNames ?? {}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -73,6 +73,24 @@ describe('OcrTrainingCard — enabled state', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('OcrTrainingCard — success dismiss button', () => {
|
||||
it('dismiss button has 44×44px touch target (h-11 w-11)', async () => {
|
||||
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: true }));
|
||||
|
||||
render(OcrTrainingCard, { trainingInfo: baseInfo });
|
||||
|
||||
const btn = page.getByRole('button', { name: /Training starten/i });
|
||||
await btn.click();
|
||||
|
||||
const dismissBtn = page.getByRole('button', { name: /Schließen/i });
|
||||
await expect.element(dismissBtn).toBeInTheDocument();
|
||||
|
||||
const el = await dismissBtn.element();
|
||||
expect(el.classList.contains('h-11')).toBe(true);
|
||||
expect(el.classList.contains('w-11')).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('OcrTrainingCard — in-flight state', () => {
|
||||
it('shows "…" while POST is in-flight', async () => {
|
||||
let resolveFetch!: (v: unknown) => void;
|
||||
|
||||
@@ -1,24 +1,13 @@
|
||||
<script lang="ts">
|
||||
import { m } from '$lib/paraglide/messages.js';
|
||||
|
||||
interface Run {
|
||||
id: string;
|
||||
status: 'RUNNING' | 'DONE' | 'FAILED';
|
||||
blockCount: number;
|
||||
documentCount: number;
|
||||
modelName: string;
|
||||
errorMessage?: string;
|
||||
triggeredBy?: string;
|
||||
createdAt: string;
|
||||
completedAt?: string;
|
||||
cer?: number;
|
||||
}
|
||||
import type { TrainingRun } from '$lib/types/training.js';
|
||||
|
||||
interface Props {
|
||||
runs: Run[];
|
||||
runs: TrainingRun[];
|
||||
personNames?: Record<string, string>;
|
||||
}
|
||||
|
||||
let { runs }: Props = $props();
|
||||
let { runs, personNames }: Props = $props();
|
||||
|
||||
const COLLAPSED_COUNT = 3;
|
||||
let expanded = $state(false);
|
||||
@@ -47,6 +36,8 @@ function formatCer(cer: number | undefined | null): string {
|
||||
<tr class="border-b border-line text-xs font-bold tracking-widest text-ink-3 uppercase">
|
||||
<th class="pb-2 text-left">{m.training_history_col_date()}</th>
|
||||
<th class="pb-2 text-left">{m.training_history_col_status()}</th>
|
||||
<th class="hidden pb-2 text-left md:table-cell">{m.training_col_type()}</th>
|
||||
<th class="hidden pb-2 text-left md:table-cell">{m.training_col_person()}</th>
|
||||
<th class="pb-2 text-right">{m.training_history_col_blocks()}</th>
|
||||
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_docs()}</th>
|
||||
<th class="hidden pb-2 text-right md:table-cell">{m.training_history_col_cer()}</th>
|
||||
@@ -55,7 +46,7 @@ function formatCer(cer: number | undefined | null): string {
|
||||
<tbody id="training-history-rows">
|
||||
{#if runs.length === 0}
|
||||
<tr>
|
||||
<td colspan="5" class="py-4 text-center text-sm text-ink-2">
|
||||
<td colspan="7" class="py-4 text-center text-sm text-ink-2">
|
||||
{m.training_history_empty()}
|
||||
</td>
|
||||
</tr>
|
||||
@@ -64,7 +55,14 @@ function formatCer(cer: number | undefined | null): string {
|
||||
<tr class="border-b border-line/50 last:border-0">
|
||||
<td class="py-2 text-ink-2">{formatDate(run.createdAt)}</td>
|
||||
<td class="py-2">
|
||||
{#if run.status === 'DONE'}
|
||||
{#if run.status === 'QUEUED'}
|
||||
<span
|
||||
class="inline-flex items-center gap-1 rounded-sm bg-amber-100 px-1.5 py-0.5 text-xs font-medium text-amber-700"
|
||||
>
|
||||
<span aria-hidden="true" class="h-1.5 w-1.5 rounded-full bg-amber-500"></span>
|
||||
{m.training_status_queued()}
|
||||
</span>
|
||||
{:else if run.status === 'DONE'}
|
||||
<span
|
||||
class="inline-flex items-center gap-1 rounded-sm bg-green-100 px-1.5 py-0.5 text-xs font-medium text-green-700"
|
||||
>
|
||||
@@ -119,6 +117,17 @@ function formatCer(cer: number | undefined | null): string {
|
||||
{m.training_status_running()}
|
||||
</span>
|
||||
{/if}
|
||||
{#if run.personId && personNames?.[run.personId]}
|
||||
<span class="mt-0.5 block text-xs text-ink-3 md:hidden"
|
||||
>{personNames[run.personId]}</span
|
||||
>
|
||||
{/if}
|
||||
</td>
|
||||
<td class="hidden py-2 text-left text-ink-2 md:table-cell">
|
||||
{run.personId ? m.training_type_personalized() : m.training_type_base()}
|
||||
</td>
|
||||
<td class="hidden py-2 text-left text-ink-2 md:table-cell">
|
||||
{run.personId && personNames?.[run.personId] ? personNames[run.personId] : '—'}
|
||||
</td>
|
||||
<td class="py-2 text-right text-ink-2">{run.blockCount}</td>
|
||||
<td class="hidden py-2 text-right text-ink-2 md:table-cell">{run.documentCount}</td>
|
||||
@@ -137,7 +146,7 @@ function formatCer(cer: number | undefined | null): string {
|
||||
type="button"
|
||||
aria-expanded={expanded}
|
||||
aria-controls="training-history-rows"
|
||||
class="text-xs font-medium text-ink-3 transition-colors hover:text-ink"
|
||||
class="rounded-sm py-2 text-xs font-medium text-ink-3 transition-colors hover:text-ink focus-visible:ring-2 focus-visible:ring-brand-navy focus-visible:ring-offset-1 focus-visible:outline-none"
|
||||
onclick={() => (expanded = !expanded)}
|
||||
>
|
||||
{expanded ? m.comp_expandable_show_less() : m.comp_expandable_show_more()}
|
||||
|
||||
@@ -50,3 +50,36 @@ describe('TrainingHistory — expand/collapse', () => {
|
||||
.not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('TrainingHistory — status badges', () => {
|
||||
it('shows amber QUEUED badge for queued runs', async () => {
|
||||
const run = { ...makeRun(0), status: 'QUEUED' as const };
|
||||
render(TrainingHistory, { runs: [run] });
|
||||
|
||||
await expect.element(page.getByText(/Warteschlange/i)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('TrainingHistory — type and person columns', () => {
|
||||
it('shows "Basis" for runs without personId', async () => {
|
||||
render(TrainingHistory, { runs: [makeRun(0)] });
|
||||
|
||||
await expect.element(page.getByText(/Basis/i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows "Personalisiert" for runs with personId', async () => {
|
||||
const run = { ...makeRun(0), personId: 'person-1' };
|
||||
render(TrainingHistory, { runs: [run], personNames: { 'person-1': 'Karl Müller' } });
|
||||
|
||||
await expect.element(page.getByText(/Personalisiert/i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows person name from personNames for sender runs', async () => {
|
||||
const run = { ...makeRun(0), personId: 'person-1' };
|
||||
render(TrainingHistory, { runs: [run], personNames: { 'person-1': 'Karl Müller' } });
|
||||
|
||||
await expect.element(page.getByText(/Personalisiert/i)).toBeInTheDocument();
|
||||
// Name appears in both the mobile inline span and the desktop table cell
|
||||
await expect.element(page.getByText('Karl Müller').first()).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1374,7 +1374,7 @@ export interface components {
|
||||
};
|
||||
MergeTagDTO: {
|
||||
/** Format: uuid */
|
||||
targetId?: string;
|
||||
targetId: string;
|
||||
};
|
||||
PersonNameAliasDTO: {
|
||||
lastName: string;
|
||||
@@ -1398,7 +1398,7 @@ export interface components {
|
||||
/** Format: uuid */
|
||||
id: string;
|
||||
/** @enum {string} */
|
||||
status: "RUNNING" | "DONE" | "FAILED";
|
||||
status: "QUEUED" | "RUNNING" | "DONE" | "FAILED";
|
||||
/** Format: int32 */
|
||||
blockCount: number;
|
||||
/** Format: int32 */
|
||||
@@ -1415,6 +1415,8 @@ export interface components {
|
||||
errorMessage?: string;
|
||||
/** Format: uuid */
|
||||
triggeredBy?: string;
|
||||
/** Format: uuid */
|
||||
personId?: string;
|
||||
/** Format: date-time */
|
||||
createdAt: string;
|
||||
/** Format: date-time */
|
||||
@@ -1596,11 +1598,11 @@ export interface components {
|
||||
};
|
||||
TagTreeNodeDTO: {
|
||||
/** Format: uuid */
|
||||
id?: string;
|
||||
name?: string;
|
||||
id: string;
|
||||
name: string;
|
||||
color?: string;
|
||||
/** Format: int32 */
|
||||
documentCount?: number;
|
||||
documentCount: number;
|
||||
children?: components["schemas"]["TagTreeNodeDTO"][];
|
||||
/**
|
||||
* Format: uuid
|
||||
@@ -1619,7 +1621,6 @@ export interface components {
|
||||
/** Format: uuid */
|
||||
id?: string;
|
||||
displayName?: string;
|
||||
personType?: string;
|
||||
firstName?: string;
|
||||
lastName?: string;
|
||||
/** Format: int64 */
|
||||
@@ -1630,19 +1631,7 @@ export interface components {
|
||||
deathYear?: number;
|
||||
alias?: string;
|
||||
notes?: string;
|
||||
};
|
||||
TrainingInfoResponse: {
|
||||
/** Format: int32 */
|
||||
availableBlocks?: number;
|
||||
/** Format: int32 */
|
||||
totalOcrBlocks?: number;
|
||||
/** Format: int32 */
|
||||
availableDocuments?: number;
|
||||
/** Format: int32 */
|
||||
availableSegBlocks?: number;
|
||||
ocrServiceAvailable?: boolean;
|
||||
lastRun?: components["schemas"]["OcrTrainingRun"];
|
||||
runs?: components["schemas"]["OcrTrainingRun"][];
|
||||
personType?: string;
|
||||
};
|
||||
StreamingResponseBody: unknown;
|
||||
OcrJob: {
|
||||
@@ -1671,13 +1660,11 @@ export interface components {
|
||||
timeout?: number;
|
||||
};
|
||||
PageNotificationDTO: {
|
||||
/** Format: int64 */
|
||||
totalElements?: number;
|
||||
/** Format: int32 */
|
||||
totalPages?: number;
|
||||
/** Format: int64 */
|
||||
totalElements?: number;
|
||||
pageable?: components["schemas"]["PageableObject"];
|
||||
first?: boolean;
|
||||
last?: boolean;
|
||||
/** Format: int32 */
|
||||
size?: number;
|
||||
content?: components["schemas"]["NotificationDTO"][];
|
||||
@@ -1686,6 +1673,8 @@ export interface components {
|
||||
sort?: components["schemas"]["SortObject"];
|
||||
/** Format: int32 */
|
||||
numberOfElements?: number;
|
||||
first?: boolean;
|
||||
last?: boolean;
|
||||
empty?: boolean;
|
||||
};
|
||||
PageableObject: {
|
||||
@@ -3501,7 +3490,9 @@ export interface operations {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"*/*": components["schemas"]["TrainingInfoResponse"];
|
||||
"*/*": {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
13
frontend/src/lib/types/training.ts
Normal file
13
frontend/src/lib/types/training.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
export interface TrainingRun {
|
||||
id: string;
|
||||
status: 'QUEUED' | 'RUNNING' | 'DONE' | 'FAILED';
|
||||
personId?: string;
|
||||
blockCount: number;
|
||||
documentCount: number;
|
||||
modelName: string;
|
||||
errorMessage?: string;
|
||||
triggeredBy?: string;
|
||||
createdAt: string;
|
||||
completedAt?: string;
|
||||
cer?: number;
|
||||
}
|
||||
@@ -1,13 +1,84 @@
|
||||
"""Kraken OCR engine wrapper — historical HTR model support for Kurrent/Suetterlin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_model = None
|
||||
_model_path = os.environ.get("KRAKEN_MODEL_PATH", "/app/models/german_kurrent.mlmodel")
|
||||
|
||||
_MODELS_DIR = os.path.realpath("/app/models")
|
||||
_MAX_CACHED_SENDER_MODELS = int(os.environ.get("OCR_MAX_CACHED_MODELS", "5"))
|
||||
|
||||
|
||||
def _load_sender_model(path: str):
|
||||
"""Load a Kraken model from disk. Extracted for testability."""
|
||||
from kraken.lib import models as kraken_models
|
||||
return kraken_models.load_any(path)
|
||||
|
||||
|
||||
class _SenderModelRegistry:
|
||||
"""Thread-safe LRU cache for per-sender Kraken models.
|
||||
|
||||
Uses double-checked locking: model loading happens outside the lock to
|
||||
avoid blocking concurrent OCR requests. At most one entry per path is
|
||||
stored even under concurrent load.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int):
|
||||
self._max_size = max_size
|
||||
self._cache: collections.OrderedDict = collections.OrderedDict()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_model(self, model_path: str):
|
||||
"""Return the cached model or load it. Validates path is within /app/models/."""
|
||||
resolved = os.path.realpath(model_path)
|
||||
if not resolved.startswith(_MODELS_DIR + os.sep) and resolved != _MODELS_DIR:
|
||||
raise ValueError(f"Sender model path not allowed: {model_path}")
|
||||
|
||||
with self._lock:
|
||||
if model_path in self._cache:
|
||||
self._cache.move_to_end(model_path)
|
||||
return self._cache[model_path]
|
||||
|
||||
new_model = _load_sender_model(model_path)
|
||||
|
||||
with self._lock:
|
||||
if model_path in self._cache:
|
||||
self._cache.move_to_end(model_path)
|
||||
return self._cache[model_path]
|
||||
self._cache[model_path] = new_model
|
||||
self._cache.move_to_end(model_path)
|
||||
while len(self._cache) > self._max_size:
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
return new_model
|
||||
|
||||
def invalidate(self, model_path: str) -> None:
|
||||
"""Remove model from cache so the next request reloads from disk."""
|
||||
with self._lock:
|
||||
self._cache.pop(model_path, None)
|
||||
|
||||
def size(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._cache)
|
||||
|
||||
def _contains(self, model_path: str) -> bool:
|
||||
with self._lock:
|
||||
return model_path in self._cache
|
||||
|
||||
|
||||
_sender_registry = _SenderModelRegistry(_MAX_CACHED_SENDER_MODELS)
|
||||
|
||||
|
||||
def load_models():
|
||||
"""Load the Kraken model at startup. Skips if model file is not present."""
|
||||
@@ -29,10 +100,12 @@ def is_available() -> bool:
|
||||
return _model is not None
|
||||
|
||||
|
||||
def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict]:
|
||||
def extract_page_blocks(image: Image, page_idx: int, language: str = "de",
|
||||
sender_model_path: str | None = None) -> list[dict]:
|
||||
"""Run Kraken segmentation + recognition on a single PIL image.
|
||||
|
||||
Returns block dicts for that page. Coordinates are normalized to [0, 1].
|
||||
When sender_model_path is provided, the per-sender fine-tuned model is used.
|
||||
"""
|
||||
from kraken import blla, rpred
|
||||
from confidence import words_from_characters
|
||||
@@ -40,11 +113,13 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
||||
if _model is None:
|
||||
raise RuntimeError("Kraken model is not loaded")
|
||||
|
||||
active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model
|
||||
|
||||
page_w, page_h = image.size
|
||||
blocks = []
|
||||
|
||||
baseline_seg = blla.segment(image)
|
||||
pred_it = rpred.rpred(_model, image, baseline_seg)
|
||||
pred_it = rpred.rpred(active_model, image, baseline_seg)
|
||||
|
||||
for record in pred_it:
|
||||
polygon_pts = record.boundary if hasattr(record, "boundary") and record.boundary else []
|
||||
@@ -79,13 +154,15 @@ def extract_page_blocks(image, page_idx: int, language: str = "de") -> list[dict
|
||||
return blocks
|
||||
|
||||
|
||||
def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
||||
def extract_region_text(image: Image, x: float, y: float, w: float, h: float,
|
||||
sender_model_path: str | None = None) -> str:
|
||||
"""Crop image to a normalized region and run Kraken recognition on the crop.
|
||||
|
||||
Used for guided OCR — skips full-page layout detection entirely.
|
||||
A single synthetic baseline spanning the full crop width is used so that
|
||||
blla.segment() (which crashes on small crops) is never called.
|
||||
Coordinates are normalized to [0, 1].
|
||||
When sender_model_path is provided, the per-sender fine-tuned model is used.
|
||||
"""
|
||||
from kraken import rpred
|
||||
from kraken.containers import Segmentation, BaselineLine
|
||||
@@ -93,6 +170,8 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
||||
if _model is None:
|
||||
raise RuntimeError("Kraken model is not loaded")
|
||||
|
||||
active_model = _sender_registry.get_model(sender_model_path) if sender_model_path else _model
|
||||
|
||||
pw, ph = image.size
|
||||
x1 = max(0, int(x * pw))
|
||||
y1 = max(0, int(y * ph))
|
||||
@@ -123,11 +202,12 @@ def extract_region_text(image, x: float, y: float, w: float, h: float) -> str:
|
||||
regions={},
|
||||
line_orders=[],
|
||||
)
|
||||
pred_it = rpred.rpred(_model, crop, synthetic_seg)
|
||||
pred_it = rpred.rpred(active_model, crop, synthetic_seg)
|
||||
return " ".join(r.prediction for r in pred_it)
|
||||
|
||||
|
||||
def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
||||
def extract_blocks(images: list, language: str = "de",
|
||||
sender_model_path: str | None = None) -> list[dict]:
|
||||
"""Run Kraken segmentation + recognition on a list of PIL images.
|
||||
|
||||
Returns block dicts with pageNumber, x, y, width, height, polygon, text.
|
||||
@@ -137,7 +217,7 @@ def extract_blocks(images: list, language: str = "de") -> list[dict]:
|
||||
all_blocks = []
|
||||
|
||||
for page_idx, image in enumerate(images, start=1):
|
||||
all_blocks.extend(extract_page_blocks(image, page_idx, language))
|
||||
all_blocks.extend(extract_page_blocks(image, page_idx, language, sender_model_path))
|
||||
|
||||
return all_blocks
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import pypdfium2 as pdfium
|
||||
from fastapi import FastAPI, Header, HTTPException, UploadFile
|
||||
from fastapi import FastAPI, Form, Header, HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from PIL import Image
|
||||
|
||||
@@ -103,7 +103,8 @@ async def run_ocr(request: OcrRequest):
|
||||
status_code=400,
|
||||
detail="Kraken model not available — cannot process Kurrent script",
|
||||
)
|
||||
blocks = await asyncio.to_thread(kraken_engine.extract_blocks, images, request.language)
|
||||
blocks = await asyncio.to_thread(
|
||||
kraken_engine.extract_blocks, images, request.language, request.senderModelPath)
|
||||
else:
|
||||
# TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya
|
||||
blocks = await asyncio.to_thread(surya_engine.extract_blocks, images, request.language)
|
||||
@@ -171,10 +172,12 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
||||
image = await asyncio.to_thread(preprocess_page, image)
|
||||
blocks = []
|
||||
sender_path = request.senderModelPath if use_kraken else None
|
||||
for region in page_regions:
|
||||
text = await asyncio.to_thread(
|
||||
engine.extract_region_text, image,
|
||||
region.x, region.y, region.width, region.height,
|
||||
sender_path,
|
||||
)
|
||||
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
|
||||
text = correct_text(text)
|
||||
@@ -231,8 +234,9 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
try:
|
||||
yield json.dumps({"type": "preprocessing", "pageNumber": page_idx}) + "\n"
|
||||
image = await asyncio.to_thread(preprocess_page, image)
|
||||
sender_path = request.senderModelPath if use_kraken else None
|
||||
blocks = await asyncio.to_thread(
|
||||
engine.extract_page_blocks, image, page_idx, request.language
|
||||
engine.extract_page_blocks, image, page_idx, request.language, sender_path
|
||||
)
|
||||
|
||||
for block in blocks:
|
||||
@@ -443,6 +447,86 @@ async def train_model(
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/train-sender")
|
||||
async def train_sender_model(
|
||||
file: UploadFile,
|
||||
output_model_path: str = Form(...),
|
||||
x_training_token: str | None = Header(default=None),
|
||||
):
|
||||
"""Fine-tune a per-sender Kurrent model and save to output_model_path.
|
||||
|
||||
output_model_path must be within /app/models/. The per-sender model cache
|
||||
is invalidated after training so the next OCR request picks up the new model.
|
||||
"""
|
||||
_check_training_token(x_training_token)
|
||||
|
||||
if not _models_ready:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||
|
||||
models_dir = os.path.realpath("/app/models")
|
||||
resolved_output = os.path.realpath(output_model_path)
|
||||
if not resolved_output.startswith(models_dir + os.sep):
|
||||
raise HTTPException(status_code=400, detail="output_model_path must be within /app/models/")
|
||||
|
||||
zip_bytes = await file.read()
|
||||
|
||||
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
||||
log.info("Starting sender training run %s → %s", training_run_id, output_model_path)
|
||||
|
||||
def _run_sender_training() -> dict:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
for entry in zf.namelist():
|
||||
_validate_zip_entry(entry, tmp_dir)
|
||||
zf.extractall(tmp_dir)
|
||||
|
||||
ground_truth = glob.glob(os.path.join(tmp_dir, "*.xml"))
|
||||
if not ground_truth:
|
||||
raise HTTPException(status_code=422, detail="No ground-truth files found in ZIP")
|
||||
|
||||
log.info("Sender training on %d ground-truth pairs", len(ground_truth))
|
||||
checkpoint_dir = os.path.join(tmp_dir, "checkpoints")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
cmd = [
|
||||
"ketos", "--workers", "0", "--device", "cpu", "--threads", "2",
|
||||
"train",
|
||||
"-f", "page",
|
||||
"--weights-format", "coreml",
|
||||
"-o", checkpoint_dir,
|
||||
"-q", "fixed",
|
||||
"-N", "10",
|
||||
"-B", "1",
|
||||
]
|
||||
if os.path.exists(KRAKEN_MODEL_PATH):
|
||||
cmd += ["-i", KRAKEN_MODEL_PATH]
|
||||
cmd += ground_truth
|
||||
|
||||
log.info("Running sender training: %s", " ".join(cmd[:6]) + " ...")
|
||||
proc = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"ketos train failed (exit {proc.returncode}): {proc.stderr[-500:]}")
|
||||
|
||||
accuracy, epochs = _parse_best_checkpoint(checkpoint_dir)
|
||||
cer = round(1.0 - accuracy, 4) if accuracy is not None else None
|
||||
|
||||
best_model = _find_best_model(checkpoint_dir)
|
||||
if best_model is None:
|
||||
raise RuntimeError("Sender training produced no model file")
|
||||
|
||||
os.makedirs(os.path.dirname(resolved_output), exist_ok=True)
|
||||
shutil.copy2(best_model, resolved_output)
|
||||
log.info("Saved sender model to %s", resolved_output)
|
||||
|
||||
kraken_engine._sender_registry.invalidate(output_model_path)
|
||||
|
||||
return {"loss": None, "accuracy": accuracy, "cer": cer, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_sender_training)
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/segtrain")
|
||||
async def segtrain_model(
|
||||
file: UploadFile,
|
||||
|
||||
@@ -19,6 +19,7 @@ class OcrRequest(BaseModel):
|
||||
scriptType: str = "UNKNOWN"
|
||||
language: str = "de"
|
||||
regions: list[OcrRegion] | None = None
|
||||
senderModelPath: str | None = None
|
||||
|
||||
|
||||
class OcrBlock(BaseModel):
|
||||
|
||||
75
ocr-service/test_sender_registry.py
Normal file
75
ocr-service/test_sender_registry.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Tests for the per-sender model LRU registry in engines/kraken.py."""
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_registry(max_size=5):
|
||||
from engines.kraken import _SenderModelRegistry
|
||||
return _SenderModelRegistry(max_size=max_size)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_cache_hit_returns_same_object():
|
||||
"""Second get_model call with the same path must return the cached object."""
|
||||
registry = _make_registry()
|
||||
mock_model = MagicMock(name="model_a")
|
||||
with patch("engines.kraken._load_sender_model", return_value=mock_model) as loader:
|
||||
m1 = registry.get_model("/app/models/sender_a.mlmodel")
|
||||
m2 = registry.get_model("/app/models/sender_a.mlmodel")
|
||||
|
||||
assert m1 is m2
|
||||
loader.assert_called_once() # only loaded once despite two gets
|
||||
|
||||
|
||||
def test_lru_eviction_removes_least_recently_used():
|
||||
"""When the cache exceeds max_size, the oldest-accessed entry is evicted."""
|
||||
registry = _make_registry(max_size=2)
|
||||
|
||||
def _side_effect(path):
|
||||
return MagicMock(name=path)
|
||||
|
||||
with patch("engines.kraken._load_sender_model", side_effect=_side_effect):
|
||||
registry.get_model("/app/models/sender_a.mlmodel")
|
||||
registry.get_model("/app/models/sender_b.mlmodel")
|
||||
registry.get_model("/app/models/sender_c.mlmodel") # should evict 'a'
|
||||
|
||||
assert registry.size() == 2
|
||||
# 'a' was the least-recently-used and should be gone
|
||||
assert not registry._contains("/app/models/sender_a.mlmodel")
|
||||
|
||||
|
||||
def test_invalidate_removes_entry_from_cache():
|
||||
"""invalidate() must evict the entry so the next get reloads from disk."""
|
||||
registry = _make_registry()
|
||||
mock_model = MagicMock(name="model_x")
|
||||
with patch("engines.kraken._load_sender_model", return_value=mock_model):
|
||||
registry.get_model("/app/models/sender_x.mlmodel")
|
||||
assert registry.size() == 1
|
||||
|
||||
registry.invalidate("/app/models/sender_x.mlmodel")
|
||||
assert registry.size() == 0
|
||||
|
||||
|
||||
def test_path_outside_models_dir_raises():
|
||||
"""get_model must reject paths outside /app/models/ (path traversal guard)."""
|
||||
registry = _make_registry()
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
registry.get_model("/etc/passwd")
|
||||
|
||||
|
||||
def test_load_failure_does_not_cache_broken_entry():
|
||||
"""A failed load must not leave a broken entry in the cache."""
|
||||
registry = _make_registry()
|
||||
with patch("engines.kraken._load_sender_model", side_effect=RuntimeError("corrupt model")):
|
||||
with pytest.raises(RuntimeError, match="corrupt model"):
|
||||
registry.get_model("/app/models/sender_bad.mlmodel")
|
||||
|
||||
assert registry.size() == 0
|
||||
@@ -67,3 +67,42 @@ async def test_train_returns_403_when_token_wrong():
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# ─── /train-sender authentication ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_sender_returns_503_when_training_token_not_configured():
|
||||
"""POST /train-sender must return 503 when TRAINING_TOKEN env var is empty.
|
||||
|
||||
An empty token means the service was started without training configured.
|
||||
Allowing requests through would grant unauthenticated access to the
|
||||
training endpoint, contradicting the principle of failing closed.
|
||||
"""
|
||||
with patch("main.TRAINING_TOKEN", ""), \
|
||||
patch("main._models_ready", True):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/train-sender",
|
||||
data={"output_model_path": "/app/models/sender_test.mlmodel"},
|
||||
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
|
||||
)
|
||||
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_sender_returns_403_when_token_wrong():
|
||||
"""POST /train-sender must return 403 when TRAINING_TOKEN is set but header is wrong."""
|
||||
with patch("main.TRAINING_TOKEN", "secret-token"), \
|
||||
patch("main._models_ready", True):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.post(
|
||||
"/train-sender",
|
||||
data={"output_model_path": "/app/models/sender_test.mlmodel"},
|
||||
files={"file": ("training.zip", _minimal_zip(), "application/zip")},
|
||||
headers={"X-Training-Token": "wrong-token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
Reference in New Issue
Block a user