feat(ocr): German spell-check post-processing to reduce handwriting gibberish #260

Merged
marcel merged 10 commits from feat/issue-254-german-spell-check into main 2026-04-17 17:28:41 +02:00
8 changed files with 153904 additions and 9 deletions

View File

@@ -47,6 +47,26 @@ jobs:
name: unit-test-screenshots
path: frontend/test-results/screenshots/
# ─── OCR Service Unit Tests ───────────────────────────────────────────────────
# Only spell_check.py and test_confidence.py — no ML stack required.
ocr-tests:
name: OCR Service Tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install test dependencies
run: pip install "pyspellchecker==0.9.0" pytest
working-directory: ocr-service
- name: Run OCR unit tests (no ML stack required)
run: python -m pytest test_spell_check.py test_confidence.py -v
working-directory: ocr-service
# ─── Backend Unit & Slice Tests ───────────────────────────────────────────────
# Pure Mockito + WebMvcTest — no DB or S3 needed.
backend-unit-tests:

View File

@@ -6,6 +6,21 @@ THRESHOLD_DEFAULT = float(os.environ.get("OCR_CONFIDENCE_THRESHOLD", "0.3"))
THRESHOLD_KURRENT = float(os.environ.get("OCR_CONFIDENCE_THRESHOLD_KURRENT", "0.5"))
ILLEGIBLE_MARKER = "[unleserlich]"
CORRECTION_MARKER = "[?]"
def collapse_adjacent_markers(tokens: list[str]) -> list[str]:
collapsed: list[str] = []
prev_was_marker = False
for token in tokens:
if token == ILLEGIBLE_MARKER:
if not prev_was_marker:
collapsed.append(token)
prev_was_marker = True
else:
collapsed.append(token)
prev_was_marker = False
return collapsed
def get_threshold(script_type: str) -> float:
@@ -30,19 +45,14 @@ def apply_confidence_markers(words: list[dict], threshold: float | None = None)
if threshold is None:
threshold = THRESHOLD_DEFAULT
result: list[str] = []
prev_was_marker = False
tokens: list[str] = []
for word in words:
if word["confidence"] < threshold:
if not prev_was_marker:
result.append(ILLEGIBLE_MARKER)
prev_was_marker = True
tokens.append(ILLEGIBLE_MARKER)
else:
result.append(word["text"])
prev_was_marker = False
tokens.append(word["text"])
return " ".join(result)
return " ".join(collapse_adjacent_markers(tokens))
def words_from_characters(prediction: str, confidences: list[float]) -> list[dict]:

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,7 @@ from fastapi.responses import StreamingResponse
from PIL import Image
from confidence import apply_confidence_markers, get_threshold
from spell_check import correct_text, load_spell_checker
from engines import kraken as kraken_engine
from engines import surya as surya_engine
from models import OcrBlock, OcrRequest
@@ -39,6 +40,8 @@ ALLOWED_PDF_HOSTS = set(
h.strip() for h in os.getenv("ALLOWED_PDF_HOSTS", "minio,localhost,127.0.0.1").split(",")
)
_SPELL_CHECK_SCRIPT_TYPES = {"HANDWRITING_KURRENT", "HANDWRITING_LATIN"}
def _validate_url(url: str) -> None:
"""Validate that the PDF URL points to an allowed host (SSRF protection)."""
@@ -55,6 +58,7 @@ async def lifespan(app: FastAPI):
logger.info("Loading Kraken model at startup (Surya loads lazily on first OCR request)...")
kraken_engine.load_models()
load_spell_checker()
_models_ready = True
logger.info("Startup complete — ready to accept requests")
@@ -109,6 +113,8 @@ async def run_ocr(request: OcrRequest):
if block.get("words"):
block["text"] = apply_confidence_markers(block["words"], threshold)
block.pop("words", None)
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
block["text"] = correct_text(block["text"])
return [OcrBlock(**b) for b in blocks]
@@ -170,6 +176,8 @@ async def run_ocr_stream(request: OcrRequest):
engine.extract_region_text, image,
region.x, region.y, region.width, region.height,
)
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
text = correct_text(text)
blocks.append({
"pageNumber": page_idx,
"x": region.x,
@@ -231,6 +239,8 @@ async def run_ocr_stream(request: OcrRequest):
if block.get("words"):
block["text"] = apply_confidence_markers(block["words"], threshold)
block.pop("words", None)
if script_type in _SPELL_CHECK_SCRIPT_TYPES:
block["text"] = correct_text(block["text"])
total_blocks += len(blocks)
yield json.dumps({

View File

@@ -8,4 +8,5 @@ pillow>=10.2.0,<11.0.0
pypdfium2==4.30.0
pyvips>=2.2.0
httpx==0.28.1
pyspellchecker==0.9.0
opencv-python-headless==4.11.0.86

117
ocr-service/spell_check.py Normal file
View File

@@ -0,0 +1,117 @@
"""German spell-check post-processing for OCR output."""
import logging
import os
from spellchecker import SpellChecker
from confidence import CORRECTION_MARKER, ILLEGIBLE_MARKER, collapse_adjacent_markers
logger = logging.getLogger(__name__)
_MIN_SPELL_CHECK_LEN = 4
_spell: SpellChecker | None = None
def load_spell_checker() -> None:
"""Load German spell checker with supplementary historical wordlist.
Safe to call multiple times — no-op if already loaded.
"""
global _spell
if _spell is not None:
return
logger.info("Loading German spell checker...")
_spell = SpellChecker(language="de")
historical_path = os.path.join(os.path.dirname(__file__), "dictionaries", "de_historical.txt")
if os.path.exists(historical_path):
with open(historical_path, encoding="utf-8") as f:
words = [
line.strip()
for line in f
if line.strip() and not line.startswith("#")
]
_spell.word_frequency.load_words(words)
logger.info("Loaded %d historical German words", len(words))
else:
logger.warning("Historical German wordlist not found at %s", historical_path)
logger.info("German spell checker ready")
def _strip_punctuation(token: str) -> tuple[str, str, str]:
"""Split token into (leading_punct, word, trailing_punct).
'Word' characters are letters (including German umlauts) and digits.
Everything else is treated as punctuation.
"""
start = 0
while start < len(token) and not (token[start].isalpha() or token[start].isdigit()):
start += 1
end = len(token)
while end > start and not (token[end - 1].isalpha() or token[end - 1].isdigit()):
end -= 1
return token[:start], token[start:end], token[end:]
def _is_numeric(word: str) -> bool:
return any(c.isdigit() for c in word)
def correct_text(text: str) -> str:
"""Spell-check OCR text, correcting errors and marking gibberish as [unleserlich].
Already-present [unleserlich] tokens are preserved unchanged.
Words of fewer than 4 characters are exempt (particles, abbreviations).
Tokens containing digits pass through unchanged.
Adjacent [unleserlich] markers are collapsed into one.
Corrected tokens are marked with [?] (e.g. "Hauus""Haus[?]").
Args:
text: OCR output, possibly already containing [unleserlich] from confidence filtering.
Returns:
Corrected text with unresolvable words replaced by [unleserlich].
"""
if _spell is None:
raise RuntimeError("Spell checker not loaded — call load_spell_checker() first")
if not text.strip():
return text
tokens = text.split()
checked: list[str] = []
for token in tokens:
if token == ILLEGIBLE_MARKER:
checked.append(token)
continue
leading, word, trailing = _strip_punctuation(token)
if len(word) < _MIN_SPELL_CHECK_LEN:
checked.append(token)
continue
if _is_numeric(word):
checked.append(token)
continue
if _spell.known([word]):
checked.append(token)
continue
correction = _spell.correction(word)
if correction and _spell.word_frequency[correction] > 50: # strict > avoids non-determinism when candidates tie at the frequency floor
if word[0].isupper() and not correction[0].isupper():
correction = correction.capitalize()
checked.append(leading + correction + CORRECTION_MARKER + trailing)
else:
checked.append(ILLEGIBLE_MARKER)
return " ".join(collapse_adjacent_markers(checked))

View File

@@ -0,0 +1,96 @@
"""Tests for OCR spell-check post-processing."""
import pytest
from spell_check import correct_text, load_spell_checker
@pytest.fixture(autouse=True)
def ensure_loaded():
load_spell_checker()
def test_known_german_word_passes_through():
assert correct_text("Haus") == "Haus"
def test_obvious_gibberish_replaced_with_marker():
assert correct_text("xqzwrpvmk") == "[unleserlich]"
def test_short_word_exempt_from_check():
assert correct_text("im") == "im"
assert correct_text("der") == "der"
assert correct_text("zu") == "zu"
def test_unleserlich_marker_preserved():
assert correct_text("[unleserlich]") == "[unleserlich]"
def test_mixed_text_correct_and_gibberish():
result = correct_text("Haus xqzwrpvmk Garten")
assert result == "Haus [unleserlich] Garten"
def test_adjacent_gibberish_words_collapsed_to_one_marker():
result = correct_text("[unleserlich] xqzwrpvmk Haus")
assert result == "[unleserlich] Haus"
def test_empty_string_returns_empty():
assert correct_text("") == ""
def test_whitespace_only_returns_unchanged():
assert correct_text(" ") == " "
def test_existing_marker_not_doubled():
result = correct_text("[unleserlich] Haus [unleserlich]")
assert result == "[unleserlich] Haus [unleserlich]"
def test_historical_word_passes_through():
assert correct_text("Thür") == "Thür"
def test_correctable_ocr_error_gets_corrected():
result = correct_text("Hauus")
assert result != "Hauus"
assert result != "[unleserlich]"
assert "[?]" in result
assert result.startswith("Haus")
def test_sentence_with_multiple_corrections():
result = correct_text("Thür Hauus xqzwrpvmk Garten")
tokens = result.split()
assert tokens[0] == "Thür"
assert "[?]" in tokens[1] and tokens[1].startswith("Haus")
assert tokens[2] == "[unleserlich]"
assert tokens[3] == "Garten"
def test_capitalization_preserved_on_correction():
# "Freunnd" (doubled n) — OCR duplication error for "Freund"
result = correct_text("Freunnd")
assert result != "Freunnd"
assert result != "[unleserlich]"
assert result[0].isupper()
def test_raises_runtime_error_when_not_loaded(monkeypatch):
import spell_check as sc
monkeypatch.setattr(sc, "_spell", None)
with pytest.raises(RuntimeError, match="not loaded"):
correct_text("test")
def test_punctuation_attached_token_passes_through():
assert correct_text("Haus,") == "Haus,"
assert correct_text("Garten.") == "Garten."
def test_numeric_token_passes_through():
assert correct_text("1870er") == "1870er"
assert correct_text("18.") == "18."

View File

@@ -0,0 +1,90 @@
"""Prepare historical German wordlist from Deutsches Textarchiv (DTA) corpora.
Downloads two 19th-century DTA ZIPs (original spelling, not normalised),
tokenises all texts, counts frequencies, filters to alphabetic tokens of
length > 3 with at least --min-freq occurrences, and writes the result to
ocr-service/dictionaries/de_historical.txt sorted by frequency descending.
Usage:
python scripts/prepare_historical_dict.py
python scripts/prepare_historical_dict.py --min-freq 10 --out path/to/output.txt
"""
import argparse
import io
import re
import zipfile
from collections import Counter
from pathlib import Path
import urllib.request
DTA_URLS = [
"https://www.deutschestextarchiv.de/media/download/dtak/2020-10-23/original/1800-1899.zip",
"https://www.deutschestextarchiv.de/media/download/dtae/2020-10-23/original/1800-1899.zip",
]
_ALPHA_RE = re.compile(r"[a-zA-ZäöüÄÖÜß]+")
_MIN_WORD_LEN = 4
def _download(url: str) -> bytes:
print(f"Downloading {url} ...")
with urllib.request.urlopen(url, timeout=300) as resp:
data = resp.read()
print(f" Downloaded {len(data) // 1024 // 1024} MB")
return data
def _tokenise_zip(data: bytes) -> Counter:
counts: Counter = Counter()
with zipfile.ZipFile(io.BytesIO(data)) as zf:
names = [n for n in zf.namelist() if n.endswith(".txt")]
print(f" Tokenising {len(names)} text files ...")
for name in names:
try:
text = zf.read(name).decode("utf-8", errors="replace")
except Exception:
continue
for token in _ALPHA_RE.findall(text):
if len(token) >= _MIN_WORD_LEN:
counts[token] += 1
return counts
def main(min_freq: int, out_path: Path) -> None:
out_path.parent.mkdir(parents=True, exist_ok=True)
total: Counter = Counter()
for url in DTA_URLS:
try:
data = _download(url)
total += _tokenise_zip(data)
except Exception as exc:
print(f" WARNING: could not process {url}: {exc}")
filtered = {word: freq for word, freq in total.items() if freq >= min_freq}
print(f"Vocabulary after filtering (min_freq={min_freq}): {len(filtered):,} words")
header = (
"# Historical German wordlist derived from Deutsches Textarchiv (DTA)\n"
"# Corpora: dtak + dtae 1800-1899 original spelling\n"
f"# Filter: alphabetic, length >= {_MIN_WORD_LEN}, min_freq >= {min_freq}\n"
"# Sorted by frequency descending\n"
)
lines = [f"{word}" for word, _ in sorted(filtered.items(), key=lambda x: -x[1])]
out_path.write_text(header + "\n".join(lines) + "\n", encoding="utf-8")
print(f"Written {len(lines):,} words to {out_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Prepare DTA historical German wordlist")
parser.add_argument("--min-freq", type=int, default=5, help="Minimum token frequency")
parser.add_argument(
"--out",
type=Path,
default=Path(__file__).parent.parent / "ocr-service" / "dictionaries" / "de_historical.txt",
help="Output path",
)
args = parser.parse_args()
main(args.min_freq, args.out)