import json
import logging
import re
from typing import Any

import httpx

logger = logging.getLogger("lab_analyzer")


def setup_logging() -> None:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    )


def remove_symbol_garbage_only(text: str) -> str:
    """Remove © ™ ® and invisible unicode noise only."""
    text = text.replace("©", "").replace("™", "").replace("®", "")
    text = re.sub(r"[\u200b-\u200f\u2028\u2029\ufeff]", "", text)
    return text


# Unit OCR fixes applied during light cleaning and validation
UNIT_OCR_FIXES: list[tuple[str, str]] = [
    (r"gldL", "g/dL"),
    (r"g\/dl\b", "g/dL"),
    (r"g\/dL", "g/dL"),
    (r"mgldL", "mg/dL"),
    (r"mg\/dl\b", "mg/dL"),
    (r"x103\s*jul", "x10^3/uL"),
    (r"x103\s*/\s*ul", "x10^3/uL"),
    (r"x10\s*\*\s*9/L", "x10^9/L"),
    (r"x109\/L", "x10^9/L"),
    (r"x1019\/L", "x10^9/L"),
    (r"x10\^9\/L", "x10^9/L"),
    (r"x10'\s*/\s*ut", "x10^3/uL"),
    (r"10\^3\/uL", "10^3/uL"),
    (r"Millions\s*I\s*cmm", "Millions/cmm"),
    (r"Thousands\s*I\s*cmm", "Thousands/cmm"),
    (r"\bUIL\b", "U/L"),
    (r"u\/L", "U/L"),
    (r"iu\/l", "IU/L"),
    (r"mmol\/l", "mmol/L"),
    # OCR misread: "fl" read as "n" (MCV/MCH units)
    (r"\bn\b", "fL"),
]


def normalize_ocr_units(text: str) -> str:
    if not text:
        return ""
    for pattern, replacement in UNIT_OCR_FIXES:
        text = re.sub(pattern, replacement, text, flags=re.IGNORECASE)
    return text


def _fix_spaced_range(match: re.Match) -> str:
    """Repair '150 450' -> '150-450' when it looks like a numeric range."""
    a_str, b_str = match.group(1), match.group(2)
    try:
        a, b = float(a_str), float(b_str)
    except ValueError:
        return match.group(0)
    # Plausible lab range: second number greater, not wildly larger
    if b > a and b <= a * 5:
        return f"{a_str}-{b_str}"
    return match.group(0)


def normalize_range_string(text: str) -> str:
    """
    Repair common OCR range mistakes.
    Examples: 28:40 -> 28-40, 150 450 -> 150-450, 31-3 37 -> 31.3-37
    """
    if not text:
        return ""

    # Colon used instead of dash (e.g. 28:40)
    text = re.sub(r"\b(\d{1,4}(?:\.\d+)?)\s*:\s*(\d{1,4}(?:\.\d+)?)\b", r"\1-\2", text)

    # Space-separated range (e.g. Platelets 150 450) — before broken-decimal fix
    text = re.sub(
        r"\b(\d{2,4}(?:\.\d+)?)\s+(\d{2,4}(?:\.\d+)?)\b",
        _fix_spaced_range,
        text,
    )

    # Broken decimal: 31-3 37 -> 31.3-37
    text = re.sub(r"\b(\d{1,3})-(\d)\s+(\d{2,3})\b", r"\1.\2-\3", text)
    return text


def is_gemma_model(model: str) -> bool:
    return "gemma" in model.lower()


def _test_has_data(item: dict) -> bool:
    """True if row has a name and at least one measurable value."""
    name = str(item.get("test_name", "")).strip()
    if not name:
        return False
    value_fields = (
        "result", "value", "percentage_value", "absolute_value",
    )
    return any(str(item.get(f, "")).strip() for f in value_fields)


def validate_extraction_json(data: dict[str, Any] | None) -> bool:
    if not data or not isinstance(data, dict):
        return False
    tests = data.get("tests")
    if not isinstance(tests, list) or len(tests) == 0:
        return False
    valid_count = sum(1 for t in tests if isinstance(t, dict) and _test_has_data(t))
    return valid_count > 0


def validate_interpretation_json(data: dict[str, Any] | None) -> bool:
    if not data or not isinstance(data, dict):
        return False
    if not str(data.get("overall_summary", "")).strip():
        return False
    return isinstance(data.get("tests"), list)


def try_parse_json(raw: str) -> dict[str, Any] | None:
    if not raw or not raw.strip():
        return None

    text = raw.strip()
    if text.startswith("```"):
        text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.IGNORECASE)
        text = re.sub(r"\s*```$", "", text)

    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass

    match = re.search(r"\{[\s\S]*\}", text)
    if match:
        try:
            return json.loads(match.group())
        except json.JSONDecodeError:
            pass

    fixed = re.sub(r",\s*}", "}", text)
    fixed = re.sub(r",\s*]", "]", fixed)
    try:
        return json.loads(fixed)
    except json.JSONDecodeError:
        logger.warning("Failed to parse JSON from model output")
        return None


def parse_numeric_value(value: str) -> float | None:
    if not value:
        return None
    match = re.search(r"[-+]?\d*\.?\d+", str(value).replace(",", ""))
    if match:
        try:
            return float(match.group())
        except ValueError:
            return None
    return None


