"""
Generate detailed justification report with LLM validation
This script reads processed JSON files and asks LLM to justify each code assignment

FIX APPLIED:
- Previously searched in raw .txt files (data/processed/dr_XXXX.txt)
- Now searches in cleaned+chunked JSON files (data/processed/chunked/*_processed.json)
- This is the same text the quotes were originally extracted from
- Also handles part files (dr_HD1103(part1), dr_DA1010_Part1, etc.) automatically
"""
import json
import os
from pathlib import Path
import pandas as pd
import ollama
from tqdm import tqdm
import time

# Paths
INPUT_EXCEL = Path("/home/nsm/dr-transcripts-deductive-V1/outputs/analysis/analysis_results_PATIENT_LEVEL.xlsx")
INPUT_SHEET = "All_Codes_Patient_Level"

# FIXED: Point to chunked folder instead of raw processed folder
CHUNKED_DIR = Path("/home/nsm/dr-transcripts-deductive-V1/data/processed/chunked")

OUTPUT_FILE = Path("/home/nsm/dr-transcripts-deductive-V1/outputs/analysis/detailed_justifications_new.xlsx")
OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)

# ============================================================
# TEST MODE SETTINGS
# Set TEST_MODE = True  → runs only TEST_PATIENTS, saves to a
#                         separate test output file so your
#                         real output is NOT overwritten
# Set TEST_MODE = False → runs ALL patients (full run)
# ============================================================
TEST_MODE     = False                       # ← change to True for test run
TEST_PATIENTS = ["dr_BN1103", "dr_CJ0406"] # ← 2 patients to test with
TEST_OUTPUT   = Path("/home/nsm/dr-transcripts-deductive-V1/outputs/analysis/detailed_justifications_TEST.xlsx")


# Category definitions for LLM reference
CATEGORY_DEFINITIONS = """
CATEGORY DEFINITIONS:

1. BARRIERS - Beliefs & Perceptions
   - Self-efficacy & Belief Barriers: Lack of confidence, feelings of helplessness
   - Emotional & Psychological Barriers: Fear, anxiety, stress, depression, worry
   - Other Belief Barriers: Misconceptions, cultural beliefs, lack of awareness

2. BARRIERS - Healthcare System
   - Healthcare Access Barriers: Transportation, distance, availability, wait times
   - Provider-Patient Relationship Issues: Poor communication, lack of trust, lack of empathy
   - Healthcare System Support: Lack of care coordination, system complexity, scheduling issues

3. BARRIERS - Financial & Access
   - Financial Barriers: Medication costs, insurance issues, cost of care, affordability
   - Other SDOH Barriers: Social determinants (food, housing, transportation)

4. FACILITATORS
   - Self-Management Behaviors: Monitoring, diet, exercise, medication adherence
   - Positive Provider Relationships: Good communication, trust, empathy, availability
   - Health Education & Empowerment: Understanding condition, asking questions, education programs
   - Healthcare System Support: Care coordination, accessible services
   - Other Facilitators: Family support, technology use, resources
"""


def call_llm_for_justification(quote, context, code, theme, group):
    """
    Ask LLM to justify why this quote was coded with this code/theme/group
    """

    prompt = f"""You are a qualitative research validator. A quote from a patient interview was coded as follows:

QUOTE: "{quote}"

SURROUNDING CONTEXT: {context}

ASSIGNED CODING:
- Code: {code}
- Theme: {theme}
- Group: {group}

{CATEGORY_DEFINITIONS}

TASK: Provide a detailed justification explaining:
1. Why this quote fits the assigned code "{code}"
2. Why it belongs to the theme "{theme}"
3. Why it's categorized in the group "{group}"
4. What specific words or phrases in the quote support this coding
5. Whether this coding is accurate or if there might be a better fit

Be specific and reference the exact words from the quote. Your justification will be used to validate the coding accuracy.

FORMAT YOUR RESPONSE AS:
**Justification:** [Your detailed explanation here in 3-5 sentences]
**Supporting Evidence:** [Specific words/phrases from the quote]
**Confidence:** [High/Medium/Low]
**Alternative Coding (if any):** [Suggest if this could fit elsewhere]
"""

    try:
        response = ollama.chat(
            model='llama3.3',
            messages=[
                {
                    'role': 'system',
                    'content': 'You are an expert qualitative research validator.'
                },
                {
                    'role': 'user',
                    'content': prompt
                }
            ],
            options={
                'temperature': 0,
                'num_predict': 2000
            }
        )

        return response['message']['content'].strip()

    except Exception as e:
        print(f"Error calling LLM: {e}")
        import traceback
        print(traceback.format_exc())
        return f"Error: {str(e)}"


