from __future__ import annotations
import os, json
from typing import Any
import yaml
import pandas as pd
from tqdm import tqdm

from schemas import ConversationReport, BatchSummary, EntitiesByCategory, EntityItem, Evidence, PatientPerspective, EntityPresenceSentiment

from pdf_text import load_folder, group_conversations, merge_conversation_text
from llm_ollama import ollama_generate
from dedupe import dedupe_strings

SYSTEM_RULES = """
You are extracting information from a diabetes-related eye care interview transcript.
Follow these rules strictly:
1) Do NOT invent information. If not stated, return empty list for that field.
2) Always include short exact quotes as evidence for entities and for patient perspective.
3) Output MUST be valid JSON only (no markdown, no extra text).
"""

PROMPT_TEMPLATE = """
{rules}

Return ONLY valid JSON.

REQUIRED JSON SCHEMA (must follow exactly):
{{
  "questions_asked": ["..."],
  "entities": [
    {{
      "category": "Symptoms",
      "items": [
        {{
          "name": "blurred vision",
          "evidence": [{{"quote": "exact short quote", "speaker": "S2"}}]
        }}
      ]
    }}
  ],
  "patient_perspective": {{
    "occurrence": ["..."],
    "severity": ["..."],
    "concerns": ["..."],
    "goals": ["..."],
    "occurrence_evidence": [{{"quote":"...", "speaker":"S2"}}],
    "severity_evidence": [{{"quote":"...", "speaker":"S2"}}],
    "concerns_evidence": [{{"quote":"...", "speaker":"S2"}}],
    "goals_evidence": [{{"quote":"...", "speaker":"S2"}}]
  }}
}}

RULES:
- "entities" MUST be a list of objects (NOT a string).
- Each entity category MUST be one of:
  Symptoms, Ophthalmic Findings, Diagnostic Tools, Systemic Risk Factors, Treatment Options, Demographics/History
- If nothing found for a category, include the category with items: []
- Do not invent anything. If not stated, keep lists empty.

TRANSCRIPT:
<<<
{transcript}
>>>
"""

PRESENCE_PROMPT_TEMPLATE = """
{rules}

You MUST return ONLY valid JSON.

TASK:
Given the transcript, check a PRE-DEFINED entity list.
For each entity:
- present: true/false
- sentiment: positive/neutral/negative/mixed/unknown
Sentiment should reflect the PATIENT attitude about that entity if present.
If the entity is present but no attitude is expressed, use "neutral".
If unclear, use "unknown".

PREDEFINED ENTITY LIST (categories -> entities):
{entity_list_json}

FEW-SHOT EXAMPLES:
Example 1:
Transcript: "S2: I see floaters and I'm worried it will get worse."
Output entity row: {{"entity":"floaters","present":true,"sentiment":"negative","evidence":[{{"quote":"I see floaters and I'm worried","speaker":"S2"}}]}}

Example 2:
Transcript: "S1: We will do OCT today. S2: Okay."
Output entity row: {{"entity":"OCT","present":true,"sentiment":"neutral","evidence":[{{"quote":"We will do OCT today.","speaker":"S1"}}]}}

NOW PROCESS THIS TRANSCRIPT:
<<<
{transcript}
>>>

Return JSON with this exact schema:
{{
  "entity_presence_sentiment": [
    {{
      "category": "Symptoms",
      "entity": "blurred vision",
      "present": true,
      "sentiment": "negative",
      "evidence": [{{"quote":"...", "speaker":"S2"}}]
    }}
  ]
}}
"""

def safe_json_load(s: str) -> Any:
    s = (s or "").strip()

    # ✅ If LLM returned nothing, fail clearly
    if not s:
        raise ValueError("LLM returned EMPTY response")

    # Try direct parse
    try:
        return json.loads(s)
    except Exception:
        # Try extract first JSON object
        start = s.find("{")
        end = s.rfind("}")
        if start != -1 and end != -1 and end > start:
            return json.loads(s[start:end+1])
        raise

