feat(ocr): German spell-check post-processing to reduce handwriting gibberish #260
@@ -47,6 +47,26 @@ jobs:
|
||||
name: unit-test-screenshots
|
||||
path: frontend/test-results/screenshots/
|
||||
|
||||
# ─── OCR Service Unit Tests ───────────────────────────────────────────────────
|
||||
# Only spell_check.py and test_confidence.py — no ML stack required.
|
||||
ocr-tests:
|
||||
name: OCR Service Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install test dependencies
|
||||
run: pip install "pyspellchecker==0.9.0" pytest
|
||||
working-directory: ocr-service
|
||||
|
||||
- name: Run OCR unit tests (no ML stack required)
|
||||
run: python -m pytest test_spell_check.py test_confidence.py -v
|
||||
working-directory: ocr-service
|
||||
|
||||
# ─── Backend Unit & Slice Tests ───────────────────────────────────────────────
|
||||
# Pure Mockito + WebMvcTest — no DB or S3 needed.
|
||||
backend-unit-tests:
|
||||
|
||||
@@ -6,6 +6,21 @@ THRESHOLD_DEFAULT = float(os.environ.get("OCR_CONFIDENCE_THRESHOLD", "0.3"))
|
||||
THRESHOLD_KURRENT = float(os.environ.get("OCR_CONFIDENCE_THRESHOLD_KURRENT", "0.5"))
|
||||
|
||||
ILLEGIBLE_MARKER = "[unleserlich]"
|
||||
CORRECTION_MARKER = "[?]"
|
||||
|
||||
|
||||
def collapse_adjacent_markers(tokens: list[str]) -> list[str]:
|
||||
collapsed: list[str] = []
|
||||
prev_was_marker = False
|
||||
for token in tokens:
|
||||
if token == ILLEGIBLE_MARKER:
|
||||
if not prev_was_marker:
|
||||
collapsed.append(token)
|
||||
prev_was_marker = True
|
||||
else:
|
||||
collapsed.append(token)
|
||||
prev_was_marker = False
|
||||
return collapsed
|
||||
|
||||
|
||||
def get_threshold(script_type: str) -> float:
|
||||
@@ -30,19 +45,14 @@ def apply_confidence_markers(words: list[dict], threshold: float | None = None)
|
||||
if threshold is None:
|
||||
threshold = THRESHOLD_DEFAULT
|
||||
|
||||
result: list[str] = []
|
||||
prev_was_marker = False
|
||||
|
||||
tokens: list[str] = []
|
||||
for word in words:
|
||||
if word["confidence"] < threshold:
|
||||
if not prev_was_marker:
|
||||
result.append(ILLEGIBLE_MARKER)
|
||||
prev_was_marker = True
|
||||
tokens.append(ILLEGIBLE_MARKER)
|
||||
else:
|
||||
result.append(word["text"])
|
||||
prev_was_marker = False
|
||||
tokens.append(word["text"])
|
||||
|
||||
return " ".join(result)
|
||||
return " ".join(collapse_adjacent_markers(tokens))
|
||||
|
||||
|
||||
def words_from_characters(prediction: str, confidences: list[float]) -> list[dict]:
|
||||
|
||||
153551
ocr-service/dictionaries/de_historical.txt
Normal file
153551
ocr-service/dictionaries/de_historical.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,7 @@ from fastapi.responses import StreamingResponse
|
||||
from PIL import Image
|
||||
|
||||
from confidence import apply_confidence_markers, get_threshold
|
||||
from spell_check import correct_text, load_spell_checker
|
||||
from engines import kraken as kraken_engine
|
||||
from engines import surya as surya_engine
|
||||
from models import OcrBlock, OcrRequest
|
||||
@@ -39,6 +40,8 @@ ALLOWED_PDF_HOSTS = set(
|
||||
h.strip() for h in os.getenv("ALLOWED_PDF_HOSTS", "minio,localhost,127.0.0.1").split(",")
|
||||
)
|
||||
|
||||
_SPELL_CHECK_SCRIPT_TYPES = {"HANDWRITING_KURRENT", "HANDWRITING_LATIN"}
|
||||
|
||||
|
||||
def _validate_url(url: str) -> None:
|
||||
"""Validate that the PDF URL points to an allowed host (SSRF protection)."""
|
||||
@@ -55,6 +58,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
logger.info("Loading Kraken model at startup (Surya loads lazily on first OCR request)...")
|
||||
kraken_engine.load_models()
|
||||
load_spell_checker()
|
||||
_models_ready = True
|
||||
logger.info("Startup complete — ready to accept requests")
|
||||
|
||||
@@ -109,6 +113,8 @@ async def run_ocr(request: OcrRequest):
|
||||
if block.get("words"):
|
||||
block["text"] = apply_confidence_markers(block["words"], threshold)
|
||||
block.pop("words", None)
|
||||
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
|
||||
block["text"] = correct_text(block["text"])
|
||||
|
||||
return [OcrBlock(**b) for b in blocks]
|
||||
|
||||
@@ -170,6 +176,8 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
engine.extract_region_text, image,
|
||||
region.x, region.y, region.width, region.height,
|
||||
)
|
||||
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
|
||||
text = correct_text(text)
|
||||
blocks.append({
|
||||
"pageNumber": page_idx,
|
||||
"x": region.x,
|
||||
@@ -231,6 +239,8 @@ async def run_ocr_stream(request: OcrRequest):
|
||||
if block.get("words"):
|
||||
block["text"] = apply_confidence_markers(block["words"], threshold)
|
||||
block.pop("words", None)
|
||||
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
|
||||
block["text"] = correct_text(block["text"])
|
||||
|
||||
total_blocks += len(blocks)
|
||||
yield json.dumps({
|
||||
|
||||
@@ -8,4 +8,5 @@ pillow>=10.2.0,<11.0.0
|
||||
pypdfium2==4.30.0
|
||||
pyvips>=2.2.0
|
||||
httpx==0.28.1
|
||||
pyspellchecker==0.9.0
|
||||
opencv-python-headless==4.11.0.86
|
||||
|
||||
117
ocr-service/spell_check.py
Normal file
117
ocr-service/spell_check.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""German spell-check post-processing for OCR output."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from spellchecker import SpellChecker
|
||||
|
||||
from confidence import CORRECTION_MARKER, ILLEGIBLE_MARKER, collapse_adjacent_markers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MIN_SPELL_CHECK_LEN = 4
|
||||
|
||||
_spell: SpellChecker | None = None
|
||||
|
||||
|
||||
def load_spell_checker() -> None:
|
||||
"""Load German spell checker with supplementary historical wordlist.
|
||||
|
||||
Safe to call multiple times — no-op if already loaded.
|
||||
"""
|
||||
global _spell
|
||||
if _spell is not None:
|
||||
return
|
||||
|
||||
logger.info("Loading German spell checker...")
|
||||
_spell = SpellChecker(language="de")
|
||||
|
||||
historical_path = os.path.join(os.path.dirname(__file__), "dictionaries", "de_historical.txt")
|
||||
if os.path.exists(historical_path):
|
||||
with open(historical_path, encoding="utf-8") as f:
|
||||
words = [
|
||||
line.strip()
|
||||
for line in f
|
||||
if line.strip() and not line.startswith("#")
|
||||
]
|
||||
_spell.word_frequency.load_words(words)
|
||||
logger.info("Loaded %d historical German words", len(words))
|
||||
else:
|
||||
logger.warning("Historical German wordlist not found at %s", historical_path)
|
||||
|
||||
logger.info("German spell checker ready")
|
||||
|
||||
|
||||
def _strip_punctuation(token: str) -> tuple[str, str, str]:
|
||||
"""Split token into (leading_punct, word, trailing_punct).
|
||||
|
||||
'Word' characters are letters (including German umlauts) and digits.
|
||||
Everything else is treated as punctuation.
|
||||
"""
|
||||
start = 0
|
||||
while start < len(token) and not (token[start].isalpha() or token[start].isdigit()):
|
||||
start += 1
|
||||
|
||||
end = len(token)
|
||||
while end > start and not (token[end - 1].isalpha() or token[end - 1].isdigit()):
|
||||
end -= 1
|
||||
|
||||
return token[:start], token[start:end], token[end:]
|
||||
|
||||
|
||||
def _is_numeric(word: str) -> bool:
|
||||
return any(c.isdigit() for c in word)
|
||||
|
||||
|
||||
def correct_text(text: str) -> str:
|
||||
"""Spell-check OCR text, correcting errors and marking gibberish as [unleserlich].
|
||||
|
||||
Already-present [unleserlich] tokens are preserved unchanged.
|
||||
Words of fewer than 4 characters are exempt (particles, abbreviations).
|
||||
Tokens containing digits pass through unchanged.
|
||||
Adjacent [unleserlich] markers are collapsed into one.
|
||||
Corrected tokens are marked with [?] (e.g. "Hauus" → "Haus[?]").
|
||||
|
||||
Args:
|
||||
text: OCR output, possibly already containing [unleserlich] from confidence filtering.
|
||||
|
||||
Returns:
|
||||
Corrected text with unresolvable words replaced by [unleserlich].
|
||||
"""
|
||||
if _spell is None:
|
||||
raise RuntimeError("Spell checker not loaded — call load_spell_checker() first")
|
||||
|
||||
if not text.strip():
|
||||
return text
|
||||
|
||||
tokens = text.split()
|
||||
checked: list[str] = []
|
||||
|
||||
for token in tokens:
|
||||
if token == ILLEGIBLE_MARKER:
|
||||
checked.append(token)
|
||||
continue
|
||||
|
||||
leading, word, trailing = _strip_punctuation(token)
|
||||
|
||||
if len(word) < _MIN_SPELL_CHECK_LEN:
|
||||
checked.append(token)
|
||||
continue
|
||||
|
||||
if _is_numeric(word):
|
||||
checked.append(token)
|
||||
continue
|
||||
|
||||
if _spell.known([word]):
|
||||
checked.append(token)
|
||||
continue
|
||||
|
||||
correction = _spell.correction(word)
|
||||
if correction and _spell.word_frequency[correction] > 50: # strict > avoids non-determinism when candidates tie at the frequency floor
|
||||
if word[0].isupper() and not correction[0].isupper():
|
||||
correction = correction.capitalize()
|
||||
checked.append(leading + correction + CORRECTION_MARKER + trailing)
|
||||
else:
|
||||
checked.append(ILLEGIBLE_MARKER)
|
||||
|
||||
return " ".join(collapse_adjacent_markers(checked))
|
||||
96
ocr-service/test_spell_check.py
Normal file
96
ocr-service/test_spell_check.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Tests for OCR spell-check post-processing."""
|
||||
|
||||
import pytest
|
||||
from spell_check import correct_text, load_spell_checker
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ensure_loaded():
|
||||
load_spell_checker()
|
||||
|
||||
|
||||
def test_known_german_word_passes_through():
|
||||
assert correct_text("Haus") == "Haus"
|
||||
|
||||
|
||||
def test_obvious_gibberish_replaced_with_marker():
|
||||
assert correct_text("xqzwrpvmk") == "[unleserlich]"
|
||||
|
||||
|
||||
def test_short_word_exempt_from_check():
|
||||
assert correct_text("im") == "im"
|
||||
assert correct_text("der") == "der"
|
||||
assert correct_text("zu") == "zu"
|
||||
|
||||
|
||||
def test_unleserlich_marker_preserved():
|
||||
assert correct_text("[unleserlich]") == "[unleserlich]"
|
||||
|
||||
|
||||
def test_mixed_text_correct_and_gibberish():
|
||||
result = correct_text("Haus xqzwrpvmk Garten")
|
||||
assert result == "Haus [unleserlich] Garten"
|
||||
|
||||
|
||||
def test_adjacent_gibberish_words_collapsed_to_one_marker():
|
||||
result = correct_text("[unleserlich] xqzwrpvmk Haus")
|
||||
assert result == "[unleserlich] Haus"
|
||||
|
||||
|
||||
def test_empty_string_returns_empty():
|
||||
assert correct_text("") == ""
|
||||
|
||||
|
||||
def test_whitespace_only_returns_unchanged():
|
||||
assert correct_text(" ") == " "
|
||||
|
||||
|
||||
def test_existing_marker_not_doubled():
|
||||
result = correct_text("[unleserlich] Haus [unleserlich]")
|
||||
assert result == "[unleserlich] Haus [unleserlich]"
|
||||
|
||||
|
||||
def test_historical_word_passes_through():
|
||||
assert correct_text("Thür") == "Thür"
|
||||
|
||||
|
||||
def test_correctable_ocr_error_gets_corrected():
|
||||
result = correct_text("Hauus")
|
||||
assert result != "Hauus"
|
||||
assert result != "[unleserlich]"
|
||||
assert "[?]" in result
|
||||
assert result.startswith("Haus")
|
||||
|
||||
|
||||
def test_sentence_with_multiple_corrections():
|
||||
result = correct_text("Thür Hauus xqzwrpvmk Garten")
|
||||
tokens = result.split()
|
||||
assert tokens[0] == "Thür"
|
||||
assert "[?]" in tokens[1] and tokens[1].startswith("Haus")
|
||||
assert tokens[2] == "[unleserlich]"
|
||||
assert tokens[3] == "Garten"
|
||||
|
||||
|
||||
def test_capitalization_preserved_on_correction():
|
||||
# "Freunnd" (doubled n) — OCR duplication error for "Freund"
|
||||
result = correct_text("Freunnd")
|
||||
assert result != "Freunnd"
|
||||
assert result != "[unleserlich]"
|
||||
assert result[0].isupper()
|
||||
|
||||
|
||||
def test_raises_runtime_error_when_not_loaded(monkeypatch):
|
||||
import spell_check as sc
|
||||
monkeypatch.setattr(sc, "_spell", None)
|
||||
with pytest.raises(RuntimeError, match="not loaded"):
|
||||
correct_text("test")
|
||||
|
||||
|
||||
def test_punctuation_attached_token_passes_through():
|
||||
assert correct_text("Haus,") == "Haus,"
|
||||
assert correct_text("Garten.") == "Garten."
|
||||
|
||||
|
||||
def test_numeric_token_passes_through():
|
||||
assert correct_text("1870er") == "1870er"
|
||||
assert correct_text("18.") == "18."
|
||||
90
scripts/prepare_historical_dict.py
Normal file
90
scripts/prepare_historical_dict.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Prepare historical German wordlist from Deutsches Textarchiv (DTA) corpora.
|
||||
|
||||
Downloads two 19th-century DTA ZIPs (original spelling, not normalised),
|
||||
tokenises all texts, counts frequencies, filters to alphabetic tokens of
|
||||
length > 3 with at least --min-freq occurrences, and writes the result to
|
||||
ocr-service/dictionaries/de_historical.txt sorted by frequency descending.
|
||||
|
||||
Usage:
|
||||
python scripts/prepare_historical_dict.py
|
||||
python scripts/prepare_historical_dict.py --min-freq 10 --out path/to/output.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import re
|
||||
import zipfile
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import urllib.request
|
||||
|
||||
DTA_URLS = [
|
||||
"https://www.deutschestextarchiv.de/media/download/dtak/2020-10-23/original/1800-1899.zip",
|
||||
"https://www.deutschestextarchiv.de/media/download/dtae/2020-10-23/original/1800-1899.zip",
|
||||
]
|
||||
|
||||
_ALPHA_RE = re.compile(r"[a-zA-ZäöüÄÖÜß]+")
|
||||
_MIN_WORD_LEN = 4
|
||||
|
||||
|
||||
def _download(url: str) -> bytes:
|
||||
print(f"Downloading {url} ...")
|
||||
with urllib.request.urlopen(url, timeout=300) as resp:
|
||||
data = resp.read()
|
||||
print(f" Downloaded {len(data) // 1024 // 1024} MB")
|
||||
return data
|
||||
|
||||
|
||||
def _tokenise_zip(data: bytes) -> Counter:
|
||||
counts: Counter = Counter()
|
||||
with zipfile.ZipFile(io.BytesIO(data)) as zf:
|
||||
names = [n for n in zf.namelist() if n.endswith(".txt")]
|
||||
print(f" Tokenising {len(names)} text files ...")
|
||||
for name in names:
|
||||
try:
|
||||
text = zf.read(name).decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
continue
|
||||
for token in _ALPHA_RE.findall(text):
|
||||
if len(token) >= _MIN_WORD_LEN:
|
||||
counts[token] += 1
|
||||
return counts
|
||||
|
||||
|
||||
def main(min_freq: int, out_path: Path) -> None:
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
total: Counter = Counter()
|
||||
|
||||
for url in DTA_URLS:
|
||||
try:
|
||||
data = _download(url)
|
||||
total += _tokenise_zip(data)
|
||||
except Exception as exc:
|
||||
print(f" WARNING: could not process {url}: {exc}")
|
||||
|
||||
filtered = {word: freq for word, freq in total.items() if freq >= min_freq}
|
||||
print(f"Vocabulary after filtering (min_freq={min_freq}): {len(filtered):,} words")
|
||||
|
||||
header = (
|
||||
"# Historical German wordlist derived from Deutsches Textarchiv (DTA)\n"
|
||||
"# Corpora: dtak + dtae 1800-1899 original spelling\n"
|
||||
f"# Filter: alphabetic, length >= {_MIN_WORD_LEN}, min_freq >= {min_freq}\n"
|
||||
"# Sorted by frequency descending\n"
|
||||
)
|
||||
lines = [f"{word}" for word, _ in sorted(filtered.items(), key=lambda x: -x[1])]
|
||||
out_path.write_text(header + "\n".join(lines) + "\n", encoding="utf-8")
|
||||
print(f"Written {len(lines):,} words to {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Prepare DTA historical German wordlist")
|
||||
parser.add_argument("--min-freq", type=int, default=5, help="Minimum token frequency")
|
||||
parser.add_argument(
|
||||
"--out",
|
||||
type=Path,
|
||||
default=Path(__file__).parent.parent / "ocr-service" / "dictionaries" / "de_historical.txt",
|
||||
help="Output path",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.min_freq, args.out)
|
||||
Reference in New Issue
Block a user