feat(training): add document-level training enrollment
- V29 migration: document_training_labels join table
- TrainingLabel enum: KURRENT_RECOGNITION, KURRENT_SEGMENTATION
- Document.trainingLabels @ElementCollection
- DocumentService.addTrainingLabel / removeTrainingLabel
- PATCH /api/documents/{id}/training-labels (WRITE_ALL)
- Auto-enroll on Kurrent OCR trigger (OcrService.startOcr)
- TranscriptionEditView: enrollment chips in panel footer
- JPQL queries updated to use MEMBER OF trainingLabels
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -21,6 +21,7 @@ import org.raddatz.familienarchiv.exception.ErrorCode;
|
||||
import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.dto.DocumentSort;
|
||||
import org.raddatz.familienarchiv.model.DocumentStatus;
|
||||
import org.raddatz.familienarchiv.model.TrainingLabel;
|
||||
import org.raddatz.familienarchiv.model.DocumentVersion;
|
||||
import org.raddatz.familienarchiv.security.Permission;
|
||||
import org.raddatz.familienarchiv.security.RequirePermission;
|
||||
@@ -35,9 +36,11 @@ import org.springframework.core.io.InputStreamResource;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.ModelAttribute;
|
||||
import org.springframework.web.bind.annotation.PatchMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.PutMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RequestPart;
|
||||
@@ -208,6 +211,29 @@ public class DocumentController {
|
||||
return ResponseEntity.ok(DocumentSearchResult.of(results));
|
||||
}
|
||||
|
||||
// --- TRAINING LABELS ---
|
||||
|
||||
public record TrainingLabelRequest(String label, boolean enrolled) {}
|
||||
|
||||
@PatchMapping("/{id}/training-labels")
|
||||
@RequirePermission(Permission.WRITE_ALL)
|
||||
public ResponseEntity<Void> patchTrainingLabel(
|
||||
@PathVariable UUID id,
|
||||
@RequestBody TrainingLabelRequest req) {
|
||||
TrainingLabel label;
|
||||
try {
|
||||
label = TrainingLabel.valueOf(req.label());
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Unknown training label: " + req.label());
|
||||
}
|
||||
if (req.enrolled()) {
|
||||
documentService.addTrainingLabel(id, label);
|
||||
} else {
|
||||
documentService.removeTrainingLabel(id, label);
|
||||
}
|
||||
return ResponseEntity.noContent().build();
|
||||
}
|
||||
|
||||
// --- VERSIONS ---
|
||||
|
||||
@GetMapping("/{id}/versions")
|
||||
|
||||
@@ -110,4 +110,11 @@ public class Document {
|
||||
@JoinTable(name = "document_tags", joinColumns = @JoinColumn(name = "document_id"), inverseJoinColumns = @JoinColumn(name = "tag_id"))
|
||||
@Builder.Default
|
||||
private Set<Tag> tags = new HashSet<>();
|
||||
|
||||
@ElementCollection(fetch = FetchType.EAGER)
|
||||
@CollectionTable(name = "document_training_labels", joinColumns = @JoinColumn(name = "document_id"))
|
||||
@Column(name = "label")
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Builder.Default
|
||||
private Set<TrainingLabel> trainingLabels = new HashSet<>();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
package org.raddatz.familienarchiv.model;
|
||||
|
||||
public enum TrainingLabel {
|
||||
KURRENT_RECOGNITION,
|
||||
KURRENT_SEGMENTATION
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package org.raddatz.familienarchiv.repository;
|
||||
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
import org.springframework.data.jpa.repository.Query;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
@@ -14,4 +15,22 @@ public interface TranscriptionBlockRepository extends JpaRepository<Transcriptio
|
||||
Optional<TranscriptionBlock> findByIdAndDocumentId(UUID id, UUID documentId);
|
||||
|
||||
int countByDocumentId(UUID documentId);
|
||||
|
||||
@Query("""
|
||||
SELECT b FROM TranscriptionBlock b
|
||||
JOIN DocumentAnnotation a ON a.id = b.annotationId
|
||||
JOIN Document d ON d.id = b.documentId
|
||||
WHERE (b.source = 'MANUAL' OR (b.source = 'OCR' AND b.reviewed = true))
|
||||
AND 'KURRENT_RECOGNITION' MEMBER OF d.trainingLabels
|
||||
""")
|
||||
List<TranscriptionBlock> findEligibleKurrentBlocks();
|
||||
|
||||
@Query("""
|
||||
SELECT b FROM TranscriptionBlock b
|
||||
JOIN DocumentAnnotation a ON a.id = b.annotationId
|
||||
JOIN Document d ON d.id = b.documentId
|
||||
WHERE b.source = 'MANUAL' AND (b.text IS NULL OR b.text = '')
|
||||
AND 'KURRENT_SEGMENTATION' MEMBER OF d.trainingLabels
|
||||
""")
|
||||
List<TranscriptionBlock> findSegmentationBlocks();
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.dto.DocumentSort;
|
||||
import org.raddatz.familienarchiv.model.DocumentStatus;
|
||||
import org.raddatz.familienarchiv.model.ScriptType;
|
||||
import org.raddatz.familienarchiv.model.TrainingLabel;
|
||||
import org.raddatz.familienarchiv.model.Person;
|
||||
import org.raddatz.familienarchiv.model.Tag;
|
||||
import org.raddatz.familienarchiv.repository.DocumentRepository;
|
||||
@@ -385,6 +386,20 @@ public class DocumentService {
|
||||
documentRepository.save(doc);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void addTrainingLabel(UUID documentId, TrainingLabel label) {
|
||||
Document doc = getDocumentById(documentId);
|
||||
doc.getTrainingLabels().add(label);
|
||||
documentRepository.save(doc);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void removeTrainingLabel(UUID documentId, TrainingLabel label) {
|
||||
Document doc = getDocumentById(documentId);
|
||||
doc.getTrainingLabels().remove(label);
|
||||
documentRepository.save(doc);
|
||||
}
|
||||
|
||||
public Document getDocumentById(UUID id) {
|
||||
return documentRepository.findById(id)
|
||||
.orElseThrow(() -> DomainException.notFound(ErrorCode.DOCUMENT_NOT_FOUND, "Document not found: " + id));
|
||||
|
||||
@@ -66,6 +66,9 @@ public class OcrService {
|
||||
|
||||
if (scriptTypeOverride != null) {
|
||||
documentService.updateScriptType(documentId, scriptTypeOverride);
|
||||
if (scriptTypeOverride == ScriptType.HANDWRITING_KURRENT) {
|
||||
documentService.addTrainingLabel(documentId, TrainingLabel.KURRENT_RECOGNITION);
|
||||
}
|
||||
}
|
||||
|
||||
OcrJob job = OcrJob.builder()
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE document_training_labels (
|
||||
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
||||
label VARCHAR(50) NOT NULL,
|
||||
PRIMARY KEY (document_id, label)
|
||||
);
|
||||
@@ -14,6 +14,7 @@ import org.raddatz.familienarchiv.service.FileService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.webmvc.test.autoconfigure.WebMvcTest;
|
||||
import org.raddatz.familienarchiv.config.SecurityConfig;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.boot.autoconfigure.aop.AopAutoConfiguration;
|
||||
import org.springframework.context.annotation.Import;
|
||||
import org.springframework.security.test.context.support.WithMockUser;
|
||||
@@ -31,8 +32,10 @@ import static org.mockito.ArgumentMatchers.anyInt;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.multipart;
|
||||
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.patch;
|
||||
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
|
||||
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
|
||||
|
||||
@@ -485,6 +488,58 @@ class DocumentControllerTest {
|
||||
.andExpect(jsonPath("$[0].editorName").value("Emma Müller"));
|
||||
}
|
||||
|
||||
// ─── PATCH /api/documents/{id}/training-labels ───────────────────────────
|
||||
|
||||
@Test
|
||||
void patchTrainingLabels_returns401_whenUnauthenticated() throws Exception {
|
||||
mockMvc.perform(patch("/api/documents/" + UUID.randomUUID() + "/training-labels")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{\"label\":\"KURRENT_RECOGNITION\",\"enrolled\":true}"))
|
||||
.andExpect(status().isUnauthorized());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser
|
||||
void patchTrainingLabels_returns403_whenMissingWritePermission() throws Exception {
|
||||
mockMvc.perform(patch("/api/documents/" + UUID.randomUUID() + "/training-labels")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{\"label\":\"KURRENT_RECOGNITION\",\"enrolled\":true}"))
|
||||
.andExpect(status().isForbidden());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void patchTrainingLabels_returns204_whenAddingLabel() throws Exception {
|
||||
UUID id = UUID.randomUUID();
|
||||
mockMvc.perform(patch("/api/documents/" + id + "/training-labels")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{\"label\":\"KURRENT_RECOGNITION\",\"enrolled\":true}"))
|
||||
.andExpect(status().isNoContent());
|
||||
|
||||
verify(documentService).addTrainingLabel(eq(id), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void patchTrainingLabels_returns204_whenRemovingLabel() throws Exception {
|
||||
UUID id = UUID.randomUUID();
|
||||
mockMvc.perform(patch("/api/documents/" + id + "/training-labels")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{\"label\":\"KURRENT_SEGMENTATION\",\"enrolled\":false}"))
|
||||
.andExpect(status().isNoContent());
|
||||
|
||||
verify(documentService).removeTrainingLabel(eq(id), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
@WithMockUser(authorities = "WRITE_ALL")
|
||||
void patchTrainingLabels_returns400_whenUnknownLabel() throws Exception {
|
||||
mockMvc.perform(patch("/api/documents/" + UUID.randomUUID() + "/training-labels")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.content("{\"label\":\"UNKNOWN_GARBAGE\",\"enrolled\":true}"))
|
||||
.andExpect(status().isBadRequest());
|
||||
}
|
||||
|
||||
// ─── GET /api/documents/{id}/versions/{versionId} ────────────────────────
|
||||
|
||||
@Test
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
package org.raddatz.familienarchiv.repository;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.raddatz.familienarchiv.PostgresContainerConfig;
|
||||
import org.raddatz.familienarchiv.config.FlywayConfig;
|
||||
import org.raddatz.familienarchiv.model.*;
|
||||
import java.util.Set;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.jdbc.test.autoconfigure.AutoConfigureTestDatabase;
|
||||
import org.springframework.boot.data.jpa.test.autoconfigure.DataJpaTest;
|
||||
import org.springframework.context.annotation.Import;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@DataJpaTest
|
||||
@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE)
|
||||
@Import({PostgresContainerConfig.class, FlywayConfig.class})
|
||||
class TrainingBlockQueryTest {
|
||||
|
||||
@Autowired TranscriptionBlockRepository blockRepository;
|
||||
@Autowired DocumentRepository documentRepository;
|
||||
@Autowired AnnotationRepository annotationRepository;
|
||||
|
||||
private UUID kurrentDocId;
|
||||
private UUID typewriterDocId;
|
||||
private UUID kurrentAnnotationId;
|
||||
private UUID typewriterAnnotationId;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
Document kurrentDoc = documentRepository.save(Document.builder()
|
||||
.title("Kurrent Brief")
|
||||
.originalFilename("kurrent.pdf")
|
||||
.status(DocumentStatus.UPLOADED)
|
||||
.trainingLabels(new java.util.HashSet<>(Set.of(TrainingLabel.KURRENT_RECOGNITION)))
|
||||
.build());
|
||||
kurrentDocId = kurrentDoc.getId();
|
||||
|
||||
Document typewriterDoc = documentRepository.save(Document.builder()
|
||||
.title("Getippter Brief")
|
||||
.originalFilename("typed.pdf")
|
||||
.status(DocumentStatus.UPLOADED)
|
||||
.build());
|
||||
typewriterDocId = typewriterDoc.getId();
|
||||
|
||||
kurrentAnnotationId = annotationRepository.save(annotation(kurrentDocId)).getId();
|
||||
typewriterAnnotationId = annotationRepository.save(annotation(typewriterDocId)).getId();
|
||||
}
|
||||
|
||||
@Test
|
||||
void findEligibleKurrentBlocks_includesManualBlock() {
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.MANUAL, false));
|
||||
|
||||
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
|
||||
|
||||
assertThat(result).hasSize(1);
|
||||
assertThat(result.get(0).getSource()).isEqualTo(BlockSource.MANUAL);
|
||||
}
|
||||
|
||||
@Test
|
||||
void findEligibleKurrentBlocks_includesReviewedOcrBlock() {
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.OCR, true));
|
||||
|
||||
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
|
||||
|
||||
assertThat(result).hasSize(1);
|
||||
assertThat(result.get(0).isReviewed()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void findEligibleKurrentBlocks_excludesUnreviewedOcrBlock() {
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.OCR, false));
|
||||
|
||||
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
|
||||
|
||||
assertThat(result).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void findEligibleKurrentBlocks_excludesNonEnrolledDocument() {
|
||||
blockRepository.save(block(typewriterDocId, typewriterAnnotationId, BlockSource.MANUAL, false));
|
||||
|
||||
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
|
||||
|
||||
assertThat(result).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void findEligibleKurrentBlocks_returnsAllEligibleAcrossBothSources() {
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.MANUAL, false));
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.OCR, true));
|
||||
blockRepository.save(block(kurrentDocId, kurrentAnnotationId, BlockSource.OCR, false)); // excluded
|
||||
|
||||
List<TranscriptionBlock> result = blockRepository.findEligibleKurrentBlocks();
|
||||
|
||||
assertThat(result).hasSize(2);
|
||||
}
|
||||
|
||||
// ─── helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
private DocumentAnnotation annotation(UUID docId) {
|
||||
return DocumentAnnotation.builder()
|
||||
.documentId(docId)
|
||||
.pageNumber(1)
|
||||
.x(0.1).y(0.2).width(0.3).height(0.4)
|
||||
.color("#00C7B1")
|
||||
.build();
|
||||
}
|
||||
|
||||
private TranscriptionBlock block(UUID docId, UUID annotId, BlockSource source, boolean reviewed) {
|
||||
return TranscriptionBlock.builder()
|
||||
.annotationId(annotId)
|
||||
.documentId(docId)
|
||||
.text("Liebe Tante")
|
||||
.sortOrder(0)
|
||||
.source(source)
|
||||
.reviewed(reviewed)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user