Also adds regex year-fallback in extract_dates() for de/es spaCy small models that don't tag bare 4-digit years as DATE entities, and widens the direction-token window to 2 tokens back to handle Spanish "antes de". Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
246 lines
7.7 KiB
Python
246 lines
7.7 KiB
Python
from __future__ import annotations
|
||
|
||
import re
|
||
from datetime import date
|
||
|
||
import dateparser
|
||
import spacy
|
||
from spacy.language import Language
|
||
|
||
from models import ParseResponse
|
||
|
||
# ── Language model registry ──────────────────────────────────────────────────
|
||
|
||
_MODEL_NAMES: dict[str, str] = {
|
||
"de": "de_core_news_sm",
|
||
"en": "en_core_web_sm",
|
||
"es": "es_core_news_sm",
|
||
}
|
||
|
||
_nlp_cache: dict[str, Language] = {}
|
||
|
||
|
||
def get_nlp(lang: str) -> Language:
|
||
if lang not in _MODEL_NAMES:
|
||
raise ValueError(f"Unsupported language: {lang!r}. Valid: {list(_MODEL_NAMES)}")
|
||
if lang not in _nlp_cache:
|
||
_nlp_cache[lang] = spacy.load(_MODEL_NAMES[lang])
|
||
return _nlp_cache[lang]
|
||
|
||
|
||
def load_all_models() -> None:
|
||
for lang in _MODEL_NAMES:
|
||
get_nlp(lang)
|
||
|
||
|
||
# ── Step 1: Person name extraction ──────────────────────────────────────────
|
||
|
||
def extract_person_names(doc) -> list[str]:
|
||
"""Return PER entity texts in left-to-right span order."""
|
||
return [ent.text for ent in doc.ents if ent.label_ == "PER"]
|
||
|
||
|
||
# ── Step 2: Role detection ───────────────────────────────────────────────────
|
||
|
||
_SENDER_PREPS: dict[str, frozenset[str]] = {
|
||
"de": frozenset({"von", "vom"}),
|
||
"en": frozenset({"from", "by"}),
|
||
"es": frozenset({"de", "por"}),
|
||
}
|
||
|
||
_RECEIVER_PREPS: dict[str, frozenset[str]] = {
|
||
"de": frozenset({"an", "nach", "für"}),
|
||
"en": frozenset({"to", "for"}),
|
||
"es": frozenset({"para", "a"}),
|
||
}
|
||
|
||
|
||
def detect_person_role(doc, per_spans: list, lang: str) -> str:
|
||
"""Return 'sender', 'receiver', or 'any'.
|
||
|
||
Only meaningful for single-PER queries — two-person queries always return
|
||
'any' because Java derives direction from list position.
|
||
"""
|
||
if len(per_spans) != 1:
|
||
return "any"
|
||
|
||
span = per_spans[0]
|
||
root = span.root
|
||
sender = _SENDER_PREPS[lang]
|
||
receiver = _RECEIVER_PREPS[lang]
|
||
|
||
# Primary: dependency-tree children of the PER root
|
||
for child in root.children:
|
||
if child.dep_ in ("case", "prep", "mo"):
|
||
if child.lower_ in sender:
|
||
return "sender"
|
||
if child.lower_ in receiver:
|
||
return "receiver"
|
||
|
||
# Fallback: token immediately before the span start
|
||
if span.start > 0:
|
||
prev = doc[span.start - 1]
|
||
if prev.lower_ in sender:
|
||
return "sender"
|
||
if prev.lower_ in receiver:
|
||
return "receiver"
|
||
|
||
return "any"
|
||
|
||
|
||
# ── Step 3: Date parsing ─────────────────────────────────────────────────────
|
||
|
||
_YEAR_RE = re.compile(r"^\d{4}$")
|
||
|
||
_DATE_BEFORE: dict[str, frozenset[str]] = {
|
||
"de": frozenset({"vor"}),
|
||
"en": frozenset({"before"}),
|
||
"es": frozenset({"antes"}),
|
||
}
|
||
|
||
_DATE_AFTER: dict[str, frozenset[str]] = {
|
||
"de": frozenset({"nach"}),
|
||
"en": frozenset({"after"}),
|
||
"es": frozenset({"después", "despues"}),
|
||
}
|
||
|
||
_DATE_BETWEEN: dict[str, frozenset[str]] = {
|
||
"de": frozenset({"zwischen"}),
|
||
"en": frozenset({"between"}),
|
||
"es": frozenset({"entre"}),
|
||
}
|
||
|
||
|
||
def _parse_date_text(text: str, lang: str) -> date | None:
|
||
text = text.strip()
|
||
if _YEAR_RE.match(text):
|
||
year = int(text)
|
||
if 1000 < year < 3000:
|
||
return date(year, 1, 1)
|
||
parsed = dateparser.parse(
|
||
text,
|
||
languages=[lang],
|
||
settings={"PREFER_DAY_OF_MONTH": "first", "RETURN_AS_TIMEZONE_AWARE": False},
|
||
)
|
||
return parsed.date() if parsed else None
|
||
|
||
|
||
def _year_end(d: date) -> date:
|
||
"""If d is Jan 1, return Dec 31 of the same year (year-only boundary)."""
|
||
if d.month == 1 and d.day == 1:
|
||
return date(d.year, 12, 31)
|
||
return d
|
||
|
||
|
||
def _find_year_spans(doc) -> list:
|
||
"""Fallback: find tokens that look like 4-digit years (1000–2999) when NER
|
||
produces no DATE entities. Returns a list of single-token pseudo-spans
|
||
(spaCy Span objects) labelled 'DATE'."""
|
||
spans = []
|
||
for token in doc:
|
||
if _YEAR_RE.match(token.text):
|
||
year = int(token.text)
|
||
if 1000 < year < 3000:
|
||
span = doc[token.i : token.i + 1]
|
||
spans.append(span)
|
||
return spans
|
||
|
||
|
||
def extract_dates(doc, lang: str) -> tuple[str | None, str | None]:
|
||
"""Return (date_from, date_to) as ISO strings or None."""
|
||
date_spans = [ent for ent in doc.ents if ent.label_ == "DATE"]
|
||
|
||
# Fallback: some spaCy small models (de, es) don't tag bare years as DATE
|
||
if not date_spans:
|
||
date_spans = _find_year_spans(doc)
|
||
|
||
if not date_spans:
|
||
return None, None
|
||
|
||
between_tokens = _DATE_BETWEEN[lang]
|
||
before_tokens = _DATE_BEFORE[lang]
|
||
after_tokens = _DATE_AFTER[lang]
|
||
|
||
# "zwischen X und Y" / "between X and Y" — two DATE spans form a range
|
||
has_between = any(tok.lower_ in between_tokens for tok in doc)
|
||
if has_between and len(date_spans) >= 2:
|
||
parsed = []
|
||
for span in date_spans[:2]:
|
||
d = _parse_date_text(span.text, lang)
|
||
if d:
|
||
parsed.append(d)
|
||
if len(parsed) == 2:
|
||
parsed.sort()
|
||
return parsed[0].isoformat(), _year_end(parsed[1]).isoformat()
|
||
|
||
# Single DATE span — use direction token
|
||
span = date_spans[0]
|
||
d = _parse_date_text(span.text, lang)
|
||
if not d:
|
||
return None, None
|
||
|
||
# Check up to 2 tokens before the date span to handle multi-word prepositions
|
||
# like Spanish "antes de 1920" where the keyword is 2 tokens back.
|
||
prev_tokens = [
|
||
doc[span.start - i].lower_
|
||
for i in range(1, min(3, span.start + 1))
|
||
]
|
||
|
||
if any(t in before_tokens for t in prev_tokens):
|
||
return None, _year_end(d).isoformat()
|
||
if any(t in after_tokens for t in prev_tokens):
|
||
return d.isoformat(), None
|
||
# Bare year/date — closed year-range
|
||
return d.isoformat(), _year_end(d).isoformat()
|
||
|
||
|
||
# ── Step 4: Keyword extraction ───────────────────────────────────────────────
|
||
|
||
def extract_keywords(doc, excluded_spans: list) -> list[str]:
|
||
"""Return lowercased lemmas of content words not inside any NER span."""
|
||
excluded_indices: set[int] = set()
|
||
for span in excluded_spans:
|
||
excluded_indices.update(range(span.start, span.end))
|
||
|
||
seen: set[str] = set()
|
||
keywords: list[str] = []
|
||
for token in doc:
|
||
if token.i in excluded_indices:
|
||
continue
|
||
if token.pos_ not in ("NOUN", "PROPN"):
|
||
continue
|
||
if token.is_stop:
|
||
continue
|
||
lemma = token.lemma_.lower()
|
||
if len(lemma) < 3:
|
||
continue
|
||
if lemma not in seen:
|
||
seen.add(lemma)
|
||
keywords.append(lemma)
|
||
|
||
return keywords
|
||
|
||
|
||
# ── Step 5: Assembly ─────────────────────────────────────────────────────────
|
||
|
||
def extract(query: str, lang: str) -> ParseResponse:
|
||
"""Run the full NLP pipeline and return a ParseResponse."""
|
||
nlp = get_nlp(lang)
|
||
doc = nlp(query)
|
||
|
||
per_spans = [ent for ent in doc.ents if ent.label_ == "PER"]
|
||
|
||
person_names = extract_person_names(doc)
|
||
person_role = detect_person_role(doc, per_spans, lang)
|
||
date_from, date_to = extract_dates(doc, lang)
|
||
keywords = extract_keywords(doc, list(doc.ents))
|
||
|
||
return ParseResponse(
|
||
personNames=person_names,
|
||
personRole=person_role,
|
||
dateFrom=date_from,
|
||
dateTo=date_to,
|
||
keywords=keywords,
|
||
rawQuery=query,
|
||
)
|