feat(training): add segmentation training pipeline and complete Part 6
- Add /segtrain endpoint to OCR service (ZIP upload, ketos.segtrain, backup rotation, in-process model reload) - Add segtrainModel() to OcrClient and RestClientOcrClient (10-min timeout, X-Training-Token header) - Add SegmentationTrainingExportService: PAGE XML export with polygon de-normalization and per-page PNG rendering via PDFBox - Add GET /api/ocr/segmentation-training-data/export endpoint - Make TranscriptionBlock.text nullable for segmentation-only blocks (V31 migration) - Add Paraglide i18n translation keys for all training UI strings (de/en/es) - Pass source prop from TranscriptionEditView to TranscriptionBlock Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -278,6 +278,75 @@ async def train_model(
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/segtrain")
|
||||
async def segtrain_model(
|
||||
file: UploadFile,
|
||||
x_training_token: str | None = Header(default=None),
|
||||
):
|
||||
"""Fine-tune the blla segmentation model with uploaded PAGE XML training data.
|
||||
|
||||
Accepts a ZIP archive containing .png/.xml (PAGE XML) training pairs exported
|
||||
by the Java backend. Training mutates in-process model state — not safe
|
||||
if the service is replicated.
|
||||
"""
|
||||
_check_training_token(x_training_token)
|
||||
|
||||
if not _models_ready:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
||||
|
||||
zip_bytes = await file.read()
|
||||
|
||||
training_run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
log = logging.LoggerAdapter(logger, {"training_run_id": training_run_id})
|
||||
log.info("Starting segmentation training run %s", training_run_id)
|
||||
|
||||
blla_model_path = os.environ.get("BLLA_MODEL_PATH", "/app/models/blla.mlmodel")
|
||||
|
||||
def _run_segtrain() -> dict:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
||||
for entry in zf.namelist():
|
||||
_validate_zip_entry(entry, tmp_dir)
|
||||
zf.extractall(tmp_dir)
|
||||
|
||||
log.info("Extracted %d ZIP entries for segmentation training", len(os.listdir(tmp_dir)))
|
||||
|
||||
xml_files = glob.glob(os.path.join(tmp_dir, "*.xml"))
|
||||
if not xml_files:
|
||||
raise HTTPException(status_code=422, detail="No PAGE XML files found in ZIP")
|
||||
|
||||
log.info("Training on %d PAGE XML files", len(xml_files))
|
||||
output_model_path = os.path.join(tmp_dir, "fine_tuned_blla.mlmodel")
|
||||
|
||||
from kraken import ketos
|
||||
result = ketos.segtrain(
|
||||
ground_truth=xml_files,
|
||||
load=blla_model_path if os.path.exists(blla_model_path) else None,
|
||||
output=output_model_path,
|
||||
format_type="path",
|
||||
)
|
||||
|
||||
epochs = getattr(result, "epochs", None) or 0
|
||||
loss = getattr(result, "best_loss", None)
|
||||
accuracy = getattr(result, "best_accuracy", None)
|
||||
|
||||
log.info("Segmentation training complete — epochs=%s loss=%s", epochs, loss)
|
||||
|
||||
if os.path.exists(blla_model_path):
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
backup_path = f"{blla_model_path}.{timestamp}.bak"
|
||||
shutil.copy2(blla_model_path, backup_path)
|
||||
_rotate_backups(blla_model_path, keep=3)
|
||||
|
||||
shutil.move(output_model_path, blla_model_path)
|
||||
log.info("Replaced blla model at %s", blla_model_path)
|
||||
|
||||
return {"loss": loss, "accuracy": accuracy, "epochs": epochs}
|
||||
|
||||
result = await asyncio.to_thread(_run_segtrain)
|
||||
return result
|
||||
|
||||
|
||||
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
||||
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
||||
_validate_url(url)
|
||||
|
||||
Reference in New Issue
Block a user