def normalize_entities(raw_entities) -> list[EntitiesByCategory]:
    """
    raw_entities should be: list[{"category": str, "items": [...] }]

    But sometimes the model returns a string or wrong shape.
    We handle that safely and return empty categories.
    """
    allowed_categories = [
        "Symptoms",
        "Ophthalmic Findings",
        "Diagnostic Tools",
        "Systemic Risk Factors",
        "Treatment Options",
        "Demographics/History",
    ]

    # ✅ If model returned a string or something invalid, fallback
    if not isinstance(raw_entities, list):
        raw_entities = []

    # Build a dict to ensure every category exists
    by_cat = {c: [] for c in allowed_categories}

    for cat in raw_entities:
        if not isinstance(cat, dict):
            continue

        category = cat.get("category")
        if category not in by_cat:
            continue

        items_out = []
        for it in (cat.get("items") or []):
            if not isinstance(it, dict):
                continue

            evs = []
            for e in (it.get("evidence") or []):
                if not isinstance(e, dict):
                    continue
                evs.append(Evidence(quote=(e.get("quote", "")).strip(), speaker=e.get("speaker")))

            items_out.append(EntityItem(name=(it.get("name", "")).strip(), evidence=evs))

        by_cat[category] = items_out

    # Return all categories (even if empty)
    out = []
    for c in allowed_categories:
        out.append(EntitiesByCategory(category=c, items=by_cat[c]))
    return out


def normalize_patient_perspective(raw_pp: dict) -> PatientPerspective:
    def ev_list(key: str):
        evs = []
        for e in (raw_pp.get(key, []) or []):
            evs.append(Evidence(quote=(e.get("quote","").strip()), speaker=e.get("speaker")))
        return evs

    return PatientPerspective(
        occurrence=raw_pp.get("occurrence", []) or [],
        severity=raw_pp.get("severity", []) or [],
        concerns=raw_pp.get("concerns", []) or [],
        goals=raw_pp.get("goals", []) or [],
        occurrence_evidence=ev_list("occurrence_evidence"),
        severity_evidence=ev_list("severity_evidence"),
        concerns_evidence=ev_list("concerns_evidence"),
        goals_evidence=ev_list("goals_evidence"),
    )
def config_entity_list_json(cfg: dict) -> str:
    import json
    ent = cfg.get("predefined_entities", {}) or {}
    return json.dumps(ent, ensure_ascii=False, indent=2)

def normalize_presence_rows(rows) -> list[EntityPresenceSentiment]:
    if not isinstance(rows, list):
        return []
    out = []
    for r in rows:
        if not isinstance(r, dict):
            continue
        evs = []
        for e in (r.get("evidence") or []):
            if isinstance(e, dict):
                evs.append(Evidence(quote=(e.get("quote","").strip()), speaker=e.get("speaker")))
        out.append(EntityPresenceSentiment(
            category=r.get("category"),
            entity=(r.get("entity","").strip()),
            present=bool(r.get("present", False)),
            sentiment=r.get("sentiment","unknown"),
            evidence=evs,
        ))
    return out

