feat(training): add segmentation training pipeline and complete Part 6
- Add /segtrain endpoint to OCR service (ZIP upload, ketos.segtrain, backup rotation, in-process model reload) - Add segtrainModel() to OcrClient and RestClientOcrClient (10-min timeout, X-Training-Token header) - Add SegmentationTrainingExportService: PAGE XML export with polygon de-normalization and per-page PNG rendering via PDFBox - Add GET /api/ocr/segmentation-training-data/export endpoint - Make TranscriptionBlock.text nullable for segmentation-only blocks (V31 migration) - Add Paraglide i18n translation keys for all training UI strings (de/en/es) - Pass source prop from TranscriptionEditView to TranscriptionBlock Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -152,6 +152,13 @@
|
||||
<artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
|
||||
<version>3.0.2</version>
|
||||
</dependency>
|
||||
|
||||
<!-- PDF rendering for training data export -->
|
||||
<dependency>
|
||||
<groupId>org.apache.pdfbox</groupId>
|
||||
<artifactId>pdfbox</artifactId>
|
||||
<version>3.0.4</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import org.raddatz.familienarchiv.service.OcrBatchService;
|
||||
import org.raddatz.familienarchiv.service.OcrProgressService;
|
||||
import org.raddatz.familienarchiv.service.OcrService;
|
||||
import org.raddatz.familienarchiv.service.OcrTrainingService;
|
||||
import org.raddatz.familienarchiv.service.SegmentationTrainingExportService;
|
||||
import org.raddatz.familienarchiv.service.TrainingDataExportService;
|
||||
import org.raddatz.familienarchiv.service.UserService;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
@@ -39,6 +40,7 @@ public class OcrController {
|
||||
private final OcrProgressService ocrProgressService;
|
||||
private final UserService userService;
|
||||
private final TrainingDataExportService trainingDataExportService;
|
||||
private final SegmentationTrainingExportService segmentationTrainingExportService;
|
||||
private final OcrTrainingService ocrTrainingService;
|
||||
|
||||
@PostMapping("/api/documents/{documentId}/ocr")
|
||||
@@ -96,6 +98,19 @@ public class OcrController {
|
||||
.body(body);
|
||||
}
|
||||
|
||||
@GetMapping("/api/ocr/segmentation-training-data/export")
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
public ResponseEntity<StreamingResponseBody> exportSegmentationTrainingData() {
|
||||
if (segmentationTrainingExportService.querySegmentationBlocks().isEmpty()) {
|
||||
return ResponseEntity.noContent().build();
|
||||
}
|
||||
StreamingResponseBody body = segmentationTrainingExportService.exportToZip();
|
||||
return ResponseEntity.ok()
|
||||
.contentType(MediaType.parseMediaType("application/zip"))
|
||||
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"segmentation-data.zip\"")
|
||||
.body(body);
|
||||
}
|
||||
|
||||
@PostMapping("/api/ocr/train")
|
||||
@ResponseStatus(HttpStatus.CREATED)
|
||||
@RequirePermission(Permission.ADMIN)
|
||||
|
||||
@@ -30,8 +30,7 @@ public class TranscriptionBlock {
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
private UUID documentId;
|
||||
|
||||
@Column(nullable = false, columnDefinition = "TEXT")
|
||||
@Schema(requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@Column(columnDefinition = "TEXT")
|
||||
private String text;
|
||||
|
||||
@Column(length = 200)
|
||||
|
||||
@@ -20,6 +20,14 @@ public interface OcrClient {
|
||||
|
||||
record TrainingResult(Double loss, Double accuracy, Integer epochs) {}
|
||||
|
||||
/**
|
||||
* Send a segmentation training ZIP to the OCR service for fine-tuning the blla model.
|
||||
*
|
||||
* @param trainingDataZip raw ZIP bytes produced by SegmentationTrainingExportService
|
||||
* @return training result metrics
|
||||
*/
|
||||
TrainingResult segtrainModel(byte[] trainingDataZip);
|
||||
|
||||
/**
|
||||
* Stream OCR results page-by-page via NDJSON. Implementations should override
|
||||
* this method. The default exists only for backward compatibility during migration
|
||||
|
||||
@@ -131,6 +131,35 @@ public class RestClientOcrClient implements OcrClient, OcrHealthClient {
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
|
||||
}
|
||||
|
||||
@Override
|
||||
public OcrClient.TrainingResult segtrainModel(byte[] trainingDataZip) {
|
||||
ByteArrayResource zipResource = new ByteArrayResource(trainingDataZip) {
|
||||
@Override
|
||||
public String getFilename() { return "segmentation-data.zip"; }
|
||||
};
|
||||
|
||||
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
|
||||
HttpHeaders partHeaders = new HttpHeaders();
|
||||
partHeaders.setContentType(MediaType.parseMediaType("application/zip"));
|
||||
body.add("file", new HttpEntity<>(zipResource, partHeaders));
|
||||
|
||||
var spec = trainingRestClient.post()
|
||||
.uri("/segtrain")
|
||||
.contentType(MediaType.MULTIPART_FORM_DATA);
|
||||
|
||||
if (trainingToken != null && !trainingToken.isBlank()) {
|
||||
spec = spec.header("X-Training-Token", trainingToken);
|
||||
}
|
||||
|
||||
TrainingResultJson result = spec
|
||||
.body(body)
|
||||
.retrieve()
|
||||
.body(TrainingResultJson.class);
|
||||
|
||||
if (result == null) return new OcrClient.TrainingResult(null, null, null);
|
||||
return new OcrClient.TrainingResult(result.loss(), result.accuracy(), result.epochs());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isHealthy() {
|
||||
try {
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
package org.raddatz.familienarchiv.service;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.pdfbox.Loader;
|
||||
import org.apache.pdfbox.pdmodel.PDDocument;
|
||||
import org.apache.pdfbox.rendering.PDFRenderer;
|
||||
import org.raddatz.familienarchiv.model.Document;
|
||||
import org.raddatz.familienarchiv.model.DocumentAnnotation;
|
||||
import org.raddatz.familienarchiv.model.TranscriptionBlock;
|
||||
import org.raddatz.familienarchiv.repository.AnnotationRepository;
|
||||
import org.raddatz.familienarchiv.repository.DocumentRepository;
|
||||
import org.raddatz.familienarchiv.repository.TranscriptionBlockRepository;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.*;
|
||||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipOutputStream;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class SegmentationTrainingExportService {
|
||||
|
||||
private final TranscriptionBlockRepository blockRepository;
|
||||
private final AnnotationRepository annotationRepository;
|
||||
private final DocumentRepository documentRepository;
|
||||
private final FileService fileService;
|
||||
|
||||
public List<TranscriptionBlock> querySegmentationBlocks() {
|
||||
return blockRepository.findSegmentationBlocks();
|
||||
}
|
||||
|
||||
public StreamingResponseBody exportToZip() {
|
||||
List<TranscriptionBlock> blocks = querySegmentationBlocks();
|
||||
if (blocks.isEmpty()) {
|
||||
return out -> {};
|
||||
}
|
||||
|
||||
// Group by documentId so we download each PDF only once
|
||||
Map<UUID, List<TranscriptionBlock>> byDoc = new LinkedHashMap<>();
|
||||
for (TranscriptionBlock b : blocks) {
|
||||
byDoc.computeIfAbsent(b.getDocumentId(), k -> new ArrayList<>()).add(b);
|
||||
}
|
||||
|
||||
// Pre-fetch annotations keyed by id
|
||||
Map<UUID, DocumentAnnotation> annotations = new HashMap<>();
|
||||
for (TranscriptionBlock b : blocks) {
|
||||
annotationRepository.findById(b.getAnnotationId())
|
||||
.ifPresent(a -> annotations.put(a.getId(), a));
|
||||
}
|
||||
|
||||
// Pre-fetch documents keyed by id
|
||||
Map<UUID, Document> documents = new HashMap<>();
|
||||
for (UUID docId : byDoc.keySet()) {
|
||||
documentRepository.findById(docId).ifPresent(d -> documents.put(d.getId(), d));
|
||||
}
|
||||
|
||||
return out -> {
|
||||
try (ZipOutputStream zip = new ZipOutputStream(out)) {
|
||||
for (Map.Entry<UUID, List<TranscriptionBlock>> entry : byDoc.entrySet()) {
|
||||
UUID docId = entry.getKey();
|
||||
Document doc = documents.get(docId);
|
||||
if (doc == null || doc.getFilePath() == null) {
|
||||
log.warn("Skipping document {} — no file path", docId);
|
||||
continue;
|
||||
}
|
||||
|
||||
byte[] pdfBytes;
|
||||
try {
|
||||
pdfBytes = fileService.downloadFileBytes(doc.getFilePath());
|
||||
} catch (FileService.StorageFileNotFoundException | IOException e) {
|
||||
log.warn("Skipping document {} — S3 download failed: {}", docId, e.getMessage());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Group blocks by page number for this document
|
||||
Map<Integer, List<TranscriptionBlock>> byPage = new LinkedHashMap<>();
|
||||
for (TranscriptionBlock b : entry.getValue()) {
|
||||
DocumentAnnotation ann = annotations.get(b.getAnnotationId());
|
||||
if (ann != null) {
|
||||
byPage.computeIfAbsent(ann.getPageNumber(), k -> new ArrayList<>()).add(b);
|
||||
}
|
||||
}
|
||||
|
||||
try (PDDocument pdf = Loader.loadPDF(pdfBytes)) {
|
||||
PDFRenderer renderer = new PDFRenderer(pdf);
|
||||
for (Map.Entry<Integer, List<TranscriptionBlock>> pageEntry : byPage.entrySet()) {
|
||||
int pageNumber = pageEntry.getKey();
|
||||
int pageIdx = pageNumber - 1;
|
||||
if (pageIdx < 0 || pageIdx >= pdf.getNumberOfPages()) continue;
|
||||
|
||||
BufferedImage pageImage = renderer.renderImageWithDPI(pageIdx, 300);
|
||||
String basename = "page-" + docId + "-" + pageNumber;
|
||||
|
||||
// Collect annotations for this page
|
||||
List<DocumentAnnotation> pageAnnotations = new ArrayList<>();
|
||||
for (TranscriptionBlock b : pageEntry.getValue()) {
|
||||
DocumentAnnotation ann = annotations.get(b.getAnnotationId());
|
||||
if (ann != null) pageAnnotations.add(ann);
|
||||
}
|
||||
|
||||
writePngEntry(zip, basename, pageImage);
|
||||
writePageXmlEntry(zip, basename, pageImage, pageAnnotations);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.warn("Skipping document {} — rendering failed: {}", docId, e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void writePngEntry(ZipOutputStream zip, String basename, BufferedImage image) throws IOException {
|
||||
zip.putNextEntry(new ZipEntry(basename + ".png"));
|
||||
ImageIO.write(image, "PNG", zip);
|
||||
zip.closeEntry();
|
||||
}
|
||||
|
||||
private void writePageXmlEntry(ZipOutputStream zip, String basename,
|
||||
BufferedImage pageImage,
|
||||
List<DocumentAnnotation> annotations) throws IOException {
|
||||
int imgW = pageImage.getWidth();
|
||||
int imgH = pageImage.getHeight();
|
||||
|
||||
StringBuilder regions = new StringBuilder();
|
||||
for (DocumentAnnotation ann : annotations) {
|
||||
String coords = buildPolygonCoords(ann, imgW, imgH);
|
||||
String regionId = ann.getId().toString();
|
||||
regions.append(" <TextRegion id=\"").append(regionId).append("\">\n");
|
||||
regions.append(" <Coords points=\"").append(coords).append("\"/>\n");
|
||||
regions.append(" </TextRegion>\n");
|
||||
}
|
||||
|
||||
String xml = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n"
|
||||
+ "<PcGts xmlns=\"http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15\">\n"
|
||||
+ " <Page imageFilename=\"" + basename + ".png\""
|
||||
+ " imageWidth=\"" + imgW + "\""
|
||||
+ " imageHeight=\"" + imgH + "\">\n"
|
||||
+ regions
|
||||
+ " </Page>\n"
|
||||
+ "</PcGts>\n";
|
||||
|
||||
zip.putNextEntry(new ZipEntry(basename + ".xml"));
|
||||
zip.write(xml.getBytes(StandardCharsets.UTF_8));
|
||||
zip.closeEntry();
|
||||
}
|
||||
|
||||
String buildPolygonCoords(DocumentAnnotation ann, int imgW, int imgH) {
|
||||
List<List<Double>> polygon = ann.getPolygon();
|
||||
if (polygon != null && !polygon.isEmpty()) {
|
||||
// Use explicit polygon — de-normalize to pixel coordinates
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (List<Double> pt : polygon) {
|
||||
if (sb.length() > 0) sb.append(' ');
|
||||
int px = (int) (pt.get(0) * imgW);
|
||||
int py = (int) (pt.get(1) * imgH);
|
||||
sb.append(px).append(',').append(py);
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
// Fall back to bounding box from x/y/width/height
|
||||
int x = (int) (ann.getX() * imgW);
|
||||
int y = (int) (ann.getY() * imgH);
|
||||
int w = (int) (ann.getWidth() * imgW);
|
||||
int h = (int) (ann.getHeight() * imgH);
|
||||
return x + "," + y + " " + (x + w) + "," + y + " " + (x + w) + "," + (y + h) + " " + x + "," + (y + h);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Intentional: segmentation-only blocks have no text.
|
||||
-- This migration is irreversible without a data cleanup step
|
||||
-- (cannot re-add NOT NULL if null rows exist).
|
||||
ALTER TABLE transcription_blocks ALTER COLUMN text DROP NOT NULL;
|
||||
ALTER TABLE transcription_blocks ALTER COLUMN text SET DEFAULT '';
|
||||
@@ -20,6 +20,7 @@ class OcrClientDefaultStreamTest {
|
||||
new OcrBlockResult(1, 0.1, 0.1, 0.8, 0.04, null, "Line 3"));
|
||||
}
|
||||
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
|
||||
};
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
@@ -50,6 +51,7 @@ class OcrClientDefaultStreamTest {
|
||||
OcrClient client = new OcrClient() {
|
||||
@Override public List<OcrBlockResult> extractBlocks(String pdfUrl, ScriptType scriptType) { return List.of(); }
|
||||
@Override public TrainingResult trainModel(byte[] zip) { return null; }
|
||||
@Override public TrainingResult segtrainModel(byte[] zip) { return null; }
|
||||
};
|
||||
|
||||
List<OcrStreamEvent> events = new ArrayList<>();
|
||||
|
||||
@@ -533,5 +533,28 @@
|
||||
"ocr_status_error": "OCR fehlgeschlagen",
|
||||
"transcription_block_review": "Als geprüft markieren",
|
||||
"transcription_block_unreview": "Markierung aufheben",
|
||||
"transcription_reviewed_count": "{reviewed} von {total} geprüft"
|
||||
"transcription_reviewed_count": "{reviewed} von {total} geprüft",
|
||||
"training_ocr_heading": "Kurrent-Erkennung trainieren",
|
||||
"training_ocr_description": "Starte ein neues Training mit den bisher geprüften OCR-Blöcken, um die Erkennungsgenauigkeit für Kurrentschrift zu verbessern.",
|
||||
"training_ocr_blocks_ready": "{blocks} geprüfte Blöcke bereit / {docs} Dokumente",
|
||||
"training_ocr_blocks_total": "(von {total} OCR-Blöcken gesamt)",
|
||||
"training_start_btn": "Training starten",
|
||||
"training_in_progress": "…",
|
||||
"training_success": "Training wurde gestartet und abgeschlossen.",
|
||||
"training_too_few_blocks": "Mindestens 5 geprüfte Blöcke erforderlich (aktuell: {available}).",
|
||||
"training_service_down": "OCR-Dienst ist nicht erreichbar.",
|
||||
"training_history_heading": "Verlauf",
|
||||
"training_history_empty": "Noch keine Trainings-Läufe.",
|
||||
"training_history_col_date": "Datum",
|
||||
"training_history_col_status": "Status",
|
||||
"training_history_col_blocks": "Blöcke",
|
||||
"training_history_col_docs": "Dokumente",
|
||||
"training_status_done": "Fertig",
|
||||
"training_status_failed": "Fehler",
|
||||
"training_status_running": "Läuft…",
|
||||
"training_seg_heading": "Segmentierung trainieren",
|
||||
"training_seg_description": "Starte ein neues Training mit annotierten Segmentierungsbereichen, um die Texterkennung zu verbessern.",
|
||||
"training_seg_blocks_ready": "{blocks} Segmentierungsblöcke bereit",
|
||||
"training_seg_too_few_blocks": "Mindestens 5 Segmentierungsblöcke erforderlich (aktuell: {available}).",
|
||||
"transcription_block_segmentation_only": "Nur Segmentierung"
|
||||
}
|
||||
|
||||
@@ -533,5 +533,28 @@
|
||||
"ocr_status_error": "OCR failed",
|
||||
"transcription_block_review": "Mark as reviewed",
|
||||
"transcription_block_unreview": "Unmark as reviewed",
|
||||
"transcription_reviewed_count": "{reviewed} of {total} reviewed"
|
||||
"transcription_reviewed_count": "{reviewed} of {total} reviewed",
|
||||
"training_ocr_heading": "Train Kurrent recognition",
|
||||
"training_ocr_description": "Start a new training run using the reviewed OCR blocks to improve recognition accuracy for Kurrent script.",
|
||||
"training_ocr_blocks_ready": "{blocks} reviewed blocks ready / {docs} documents",
|
||||
"training_ocr_blocks_total": "(of {total} OCR blocks total)",
|
||||
"training_start_btn": "Start training",
|
||||
"training_in_progress": "…",
|
||||
"training_success": "Training started and completed.",
|
||||
"training_too_few_blocks": "At least 5 reviewed blocks required (currently: {available}).",
|
||||
"training_service_down": "OCR service is unavailable.",
|
||||
"training_history_heading": "History",
|
||||
"training_history_empty": "No training runs yet.",
|
||||
"training_history_col_date": "Date",
|
||||
"training_history_col_status": "Status",
|
||||
"training_history_col_blocks": "Blocks",
|
||||
"training_history_col_docs": "Documents",
|
||||
"training_status_done": "Done",
|
||||
"training_status_failed": "Failed",
|
||||
"training_status_running": "Running…",
|
||||
"training_seg_heading": "Train segmentation",
|
||||
"training_seg_description": "Start a new training run using annotated segmentation regions to improve text detection.",
|
||||
"training_seg_blocks_ready": "{blocks} segmentation blocks ready",
|
||||
"training_seg_too_few_blocks": "At least 5 segmentation blocks required (currently: {available}).",
|
||||
"transcription_block_segmentation_only": "Segmentation only"
|
||||
}
|
||||
|
||||
@@ -533,5 +533,28 @@
|
||||
"ocr_status_error": "OCR fallido",
|
||||
"transcription_block_review": "Marcar como revisado",
|
||||
"transcription_block_unreview": "Desmarcar como revisado",
|
||||
"transcription_reviewed_count": "{reviewed} de {total} revisados"
|
||||
"transcription_reviewed_count": "{reviewed} de {total} revisados",
|
||||
"training_ocr_heading": "Entrenar reconocimiento Kurrent",
|
||||
"training_ocr_description": "Inicia un nuevo entrenamiento con los bloques OCR revisados para mejorar la precisión de reconocimiento del script Kurrent.",
|
||||
"training_ocr_blocks_ready": "{blocks} bloques revisados listos / {docs} documentos",
|
||||
"training_ocr_blocks_total": "(de {total} bloques OCR en total)",
|
||||
"training_start_btn": "Iniciar entrenamiento",
|
||||
"training_in_progress": "…",
|
||||
"training_success": "Entrenamiento iniciado y completado.",
|
||||
"training_too_few_blocks": "Se requieren al menos 5 bloques revisados (actualmente: {available}).",
|
||||
"training_service_down": "El servicio OCR no está disponible.",
|
||||
"training_history_heading": "Historial",
|
||||
"training_history_empty": "Todavía no hay ejecuciones de entrenamiento.",
|
||||
"training_history_col_date": "Fecha",
|
||||
"training_history_col_status": "Estado",
|
||||
"training_history_col_blocks": "Bloques",
|
||||
"training_history_col_docs": "Documentos",
|
||||
"training_status_done": "Listo",
|
||||
"training_status_failed": "Error",
|
||||
"training_status_running": "Ejecutando…",
|
||||
"training_seg_heading": "Entrenar segmentación",
|
||||
"training_seg_description": "Inicia un nuevo entrenamiento con regiones de segmentación anotadas para mejorar la detección de texto.",
|
||||
"training_seg_blocks_ready": "{blocks} bloques de segmentación listos",
|
||||
"training_seg_too_few_blocks": "Se requieren al menos 5 bloques de segmentación (actualmente: {available}).",
|
||||
"transcription_block_segmentation_only": "Solo segmentación"
|
||||
}
|
||||
|
||||
@@ -344,6 +344,7 @@ $effect(() => {
|
||||
onMoveDown={() => handleMoveDown(block.id)}
|
||||
isFirst={i === 0}
|
||||
isLast={i === sortedBlocks.length - 1}
|
||||
source={block.source}
|
||||
/>
|
||||
</div>
|
||||
{/each}
|
||||
|
||||
@@ -278,6 +278,75 @@ async def train_model(
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/segtrain")
|
||||
async def segtrain_model(
|
||||
file: UploadFile,
|
||||
x_training_token: str | None = Header(default=None),
|
||||
):
|
||||
"""Fine-tune the blla segmentation model with uploaded PAGE XML training data.
|
||||
|
||||
Accepts a ZIP archive containing .png/.xml (PAGE XML) training pairs exported
|
||||
by the Java backend. Training mutates in-process model state — not safe
|
||||
if the service is replicated.
|
||||
"""
|
||||
_check_training_token(x_training_token)
|
||||
|
||||
if not _models_ready:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||
|
||||
zip_bytes = await file.read()
|
||||
|
||||
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
||||
log.info("Starting segmentation training run %s", training_run_id)
|
||||
|
||||
blla_model_path = os.environ.get("BLLA_MODEL_PATH", "/app/models/blla.mlmodel")
|
||||
|
||||
def _run_segtrain() -> dict:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
for entry in zf.namelist():
|
||||
_validate_zip_entry(entry, tmp_dir)
|
||||
zf.extractall(tmp_dir)
|
||||
|
||||
log.info("Extracted %d ZIP entries for segmentation training", len(os.listdir(tmp_dir)))
|
||||
|
||||
xml_files = glob.glob(os.path.join(tmp_dir, "*.xml"))
|
||||
if not xml_files:
|
||||
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
|
||||
|
||||
log.info("Training on %d PAGE XML files", len(xml_files))
|
||||
output_model_path = os.path.join(tmp_dir, "fine_tuned_blla.mlmodel")
|
||||
|
||||
from kraken import ketos
|
||||
result = ketos.segtrain(
|
||||
ground_truth=xml_files,
|
||||
load=blla_model_path if os.path.exists(blla_model_path) else None,
|
||||
output=output_model_path,
|
||||
format_type="path",
|
||||
)
|
||||
|
||||
epochs = getattr(result, "epochs", None) or 0
|
||||
loss = getattr(result, "best_loss", None)
|
||||
accuracy = getattr(result, "best_accuracy", None)
|
||||
|
||||
log.info("Segmentation training complete — epochs=%s loss=%s", epochs, loss)
|
||||
|
||||
if os.path.exists(blla_model_path):
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
backup_path = f"{blla_model_path}.{timestamp}.bak"
|
||||
shutil.copy2(blla_model_path, backup_path)
|
||||
_rotate_backups(blla_model_path, keep=3)
|
||||
|
||||
shutil.move(output_model_path, blla_model_path)
|
||||
log.info("Replaced blla model at %s", blla_model_path)
|
||||
|
||||
return {"loss": loss, "accuracy": accuracy, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_segtrain)
|
||||
return result
|
||||
|
||||
|
||||
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
||||
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
||||
_validate_url(url)
|
||||
|
||||
Reference in New Issue
Block a user