from __future__ import annotations
import os, json
from typing import Any
import yaml
import pandas as pd
from tqdm import tqdm

from schemas import ConversationReport, BatchSummary, EntitiesByCategory, EntityItem, Evidence, PatientPerspective
from pdf_text import load_folder, group_conversations, merge_conversation_text
from llm_ollama import ollama_generate
from dedupe import dedupe_strings

SYSTEM_RULES = """
You are extracting information from a diabetes-related eye care interview transcript.
Follow these rules strictly:
1) Do NOT invent information. If not stated, return empty list for that field.
2) Always include short exact quotes as evidence for entities and for patient perspective.
3) Output MUST be valid JSON only (no markdown, no extra text).
"""

PROMPT_TEMPLATE = """
{rules}

Return ONLY valid JSON.

REQUIRED JSON SCHEMA (must follow exactly):
{{
  "questions_asked": ["..."],
  "entities": [
    {{
      "category": "Symptoms",
      "items": [
        {{
          "name": "blurred vision",
          "evidence": [{{"quote": "exact short quote", "speaker": "S2"}}]
        }}
      ]
    }}
  ],
  "patient_perspective": {{
    "occurrence": ["..."],
    "severity": ["..."],
    "concerns": ["..."],
    "goals": ["..."],
    "occurrence_evidence": [{{"quote":"...", "speaker":"S2"}}],
    "severity_evidence": [{{"quote":"...", "speaker":"S2"}}],
    "concerns_evidence": [{{"quote":"...", "speaker":"S2"}}],
    "goals_evidence": [{{"quote":"...", "speaker":"S2"}}]
  }}
}}

RULES:
- "entities" MUST be a list of objects (NOT a string).
- Each entity category MUST be one of:
  Symptoms, Ophthalmic Findings, Diagnostic Tools, Systemic Risk Factors, Treatment Options, Demographics/History
- If nothing found for a category, include the category with items: []
- Do not invent anything. If not stated, keep lists empty.

TRANSCRIPT:
<<<
{transcript}
>>>
"""


def safe_json_load(s: str) -> Any:
    s = (s or "").strip()

    # ✅ If LLM returned nothing, fail clearly
    if not s:
        raise ValueError("LLM returned EMPTY response")

    # Try direct parse
    try:
        return json.loads(s)
    except Exception:
        # Try extract first JSON object
        start = s.find("{")
        end = s.rfind("}")
        if start != -1 and end != -1 and end > start:
            return json.loads(s[start:end+1])
        raise

def normalize_entities(raw_entities) -> list[EntitiesByCategory]:
    """
    raw_entities should be: list[{"category": str, "items": [...] }]

    But sometimes the model returns a string or wrong shape.
    We handle that safely and return empty categories.
    """
    allowed_categories = [
        "Symptoms",
        "Ophthalmic Findings",
        "Diagnostic Tools",
        "Systemic Risk Factors",
        "Treatment Options",
        "Demographics/History",
    ]

    # ✅ If model returned a string or something invalid, fallback
    if not isinstance(raw_entities, list):
        raw_entities = []

    # Build a dict to ensure every category exists
    by_cat = {c: [] for c in allowed_categories}

    for cat in raw_entities:
        if not isinstance(cat, dict):
            continue

        category = cat.get("category")
        if category not in by_cat:
            continue

        items_out = []
        for it in (cat.get("items") or []):
            if not isinstance(it, dict):
                continue

            evs = []
            for e in (it.get("evidence") or []):
                if not isinstance(e, dict):
                    continue
                evs.append(Evidence(quote=(e.get("quote", "")).strip(), speaker=e.get("speaker")))

            items_out.append(EntityItem(name=(it.get("name", "")).strip(), evidence=evs))

        by_cat[category] = items_out

    # Return all categories (even if empty)
    out = []
    for c in allowed_categories:
        out.append(EntitiesByCategory(category=c, items=by_cat[c]))
    return out


