feat(ocr): add SSRF protection for PDF URL downloads
Validates PDF download URLs against an ALLOWED_PDF_HOSTS allowlist (default: minio,localhost,127.0.0.1) and disables redirect following to prevent redirect-based SSRF. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit was merged in pull request #229.
This commit is contained in:
@@ -4,7 +4,9 @@ import asyncio
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pypdfium2 as pdfium
|
import pypdfium2 as pdfium
|
||||||
@@ -22,6 +24,18 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
_models_ready = False
|
_models_ready = False
|
||||||
|
|
||||||
|
ALLOWED_PDF_HOSTS = set(
|
||||||
|
h.strip() for h in os.getenv("ALLOWED_PDF_HOSTS", "minio,localhost,127.0.0.1").split(",")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_url(url: str) -> None:
|
||||||
|
"""Validate that the PDF URL points to an allowed host (SSRF protection)."""
|
||||||
|
parsed = urlparse(url)
|
||||||
|
hostname = parsed.hostname or ""
|
||||||
|
if hostname not in ALLOWED_PDF_HOSTS:
|
||||||
|
raise HTTPException(status_code=400, detail=f"PDF host not allowed: {hostname}")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
@@ -157,7 +171,10 @@ async def run_ocr_stream(request: OcrRequest):
|
|||||||
|
|
||||||
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
async def _download_and_convert_pdf(url: str) -> list[Image.Image]:
|
||||||
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
"""Download a PDF from a presigned URL and convert each page to a PIL Image."""
|
||||||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
_validate_url(url)
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(300.0), follow_redirects=False
|
||||||
|
) as client:
|
||||||
response = await client.get(url)
|
response = await client.get(url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|||||||
@@ -239,6 +239,38 @@ async def test_ocr_stream_returns_400_when_kraken_unavailable_for_kurrent(mock_i
|
|||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
# ─── SSRF protection ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ocr_stream_rejects_disallowed_host():
|
||||||
|
with patch("main._models_ready", True):
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr/stream", json={
|
||||||
|
"pdfUrl": "http://evil.example.com/malicious.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "not allowed" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ocr_stream_allows_minio_host(mock_images):
|
||||||
|
with patch("main._download_and_convert_pdf", new_callable=AsyncMock, return_value=mock_images), \
|
||||||
|
patch("main._models_ready", True), \
|
||||||
|
patch("main.surya_engine") as mock_surya:
|
||||||
|
mock_surya.extract_page_blocks.return_value = [_make_block(0)]
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||||
|
response = await client.post("/ocr/stream", json={
|
||||||
|
"pdfUrl": "http://minio/test.pdf",
|
||||||
|
"scriptType": "TYPEWRITER",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ocr_stream_applies_confidence_markers(mock_images):
|
async def test_ocr_stream_applies_confidence_markers(mock_images):
|
||||||
"""Low-confidence words should be replaced with [unleserlich] in the stream output."""
|
"""Low-confidence words should be replaced with [unleserlich] in the stream output."""
|
||||||
|
|||||||
Reference in New Issue
Block a user