diff --git a/ocr-service/Dockerfile b/ocr-service/Dockerfile index 24f74be0..57368726 100644 --- a/ocr-service/Dockerfile +++ b/ocr-service/Dockerfile @@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ # PyTorch CPU-only — separate layer; the whl/cpu index strips all CUDA variants (~2 GB saved) RUN pip install --no-cache-dir \ - torch==2.5.1 \ + torch==2.7.1 \ --index-url https://download.pytorch.org/whl/cpu COPY requirements.txt . diff --git a/ocr-service/engines/surya.py b/ocr-service/engines/surya.py index c6cc7768..77a895d9 100644 --- a/ocr-service/engines/surya.py +++ b/ocr-service/engines/surya.py @@ -4,28 +4,23 @@ import logging logger = logging.getLogger(__name__) -# Lazy-loaded at startup via load_models() -_recognition_model = None -_recognition_processor = None -_detection_model = None -_detection_processor = None +_recognition_predictor = None +_detection_predictor = None def load_models(): """Eagerly load Surya models into memory. Called once at container startup.""" - global _recognition_model, _recognition_processor, _detection_model, _detection_processor + global _recognition_predictor, _detection_predictor logger.info("Loading Surya models...") - from surya.model.detection.model import load_model as load_det_model - from surya.model.detection.model import load_processor as load_det_processor - from surya.model.recognition.model import load_model as load_rec_model - from surya.model.recognition.processor import load_processor as load_rec_processor + from surya.foundation import FoundationPredictor + from surya.recognition import RecognitionPredictor + from surya.detection import DetectionPredictor - _detection_model = load_det_model() - _detection_processor = load_det_processor() - _recognition_model = load_rec_model() - _recognition_processor = load_rec_processor() + foundation_predictor = FoundationPredictor() + _recognition_predictor = RecognitionPredictor(foundation_predictor) + _detection_predictor = DetectionPredictor() logger.info("Surya models loaded successfully") @@ -33,33 +28,36 @@ def load_models(): def extract_blocks(images: list, language: str = "de") -> list[dict]: """Run Surya OCR on a list of PIL images (one per page). - Returns a flat list of block dicts with pageNumber, x, y, width, height, text. - Coordinates are normalized to [0, 1] relative to page dimensions. + Returns a flat list of block dicts with pageNumber, x, y, width, height, + polygon, text. Coordinates are normalized to [0, 1] relative to page dimensions. + Surya 0.17+ returns polygon (4-point) natively on each text line. """ - from surya.detection import batch_text_detection - from surya.recognition import batch_recognition - all_blocks = [] - for page_idx, image in enumerate(images): - page_w, page_h = image.size + predictions = _recognition_predictor(images, det_predictor=_detection_predictor) - det_predictions = batch_text_detection([image], _detection_model, _detection_processor) - rec_predictions = batch_recognition( - [image], det_predictions, _recognition_model, _recognition_processor, [language] - ) + for page_idx, page_pred in enumerate(predictions): + page_w, page_h = images[page_idx].size - for line in rec_predictions[0].text_lines: + for line in page_pred.text_lines: bbox = line.bbox # [x1, y1, x2, y2] in pixel coordinates x1, y1, x2, y2 = bbox + # Surya 0.17 provides polygon as list of (x, y) tuples (4 points, clockwise) + polygon = None + if hasattr(line, "polygon") and line.polygon and len(line.polygon) == 4: + polygon = [ + [p[0] / page_w, p[1] / page_h] + for p in line.polygon + ] + all_blocks.append({ "pageNumber": page_idx, "x": x1 / page_w, "y": y1 / page_h, "width": (x2 - x1) / page_w, "height": (y2 - y1) / page_h, - "polygon": None, + "polygon": polygon, "text": line.text, }) diff --git a/ocr-service/requirements.txt b/ocr-service/requirements.txt index 130c76ef..5a090f8b 100644 --- a/ocr-service/requirements.txt +++ b/ocr-service/requirements.txt @@ -1,6 +1,6 @@ fastapi[standard]==0.115.6 -surya-ocr==0.6.3 -kraken==6.0.3 +surya-ocr==0.17.1 +kraken==7.0 pillow>=10.2.0,<11.0.0 pypdfium2==4.30.0 httpx==0.28.1