diff --git a/docker-compose.yml b/docker-compose.yml index 5e88f381..6d68e1aa 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -85,6 +85,7 @@ services: environment: KRAKEN_MODEL_PATH: /app/models/german_kurrent.mlmodel OCR_CONFIDENCE_THRESHOLD: "0.3" + OCR_CONFIDENCE_THRESHOLD_KURRENT: "0.5" networks: - archive-net healthcheck: diff --git a/ocr-service/confidence.py b/ocr-service/confidence.py index 092c2892..e331443f 100644 --- a/ocr-service/confidence.py +++ b/ocr-service/confidence.py @@ -2,16 +2,24 @@ import os -CONFIDENCE_THRESHOLD = float(os.environ.get("OCR_CONFIDENCE_THRESHOLD", "0.3")) +THRESHOLD_DEFAULT = float(os.environ.get("OCR_CONFIDENCE_THRESHOLD", "0.3")) +THRESHOLD_KURRENT = float(os.environ.get("OCR_CONFIDENCE_THRESHOLD_KURRENT", "0.5")) ILLEGIBLE_MARKER = "[unleserlich]" -def apply_confidence_markers(words: list[dict]) -> str: +def get_threshold(script_type: str) -> float: + if script_type and script_type.upper() == "HANDWRITING_KURRENT": + return THRESHOLD_KURRENT + return THRESHOLD_DEFAULT + + +def apply_confidence_markers(words: list[dict], threshold: float | None = None) -> str: """Replace low-confidence words with [unleserlich], collapsing adjacent markers. Args: words: list of {"text": str, "confidence": float} dicts + threshold: confidence threshold (uses THRESHOLD_DEFAULT if None) Returns: Reconstructed text string with [unleserlich] substitutions. @@ -19,11 +27,14 @@ def apply_confidence_markers(words: list[dict]) -> str: if not words: return "" + if threshold is None: + threshold = THRESHOLD_DEFAULT + result: list[str] = [] prev_was_marker = False for word in words: - if word["confidence"] < CONFIDENCE_THRESHOLD: + if word["confidence"] < threshold: if not prev_was_marker: result.append(ILLEGIBLE_MARKER) prev_was_marker = True diff --git a/ocr-service/main.py b/ocr-service/main.py index f87985e6..b1766516 100644 --- a/ocr-service/main.py +++ b/ocr-service/main.py @@ -9,7 +9,7 @@ import pypdfium2 as pdfium from fastapi import FastAPI, HTTPException from PIL import Image -from confidence import apply_confidence_markers +from confidence import apply_confidence_markers, get_threshold from engines import kraken as kraken_engine from engines import surya as surya_engine from models import OcrBlock, OcrRequest @@ -72,9 +72,10 @@ async def run_ocr(request: OcrRequest): # TYPEWRITER, HANDWRITING_LATIN, UNKNOWN — all use Surya blocks = surya_engine.extract_blocks(images, request.language) + threshold = get_threshold(script_type) for block in blocks: if block.get("words"): - block["text"] = apply_confidence_markers(block["words"]) + block["text"] = apply_confidence_markers(block["words"], threshold) block.pop("words", None) return [OcrBlock(**b) for b in blocks] diff --git a/ocr-service/test_confidence.py b/ocr-service/test_confidence.py index e1359eb1..c89ade9b 100644 --- a/ocr-service/test_confidence.py +++ b/ocr-service/test_confidence.py @@ -2,7 +2,7 @@ import os import pytest -from confidence import apply_confidence_markers, words_from_characters +from confidence import apply_confidence_markers, words_from_characters, get_threshold # ─── apply_confidence_markers ───────────────────────────────────────────────── @@ -75,22 +75,29 @@ def test_just_below_threshold_replaced(): assert apply_confidence_markers(words) == "[unleserlich]" -def test_custom_threshold_via_env(monkeypatch): - monkeypatch.setenv("OCR_CONFIDENCE_THRESHOLD", "0.8") - # Need to reload the module to pick up the new env var - import importlib - import confidence - importlib.reload(confidence) - +def test_custom_threshold_via_parameter(): words = [ {"text": "Lieber", "confidence": 0.95}, {"text": "Freund", "confidence": 0.5}, ] - assert confidence.apply_confidence_markers(words) == "Lieber [unleserlich]" + assert apply_confidence_markers(words, threshold=0.8) == "Lieber [unleserlich]" + assert apply_confidence_markers(words, threshold=0.3) == "Lieber Freund" - # Reset - monkeypatch.setenv("OCR_CONFIDENCE_THRESHOLD", "0.3") - importlib.reload(confidence) + +def test_kurrent_threshold_is_higher_than_default(): + default = get_threshold("TYPEWRITER") + kurrent = get_threshold("HANDWRITING_KURRENT") + assert kurrent > default + + +def test_get_threshold_kurrent(): + assert get_threshold("HANDWRITING_KURRENT") == 0.5 + + +def test_get_threshold_default(): + assert get_threshold("TYPEWRITER") == 0.3 + assert get_threshold("HANDWRITING_LATIN") == 0.3 + assert get_threshold("UNKNOWN") == 0.3 def test_low_confidence_at_start():