# Helper to clean text (unicode normalization + whitespace)
def _clean_text(s: str) -> str:
    if s is None:
        return ""
    s = str(s)
    s = s.replace("\u2013", "-").replace("\u2014", "-")  # en/em dash
    s = s.replace("\u2018", "'").replace("\u2019", "'")  # curly apostrophe
    s = s.replace("\u201c", '"').replace("\u201d", '"')  # curly quotes
    s = " ".join(s.split())
    return s


_CLEANED_TEXT_CACHE = {}


# Cache for chunks too
_CHUNKS_CACHE = {}


def load_cleaned_text(patient_id: str) -> str:
    """
    Load full_text from *_processed.json files in chunked folder.
    Handles part files automatically by globbing all files starting with patient_id.
    e.g. dr_HD1103 will find dr_HD1103(part1)_processed.json + dr_HD1103(part2)_processed.json
    e.g. dr_DA1010 will find dr_DA1010_Part1_processed.json + dr_DA1010_Part2_processed.json
    """
    pid = str(patient_id).strip()
    if pid in _CLEANED_TEXT_CACHE:
        return _CLEANED_TEXT_CACHE[pid]

    matches = sorted(CHUNKED_DIR.glob(f"{pid}*_processed.json"))

    if not matches:
        print(f"  WARNING: No processed JSON found for {pid} in {CHUNKED_DIR}")
        _CLEANED_TEXT_CACHE[pid] = ""
        _CHUNKS_CACHE[pid] = []
        return ""

    combined = ""
    all_chunks = []
    for json_path in matches:
        with open(json_path, "r", encoding="utf-8", errors="ignore") as f:
            data = json.load(f)
        combined += " " + data.get("full_text", "")
        all_chunks.extend(data.get("chunks", []))

    combined = _clean_text(combined.strip())
    _CLEANED_TEXT_CACHE[pid] = combined
    _CHUNKS_CACHE[pid] = all_chunks
    return combined


def get_chunk_text(patient_id: str, chunk_id) -> str:
    """
    Return the specific chunk text for a given patient and chunk_id.
    Used as fallback evidence when fuzzy match fails.
    """
    pid = str(patient_id).strip()

    # Make sure cache is loaded
    if pid not in _CHUNKS_CACHE:
        load_cleaned_text(pid)

    chunks = _CHUNKS_CACHE.get(pid, [])
    if not chunks:
        return ""

    try:
        idx = int(chunk_id)
    except (ValueError, TypeError):
        return ""

    if 0 <= idx < len(chunks):
        chunk = chunks[idx]
        # chunks can be plain strings or dicts
        if isinstance(chunk, dict):
            return _clean_text(chunk.get("text", ""))
        elif isinstance(chunk, str):
            return _clean_text(chunk)

    return ""


def _strip_for_fuzzy(s: str) -> str:
    """Remove punctuation, quotes, annotations, speaker labels for fuzzy matching."""
    import re
    # Remove parenthetical annotations like (when asked if...)
    s = re.sub(r'\(.*?\)', ' ', s)
    # Remove speaker labels like "Speaker 2:" or "S1:"
    s = re.sub(r'\b(?:Speaker\s*\d+|S\d+)\s*:', ' ', s, flags=re.IGNORECASE)
    # Remove ellipsis
    s = s.replace("...", " ")
    # Remove all punctuation except spaces
    s = re.sub(r'[^\w\s]', ' ', s)
    # Normalize whitespace
    s = " ".join(s.split())
    return s.lower()