def main():
    cfg = yaml.safe_load(open("config.yaml", "r"))
    input_dir = cfg["input_dir"]
    only_ids = set(cfg.get("only_conversations", []) or [])

    out_dir = cfg.get("output_dir", "outputs")
    os.makedirs(out_dir, exist_ok=True)

    model = cfg.get("model", "llama3.1:8b-instruct")
    host = cfg.get("ollama_host", "http://localhost:11434")
    num_ctx = int(cfg.get("num_ctx", 32768))
    temperature = float(cfg.get("temperature", 0.2))

    docs = load_folder(input_dir)
    groups = group_conversations(docs)

    # ✅ REBUILD MODE: if json files already exist, rebuild CSVs without calling LLM
    if cfg.get("rebuild_only", False):
        conversation_reports = []
        for fn in os.listdir(out_dir):
            if fn.endswith(".json") and fn not in ("batch_summary.json",):
                with open(os.path.join(out_dir, fn), "r", encoding="utf-8") as f:
                    conversation_reports.append(ConversationReport.model_validate_json(f.read()))

        # Build summary + CSVs from existing reports
        all_questions = []
        for r in conversation_reports:
            all_questions.extend(r.questions_asked)

        global_questions = dedupe_strings(all_questions, threshold=90)
        summary = BatchSummary(
            total_unique_conversations=len(conversation_reports),
            conversation_ids=[r.conversation_id for r in conversation_reports],
            global_distinct_questions=global_questions,
        )
        with open(os.path.join(out_dir, "batch_summary.json"), "w", encoding="utf-8") as f:
            f.write(summary.model_dump_json(indent=2))

        # patient_perspective.csv
        rows = []
        for r in conversation_reports:
            rows.append({
                "conversation_id": r.conversation_id,
                "source_files": "; ".join(r.source_files),
                "occurrence": " | ".join(r.patient_perspective.occurrence),
                "severity": " | ".join(r.patient_perspective.severity),
                "concerns": " | ".join(r.patient_perspective.concerns),
                "goals": " | ".join(r.patient_perspective.goals),
            })
        pd.DataFrame(rows).to_csv(os.path.join(out_dir, "patient_perspective.csv"), index=False)

        # entities.csv
        e_rows = []
        for r in conversation_reports:
            for cat in r.entities:
                for item in cat.items:
                    quotes = " || ".join([e.quote for e in item.evidence][:3])
                    e_rows.append({
                        "conversation_id": r.conversation_id,
                        "category": cat.category,
                        "entity": item.name,
                        "evidence_quotes": quotes,
                    })
        pd.DataFrame(e_rows).to_csv(os.path.join(out_dir, "entities.csv"), index=False)

        # ✅ NEW: predefined entity presence + sentiment table
        ps_rows = []
        for r in conversation_reports:
            for row in (r.entity_presence_sentiment or []):
                quotes = " || ".join([e.quote for e in row.evidence][:2])
                ps_rows.append({
                    "conversation_id": r.conversation_id,
                    "category": row.category,
                    "entity": row.entity,
                    "present": row.present,
                    "sentiment": row.sentiment,
                    "evidence_quotes": quotes,
                })
        pd.DataFrame(ps_rows).to_csv(os.path.join(out_dir, "entity_presence_sentiment.csv"), index=False)


        # global_questions.csv
        pd.DataFrame({"question": global_questions}).to_csv(os.path.join(out_dir, "global_questions.csv"), index=False)

        print("\n✅ REBUILD DONE (no LLM calls)")
        print(f"- Total conversations: {summary.total_unique_conversations}")
        print(f"- Outputs saved to: {out_dir}")
        return


    conversation_reports: list[ConversationReport] = []
    all_questions: list[str] = []

    for cid, parts in tqdm(groups, desc="Processing conversations"):
        # ✅ If only_conversations provided, run only those
        if only_ids and cid not in only_ids:
            continue

        # ✅ RESUME SUPPORT: if already processed, skip
        out_path = os.path.join(out_dir, f"{cid}.json")
        if os.path.exists(out_path):
            continue

        merged = merge_conversation_text(parts)

        # (A) existing extraction (questions/entities/patient perspective)
        prompt = PROMPT_TEMPLATE.format(rules=SYSTEM_RULES, transcript=merged)
        resp = ollama_generate(model=model, prompt=prompt, host=host, num_ctx=num_ctx, temperature=temperature)

        with open(os.path.join(out_dir, f"{cid}_RAW.txt"), "w", encoding="utf-8") as f:
            f.write(resp or "")

        try:
            data = safe_json_load(resp)
        except Exception as e:
            print(f"❌ Failed for conversation {cid} (base): {e}")
            continue

        # (B) NEW: predefined entity presence + sentiment
        entity_list_json = config_entity_list_json(cfg)
        presence_prompt = PRESENCE_PROMPT_TEMPLATE.format(
            rules=SYSTEM_RULES,
            entity_list_json=entity_list_json,
            transcript=merged,
        )
        presence_resp = ollama_generate(model=model, prompt=presence_prompt, host=host, num_ctx=num_ctx, temperature=temperature)

        with open(os.path.join(out_dir, f"{cid}_PRESENCE_RAW.txt"), "w", encoding="utf-8") as f:
            f.write(presence_resp or "")

        try:
            presence_data = safe_json_load(presence_resp)
        except Exception as e:
            print(f"❌ Failed for conversation {cid} (presence): {e}")
            presence_data = {"entity_presence_sentiment": []}




        # ✅ DEBUG: save raw LLM response if parsing structure is unexpected
        if not isinstance(data, dict) or not isinstance(data.get("entities", []), list):
            with open(os.path.join(out_dir, f"{cid}_RAW.txt"), "w", encoding="utf-8") as f:
                f.write(resp)


        questions = dedupe_strings(data.get("questions_asked", []) or [], threshold=90)
        entities = normalize_entities(data.get("entities", []) or [])
        pp = normalize_patient_perspective(data.get("patient_perspective", {}) or {})

        presence_rows = normalize_presence_rows(presence_data.get("entity_presence_sentiment", []))

        report = ConversationReport(
            conversation_id=cid,
            source_files=[p.filename for p in parts],
            questions_asked=questions,
            entities=entities,
            patient_perspective=pp,
            entity_presence_sentiment=presence_rows,
        )

        conversation_reports.append(report)
        all_questions.extend(questions)

        # Save per-conversation JSON
        with open(os.path.join(out_dir, f"{cid}.json"), "w", encoding="utf-8") as f:
            f.write(report.model_dump_json(indent=2))

    # Batch summary
    global_questions = dedupe_strings(all_questions, threshold=90)
    summary = BatchSummary(
        total_unique_conversations=len(conversation_reports),
        conversation_ids=[r.conversation_id for r in conversation_reports],
        global_distinct_questions=global_questions,
    )
    with open(os.path.join(out_dir, "batch_summary.json"), "w", encoding="utf-8") as f:
        f.write(summary.model_dump_json(indent=2))

    # Create a CSV table for patient perspective
    rows = []
    for r in conversation_reports:
        rows.append({
            "conversation_id": r.conversation_id,
            "source_files": "; ".join(r.source_files),
            "occurrence": " | ".join(r.patient_perspective.occurrence),
            "severity": " | ".join(r.patient_perspective.severity),
            "concerns": " | ".join(r.patient_perspective.concerns),
            "goals": " | ".join(r.patient_perspective.goals),
        })
    pd.DataFrame(rows).to_csv(os.path.join(out_dir, "patient_perspective.csv"), index=False)

    # Create an entity table CSV (one row per entity item)
    e_rows = []
    for r in conversation_reports:
        for cat in r.entities:
            for item in cat.items:
                quotes = " || ".join([e.quote for e in item.evidence][:3])
                e_rows.append({
                    "conversation_id": r.conversation_id,
                    "category": cat.category,
                    "entity": item.name,
                    "evidence_quotes": quotes,
                })
    pd.DataFrame(e_rows).to_csv(os.path.join(out_dir, "entities.csv"), index=False)

    # Create questions CSV
    pd.DataFrame({"question": global_questions}).to_csv(os.path.join(out_dir, "global_questions.csv"), index=False)

    # ✅ NEW: predefined entity presence + sentiment table (sir's main tabulation)
    ps_rows = []
    for r in conversation_reports:
        for row in (getattr(r, "entity_presence_sentiment", []) or []):
            quotes = " || ".join([e.quote for e in row.evidence][:2])
            ps_rows.append({
                "conversation_id": r.conversation_id,
                "category": row.category,
                "entity": row.entity,
                "present": row.present,
                "sentiment": row.sentiment,
                "evidence_quotes": quotes,
            })
    pd.DataFrame(ps_rows).to_csv(os.path.join(out_dir, "entity_presence_sentiment.csv"), index=False)

    print("\n✅ DONE")
    print(f"- Total conversations: {summary.total_unique_conversations}")
    print(f"- Outputs saved to: {out_dir}")

if __name__ == "__main__":
    main()