def normalize_patient_perspective(raw_pp: dict) -> PatientPerspective:
    def ev_list(key: str):
        evs = []
        for e in (raw_pp.get(key, []) or []):
            evs.append(Evidence(quote=(e.get("quote","").strip()), speaker=e.get("speaker")))
        return evs

    return PatientPerspective(
        occurrence=raw_pp.get("occurrence", []) or [],
        severity=raw_pp.get("severity", []) or [],
        concerns=raw_pp.get("concerns", []) or [],
        goals=raw_pp.get("goals", []) or [],
        occurrence_evidence=ev_list("occurrence_evidence"),
        severity_evidence=ev_list("severity_evidence"),
        concerns_evidence=ev_list("concerns_evidence"),
        goals_evidence=ev_list("goals_evidence"),
    )

def main():
    cfg = yaml.safe_load(open("config.yaml", "r"))
    input_dir = cfg["input_dir"]
    only_ids = set(cfg.get("only_conversations", []) or [])

    out_dir = cfg.get("output_dir", "outputs")
    os.makedirs(out_dir, exist_ok=True)

    model = cfg.get("model", "llama3.1:8b-instruct")
    host = cfg.get("ollama_host", "http://localhost:11434")
    num_ctx = int(cfg.get("num_ctx", 32768))
    temperature = float(cfg.get("temperature", 0.2))

    docs = load_folder(input_dir)
    groups = group_conversations(docs)

        # ✅ REBUILD MODE: if json files already exist, rebuild CSVs without calling LLM
    if cfg.get("rebuild_only", False):
        conversation_reports = []
        for fn in os.listdir(out_dir):
            if fn.endswith(".json") and fn not in ("batch_summary.json",):
                with open(os.path.join(out_dir, fn), "r", encoding="utf-8") as f:
                    conversation_reports.append(ConversationReport.model_validate_json(f.read()))

        # Build summary + CSVs from existing reports
        all_questions = []
        for r in conversation_reports:
            all_questions.extend(r.questions_asked)

        global_questions = dedupe_strings(all_questions, threshold=90)
        summary = BatchSummary(
            total_unique_conversations=len(conversation_reports),
            conversation_ids=[r.conversation_id for r in conversation_reports],
            global_distinct_questions=global_questions,
        )
        with open(os.path.join(out_dir, "batch_summary.json"), "w", encoding="utf-8") as f:
            f.write(summary.model_dump_json(indent=2))

        # patient_perspective.csv
        rows = []
        for r in conversation_reports:
            rows.append({
                "conversation_id": r.conversation_id,
                "source_files": "; ".join(r.source_files),
                "occurrence": " | ".join(r.patient_perspective.occurrence),
                "severity": " | ".join(r.patient_perspective.severity),
                "concerns": " | ".join(r.patient_perspective.concerns),
                "goals": " | ".join(r.patient_perspective.goals),
            })
        pd.DataFrame(rows).to_csv(os.path.join(out_dir, "patient_perspective.csv"), index=False)

        # entities.csv
        e_rows = []
        for r in conversation_reports:
            for cat in r.entities:
                for item in cat.items:
                    quotes = " || ".join([e.quote for e in item.evidence][:3])
                    e_rows.append({
                        "conversation_id": r.conversation_id,
                        "category": cat.category,
                        "entity": item.name,
                        "evidence_quotes": quotes,
                    })
        pd.DataFrame(e_rows).to_csv(os.path.join(out_dir, "entities.csv"), index=False)

        # global_questions.csv
        pd.DataFrame({"question": global_questions}).to_csv(os.path.join(out_dir, "global_questions.csv"), index=False)

        print("\n✅ REBUILD DONE (no LLM calls)")
        print(f"- Total conversations: {summary.total_unique_conversations}")
        print(f"- Outputs saved to: {out_dir}")
        return


    conversation_reports: list[ConversationReport] = []
    all_questions: list[str] = []

    for cid, parts in tqdm(groups, desc="Processing conversations"):
                # ✅ If only_conversations provided, run only those
        if only_ids and cid not in only_ids:
            continue

                # ✅ RESUME SUPPORT: if already processed, skip
        out_path = os.path.join(out_dir, f"{cid}.json")
        if os.path.exists(out_path):
            continue

        merged = merge_conversation_text(parts)
        prompt = PROMPT_TEMPLATE.format(rules=SYSTEM_RULES, transcript=merged)

        resp = ollama_generate(
            model=model,
            prompt=prompt,
            host=host,
            num_ctx=num_ctx,
            temperature=temperature,
        )

        # ✅ Always save RAW output (so we can inspect failures)
        with open(os.path.join(out_dir, f"{cid}_RAW.txt"), "w", encoding="utf-8") as f:
            f.write(resp or "")

        try:
            data = safe_json_load(resp)
        except Exception as e:
            print(f"❌ Failed for conversation {cid}: {e}")
            # Skip this conversation and continue
            continue



        # ✅ DEBUG: save raw LLM response if parsing structure is unexpected
        if not isinstance(data, dict) or not isinstance(data.get("entities", []), list):
            with open(os.path.join(out_dir, f"{cid}_RAW.txt"), "w", encoding="utf-8") as f:
                f.write(resp)


        questions = dedupe_strings(data.get("questions_asked", []) or [], threshold=90)
        entities = normalize_entities(data.get("entities", []) or [])
        pp = normalize_patient_perspective(data.get("patient_perspective", {}) or {})

        report = ConversationReport(
            conversation_id=cid,
            source_files=[p.filename for p in parts],
            questions_asked=questions,
            entities=entities,
            patient_perspective=pp,
        )
        conversation_reports.append(report)
        all_questions.extend(questions)

        # Save per-conversation JSON
        with open(os.path.join(out_dir, f"{cid}.json"), "w", encoding="utf-8") as f:
            f.write(report.model_dump_json(indent=2))

    # Batch summary
    global_questions = dedupe_strings(all_questions, threshold=90)
    summary = BatchSummary(
        total_unique_conversations=len(conversation_reports),
        conversation_ids=[r.conversation_id for r in conversation_reports],
        global_distinct_questions=global_questions,
    )
    with open(os.path.join(out_dir, "batch_summary.json"), "w", encoding="utf-8") as f:
        f.write(summary.model_dump_json(indent=2))

    # Create a CSV table for patient perspective
    rows = []
    for r in conversation_reports:
        rows.append({
            "conversation_id": r.conversation_id,
            "source_files": "; ".join(r.source_files),
            "occurrence": " | ".join(r.patient_perspective.occurrence),
            "severity": " | ".join(r.patient_perspective.severity),
            "concerns": " | ".join(r.patient_perspective.concerns),
            "goals": " | ".join(r.patient_perspective.goals),
        })
    pd.DataFrame(rows).to_csv(os.path.join(out_dir, "patient_perspective.csv"), index=False)

    # Create an entity table CSV (one row per entity item)
    e_rows = []
    for r in conversation_reports:
        for cat in r.entities:
            for item in cat.items:
                quotes = " || ".join([e.quote for e in item.evidence][:3])
                e_rows.append({
                    "conversation_id": r.conversation_id,
                    "category": cat.category,
                    "entity": item.name,
                    "evidence_quotes": quotes,
                })
    pd.DataFrame(e_rows).to_csv(os.path.join(out_dir, "entities.csv"), index=False)

    # Create questions CSV
    pd.DataFrame({"question": global_questions}).to_csv(os.path.join(out_dir, "global_questions.csv"), index=False)

    print("\n✅ DONE")
    print(f"- Total conversations: {summary.total_unique_conversations}")
    print(f"- Outputs saved to: {out_dir}")

if __name__ == "__main__":
    main()
