import logging
import re

from app.utils.constants import KNOWN_UNITS, TEST_ALIASES
from app.utils.helpers import (
    calculate_status,
    finalize_test_output,
    is_differential_raw,
    normalize_ocr_units,
    normalize_range_string,
    parse_numeric_value,
    parse_reference_range,
    pick_value,
)

logger = logging.getLogger("lab_analyzer")

# Typical WBC differential % ranges (when AI swaps absolute range into %)
TYPICAL_PCT_RANGES: dict[str, tuple[float, float]] = {
    "neutrophil": (40, 70),
    "lymphocyte": (20, 40),
    "monocyte": (2, 10),
    "eosinophil": (1, 6),
    "basophil": (0, 2),
}


def _canonical_name(test_name: str) -> str:
    key = test_name.strip().lower()
    if key in TEST_ALIASES:
        return TEST_ALIASES[key]
    for alias, canonical in TEST_ALIASES.items():
        if alias in key or key in alias:
            return canonical
    return test_name.strip()


def _normalize_unit(unit: str) -> str:
    if not unit:
        return ""
    u = normalize_ocr_units(unit.strip())
    if u.lower() == "n":
        u = "fL"
    for known in KNOWN_UNITS:
        if u.lower() == known.lower():
            return known
    return u


def _dedupe_key(name: str, result: str, pct: str, abs_val: str) -> str:
    return f"{name}|{result}|{pct}|{abs_val}".lower()


def _typical_pct_range(test_name: str) -> str | None:
    lower = test_name.lower()
    for key, (lo, hi) in TYPICAL_PCT_RANGES.items():
        if key in lower:
            return f"{lo}-{hi}"
    return None


def _percentage_range_plausible(pct_val: str, pct_range: str) -> bool:
    """Detect when AI put absolute range (e.g. 1-4.8) on a % value (e.g. 23.4)."""
    num = parse_numeric_value(pct_val)
    if num is None:
        return False
    low, high = parse_reference_range(pct_range)
    if low is None or high is None:
        return False
    if num > 100 or high > 100:
        return False
    # % value > 15 but range max < 15 → wrong range type
    if num > 15 and high < 15:
        return False
    return True


def _build_standard_test(
    test_name: str,
    canonical: str,
    raw: dict,
    notes: list[str],
) -> dict | None:
    result = pick_value(raw, "result", "value")
    unit = _normalize_unit(pick_value(raw, "unit"))
    normal_range = normalize_range_string(
        pick_value(raw, "normal_range", "reference_range")
    )

    if not result:
        notes.append("skipped: standard test without result")
        return None

    test: dict = {"test_name": test_name}
    if canonical and canonical != test_name:
        test["canonical_name"] = canonical
    test["result"] = result
    if unit:
        test["unit"] = unit
    if normal_range:
        test["normal_range"] = normal_range

    status = calculate_status(result, normal_range)
    if status != "unknown":
        test["status"] = status

    return test


def _build_differential_test(
    test_name: str,
    canonical: str,
    raw: dict,
    notes: list[str],
) -> dict | None:
    pct_val = pick_value(raw, "percentage_value")
    pct_unit = _normalize_unit(pick_value(raw, "percentage_unit")) or ("%" if pct_val else "")
    pct_range = normalize_range_string(pick_value(raw, "percentage_range"))

    abs_val = pick_value(raw, "absolute_value")
    abs_unit = _normalize_unit(pick_value(raw, "absolute_unit"))
    abs_range = normalize_range_string(pick_value(raw, "absolute_range"))

    if not pct_val and not abs_val:
        return None

    test: dict = {"test_name": test_name}
    if canonical and canonical != test_name:
        test["canonical_name"] = canonical

    if pct_val:
        test["percentage_value"] = pct_val
        if pct_unit:
            test["percentage_unit"] = pct_unit

        if not _percentage_range_plausible(pct_val, pct_range):
            typical = _typical_pct_range(test_name)
            if typical:
                pct_range = typical
                notes.append("fixed: percentage_range from typical values")
            else:
                pct_range = ""

        if pct_range:
            test["percentage_range"] = pct_range
            status_pct = calculate_status(pct_val, pct_range)
            if status_pct != "unknown":
                test["status_percentage"] = status_pct

    if abs_val:
        test["absolute_value"] = abs_val
        if abs_unit:
            test["absolute_unit"] = abs_unit
        if abs_range:
            test["absolute_range"] = abs_range
            status_abs = calculate_status(abs_val, abs_range)
            if status_abs != "unknown":
                test["status_absolute"] = status_abs

    return test


def _merge_duplicate_entries(validated: list[dict]) -> list[dict]:
    """Merge ESR hours, skip PT control duplicate, etc."""
    merged: list[dict] = []
    esr_parts: list[str] = []

    for test in validated:
        name_lower = (test.get("test_name") or "").lower()

        if "erythrocyte sedimentation" in name_lower or "esr" in name_lower:
            esr_parts.append(str(test.get("result", "")))
            continue

        if "prothrombin time" in name_lower and not test.get("normal_range"):
            continue  # skip control row without range

        merged.append(test)

    if esr_parts:
        result = esr_parts[0] if len(esr_parts) == 1 else f"{esr_parts[0]} (ساعة أولى) / {esr_parts[1]} (ساعة ثانية)" if len(esr_parts) > 1 else esr_parts[0]
        status = calculate_status(esr_parts[0], "Up to 12")
        esr_test = {
            "test_name": "Erythrocyte Sedimentation Rate (ESR)",
            "result": result,
            "unit": "mm",
            "normal_range": "Up to 12",
        }
        if status != "unknown":
            esr_test["status"] = status
        merged.append(esr_test)

    return merged


def validate_and_enrich_tests(raw_tests: list[dict]) -> list[dict]:
    validated: list[dict] = []
    seen: set[str] = set()

    for raw in raw_tests:
        if not isinstance(raw, dict):
            continue

        test_name = normalize_ocr_units(pick_value(raw, "test_name"))
        if not test_name:
            continue

        canonical = _canonical_name(test_name)
        notes: list[str] = []

        if is_differential_raw(raw):
            test = _build_differential_test(test_name, canonical, raw, notes)
        else:
            test = _build_standard_test(test_name, canonical, raw, notes)

        if not test:
            continue

        dedupe = _dedupe_key(
            canonical,
            test.get("result", ""),
            test.get("percentage_value", ""),
            test.get("absolute_value", ""),
        )
        if dedupe in seen:
            continue
        seen.add(dedupe)

        validated.append(finalize_test_output(test))

    clean_results = _merge_duplicate_entries(validated)

    logger.info("[VALIDATION] %d tests after validation (from %d raw)", len(clean_results), len(raw_tests))
    return clean_results