def find_evidence_snippet(cleaned_text: str, quote: str, window_chars: int = 280):
    """
    Search for quote in cleaned text and return surrounding context snippet.
    Tries exact match first, then progressively looser matching, then fuzzy.
    Returns (found_bool, evidence_snippet, match_index)
    """
    import re
    from difflib import SequenceMatcher

    text = cleaned_text or ""
    q = _clean_text(quote)

    if not text or not q:
        return False, "", -1

    low_text = text.lower()
    low_q = q.lower()

    # Try 1: direct case-insensitive match
    idx = low_text.find(low_q)
    if idx != -1:
        start = max(0, idx - window_chars)
        end = min(len(text), idx + len(q) + window_chars)
        return True, text[start:end], idx

    # Try 2: remove double-dashes (speech interruptions)
    low_q2 = low_q.replace("--", " ").replace("-", " ")
    low_q2 = " ".join(low_q2.split())
    idx2 = low_text.find(low_q2)
    if idx2 != -1:
        start = max(0, idx2 - window_chars)
        end = min(len(text), idx2 + len(low_q2) + window_chars)
        return True, text[start:end], idx2

    # Try 3: remove double-commas (ASR artifacts like ", ,")
    low_q3 = low_q.replace(", ,", ",").replace(",,", ",")
    low_q3 = " ".join(low_q3.split())
    idx3 = low_text.find(low_q3)
    if idx3 != -1:
        start = max(0, idx3 - window_chars)
        end = min(len(text), idx3 + len(low_q3) + window_chars)
        return True, text[start:end], idx3

    # Try 4: remove ALL commas and dashes
    low_q4 = re.sub(r'[,\-]+', ' ', low_q)
    low_q4 = " ".join(low_q4.split())
    low_text4 = re.sub(r'[,\-]+', ' ', low_text)
    low_text4 = " ".join(low_text4.split())
    idx4 = low_text4.find(low_q4)
    if idx4 != -1:
        start = max(0, idx4 - window_chars)
        end = min(len(text), idx4 + len(low_q4) + window_chars)
        return True, text[start:end], idx4

    # Try 5: strip ALL punctuation + annotations + speaker labels (handles
    #         LLM-joined quotes with ..., parenthetical notes, speaker labels)
    q5 = _strip_for_fuzzy(q)
    text5 = _strip_for_fuzzy(text)
    if q5:
        idx5 = text5.find(q5)
        if idx5 != -1:
            start = max(0, idx5 - window_chars)
            end = min(len(text), idx5 + len(q5) + window_chars)
            return True, text[start:end], idx5

    # Try 6: fuzzy sliding-window match (handles paraphrased / joined quotes)
    # Split transcript into overlapping windows same length as quote
    # and find the window with highest similarity score
    q6 = _strip_for_fuzzy(q)
    words_text = text5.split()   # already stripped version
    words_q    = q6.split()
    win_size   = max(len(words_q), 5)   # window = quote length in words
    best_ratio = 0.0
    best_idx   = -1

    for i in range(0, max(1, len(words_text) - win_size + 1)):
        window = " ".join(words_text[i:i + win_size])
        ratio  = SequenceMatcher(None, q6, window).ratio()
        if ratio > best_ratio:
            best_ratio = ratio
            best_idx   = i

    # Accept if similarity >= 70%
    if best_ratio >= 0.70 and best_idx != -1:
        # Map word index back to character index in original text
        char_idx = len(" ".join(words_text[:best_idx]))
        start    = max(0, char_idx - window_chars)
        end      = min(len(text), char_idx + window_chars * 2)
        return True, text[start:end], char_idx

    return False, "", -1