def _extract_numeric_range_from_text(ref: str) -> tuple[float | None, float | None]:
    """Pull first numeric range from labels like 'Normal: 70 - 139'."""
    ref = normalize_range_string(ref.strip().replace(",", ""))

    range_match = re.search(
        r"([-+]?\d*\.?\d+)\s*[-–—]\s*([-+]?\d*\.?\d+)", ref
    )
    if range_match:
        return float(range_match.group(1)), float(range_match.group(2))

    # TSH-style: "0.55 4.78" already fixed by normalize_range_string
    spaced = re.search(r"(\d+\.?\d*)\s+(\d+\.?\d+)", ref)
    if spaced:
        a, b = float(spaced.group(1)), float(spaced.group(2))
        if b > a:
            return a, b

    return None, None


def parse_reference_range(ref: str) -> tuple[float | None, float | None]:
    if not ref:
        return None, None

    ref_clean = ref.strip().replace(",", "")
    low, high = _extract_numeric_range_from_text(ref_clean)
    if low is not None or high is not None:
        return low, high

    # Less than X / أقل من
    less_match = re.search(
        r"(?:less\s+than|أقل\s+من|up\s+to|حتى)\s*([-+]?\d*\.?\d+)",
        ref_clean,
        re.IGNORECASE,
    )
    if less_match:
        return None, float(less_match.group(1))

    lt_match = re.search(r"<\s*([-+]?\d*\.?\d+)", ref_clean)
    if lt_match:
        return None, float(lt_match.group(1))

    gt_match = re.search(r">\s*([-+]?\d*\.?\d+)", ref_clean)
    if gt_match:
        return float(gt_match.group(1)), None

    return None, None


def calculate_status(value: str, reference_range: str) -> str:
    """Calculate low / normal / high in Python only."""
    num = parse_numeric_value(value)
    if num is None:
        return "unknown"

    ref = reference_range or ""
    ref_lower = ref.lower()

    # HbA1c style: "Normal: Less than 5.7" — below threshold is GOOD
    less_than = re.search(
        r"(?:normal\s*:\s*)?(?:less\s+than|أقل\s+من)\s*([-+]?\d*\.?\d+)",
        ref_lower,
    )
    if less_than:
        threshold = float(less_than.group(1))
        if num < threshold:
            return "normal"
        if num >= threshold:
            return "high"
        return "unknown"

    # "Up to 12" / "Upto 12" — above is high
    up_to = re.search(r"(?:up\s*to|upto|حتى)\s*([-+]?\d*\.?\d+)", ref_lower)
    if up_to:
        limit = float(up_to.group(1))
        return "normal" if num <= limit else "high"

    # Strip label prefix: "Normal: 70 - 139" (extract range correctly)
    if "normal" in ref_lower or "طبيعي" in ref_lower:
        low_bound, high_bound = _extract_numeric_range_from_text(ref)
        if low_bound is not None and high_bound is not None:
            ref = f"{low_bound}-{high_bound}"

    low, high = parse_reference_range(ref)
    if low is None and high is None:
        return "unknown"

    if low is not None and high is not None:
        if num < low:
            return "low"
        if num > high:
            return "high"
        return "normal"

    if high is not None and low is None:
        return "normal" if num <= high else "high"

    if low is not None and high is None:
        return "normal" if num >= low else "low"

    return "unknown"


# ---------------------------------------------------------------------------
# Shared HTTP client (reused across AI provider calls)
# ---------------------------------------------------------------------------
_http_client: httpx.AsyncClient | None = None


def get_http_client() -> httpx.AsyncClient:
    global _http_client
    if _http_client is None:
        from app.utils.constants import HTTP_TIMEOUT

        _http_client = httpx.AsyncClient(timeout=HTTP_TIMEOUT)
    return _http_client


# ---------------------------------------------------------------------------
# Lab test schema helpers (dynamic fields per test type)
# ---------------------------------------------------------------------------
DIFFERENTIAL_KEYWORDS = (
    "neutrophil", "lymphocyte", "monocyte", "eosinophil", "basophil",
    "granulocyte", "band", "immature",
)


def strip_empty(d: dict[str, Any]) -> dict[str, Any]:
    clean: dict[str, Any] = {}
    for key, val in d.items():
        if val is None:
            continue
        if isinstance(val, str) and not val.strip():
            continue
        if isinstance(val, list) and len(val) == 0:
            continue
        clean[key] = val
    return clean


def is_differential_raw(raw: dict[str, Any]) -> bool:
    if raw.get("percentage_value") or raw.get("absolute_value"):
        return True
    name = str(raw.get("test_name", "")).lower()
    return any(kw in name for kw in DIFFERENTIAL_KEYWORDS)


def pick_value(raw: dict[str, Any], *keys: str) -> str:
    for key in keys:
        val = raw.get(key)
        if val is not None and str(val).strip():
            return str(val).strip()
    return ""


def finalize_test_output(test: dict[str, Any]) -> dict[str, Any]:
    cleaned = strip_empty(test)
    if cleaned.get("canonical_name") == cleaned.get("test_name"):
        cleaned.pop("canonical_name", None)
    return cleaned