def process_patient_from_excel(patient_id: str, df_all_codes: pd.DataFrame):
    """
    Process one patient_id from Excel rows and generate justifications.
    Uses cleaned chunked JSON text (same source quotes were extracted from).
    """
    pid = str(patient_id).strip()
    rows = df_all_codes[df_all_codes["patient_id"].astype(str).str.strip() == pid].copy()

    results = []

    # FIXED: Load from cleaned chunked JSON instead of raw txt
    cleaned_text = load_cleaned_text(pid)

    print(f"\nProcessing {pid}... total rows: {len(rows)}, text loaded: {len(cleaned_text)} chars")

    for _, r in tqdm(rows.iterrows(), total=len(rows), desc=f"Rows for {pid}"):
        code = str(r.get("code", "") or "").strip()
        theme = str(r.get("theme", "") or "").strip()
        group = str(r.get("group", "") or "").strip()
        quote = str(r.get("quote", "") or "").strip()

        excel_context = str(r.get("context", "") or "").strip()
        original_file_id = str(r.get("original_file_id", "") or "").strip()
        chunk_id = r.get("chunk_id", "")

        if not all([code, theme, group, quote]):
            continue

        found, evidence, _ = find_evidence_snippet(cleaned_text, quote, window_chars=280)

        # Determine evidence snippet and context for LLM
        if found:
            # Best case: found verbatim or fuzzy in transcript
            raw_evidence = evidence
            context_for_llm = evidence
        else:
            # Fallback: use the specific chunk the quote came from
            chunk_text = get_chunk_text(pid, chunk_id)
            if chunk_text:
                raw_evidence = chunk_text
                context_for_llm = chunk_text
            elif excel_context:
                raw_evidence = excel_context
                context_for_llm = excel_context
            else:
                raw_evidence = "Context not available."
                context_for_llm = "Context not available."

        justification = call_llm_for_justification(
            quote=quote,
            context=context_for_llm,
            code=code,
            theme=theme,
            group=group
        )

        results.append({
            "patient_id": pid,
            "original_file_id": original_file_id,
            "chunk_id": chunk_id,
            "code": code,
            "theme": theme,
            "group": group,
            "quote": quote,
            "raw_evidence_snippet": raw_evidence,
            "excel_context": excel_context,
            "llm_justification": justification
        })

        time.sleep(0.25)

    return results


def main():
    print("=" * 80)
    print("GENERATING LLM JUSTIFICATION REPORT (FROM CLEANED CHUNKED JSON)")
    print("=" * 80)

    if not INPUT_EXCEL.exists():
        print(f"ERROR: Missing Excel file: {INPUT_EXCEL}")
        return

    if not CHUNKED_DIR.exists():
        print(f"ERROR: Chunked directory not found: {CHUNKED_DIR}")
        return

    df_all = pd.read_excel(INPUT_EXCEL, sheet_name=INPUT_SHEET)

    all_patient_ids = sorted(df_all["patient_id"].astype(str).str.strip().unique().tolist())

    # TEST MODE: only run selected patients, save to separate file
    if TEST_MODE:
        patient_ids = TEST_PATIENTS
        output_path = TEST_OUTPUT
        print(f"*** TEST MODE ON - running {len(patient_ids)} patients only ***")
        print(f"*** Patients : {patient_ids} ***")
        print(f"*** Output   : {output_path} ***")
    else:
        patient_ids = all_patient_ids
        output_path = OUTPUT_FILE
        print(f"*** FULL MODE - running all {len(patient_ids)} patients ***")
        print(f"*** Output   : {output_path} ***")

    print(f"Total patients to process: {len(patient_ids)}")

    all_results = []
    for pid in patient_ids:
        all_results.extend(process_patient_from_excel(pid, df_all))

    df = pd.DataFrame(all_results)

    # Summary
    total = len(df)
    print(f"\n=== SUMMARY ===")
    print(f"Total rows processed : {total}")

    print(f"\nSaving results to {output_path}...")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
        df.to_excel(writer, sheet_name="Justifications", index=False)

        ws = writer.sheets["Justifications"]
        ws.column_dimensions["A"].width = 15   # patient_id
        ws.column_dimensions["B"].width = 20   # original_file_id
        ws.column_dimensions["C"].width = 10   # chunk_id
        ws.column_dimensions["D"].width = 40   # code
        ws.column_dimensions["E"].width = 30   # theme
        ws.column_dimensions["F"].width = 25   # group
        ws.column_dimensions["G"].width = 60   # quote
        ws.column_dimensions["H"].width = 90   # raw_evidence_snippet
        ws.column_dimensions["I"].width = 60   # excel_context
        ws.column_dimensions["J"].width = 110  # llm_justification

        from openpyxl.styles import Alignment
        for row in ws.iter_rows(min_row=2, max_row=ws.max_row):
            for cell in row:
                cell.alignment = Alignment(wrap_text=True, vertical="top")

    print(f"\n✓ Successfully created justification report!")
    print(f"  Total rows processed : {total}")
    print(f"  Output file          : {output_path}")


if __name__ == "__main__":
    main()