from __future__ import annotations
import os, json, time
from typing import Any
import yaml
import pandas as pd
from tqdm import tqdm

from schemas import ConversationReport, BatchSummary, EntitiesByCategory, EntityItem, Evidence, PatientPerspective, EntityPresenceSentiment, BarrierFacilitatorPresence

from pdf_text import load_folder, group_conversations, merge_conversation_text
from llm_ollama import ollama_generate
from dedupe import dedupe_strings

# ✅ FIXED: Combined system rules for both tasks
SYSTEM_RULES = """
You are a medical transcript analyzer extracting structured information from diabetes eye care interviews.

CRITICAL RULES:
1) Extract ONLY information explicitly stated in the transcript
2) Do NOT invent, infer, or assume anything
3) Include SHORT exact quotes (max 15 words) as evidence
4) Return ONLY valid JSON (no markdown, no preamble, no explanations)
5) If information is not present, use empty lists []
"""

# ✅ MAJOR FIX: SINGLE COMPREHENSIVE PROMPT (instead of 2 separate prompts)
# This cuts processing time in HALF
COMPREHENSIVE_PROMPT_TEMPLATE = """
{rules}

YOU ARE A STRICT JSON EXTRACTION TOOL. YOU MUST FOLLOW THE EXACT SCHEMA PROVIDED.
DO NOT INVENT NEW CATEGORIES. DO NOT ADD EXTRA FIELDS. ONLY USE THE CATEGORIES LISTED BELOW.

=== ALLOWED MEDICAL ENTITY CATEGORIES (USE ONLY THESE 6) ===
1. Symptoms
2. Ophthalmic Findings
3. Diagnostic Tools
4. Systemic Risk Factors
5. Treatment Options
6. Demographics/History

PREDEFINED MEDICAL ENTITIES TO CHECK:
{entity_list_json}

=== ALLOWED BARRIER/FACILITATOR CATEGORIES (USE ONLY THESE 12) ===
1. Transportation_Barriers
2. Access_Barriers
3. Cost_Barriers
4. Belief_Barriers
5. Provider_System_Barriers
6. Education_Barriers
7. Family_Social_Barriers
8. Support_Facilitators
9. Provider_Facilitators
10. Access_Facilitators
11. Technology_Facilitators
12. Education_Facilitators

PREDEFINED BARRIERS/FACILITATORS TO CHECK:
{barriers_facilitators_json}

FLATTENED LIST (CHECK EVERY ITEM - APPROXIMATELY {bf_items_count} ITEMS):
{barriers_facilitators_flat_list}

=== STRICT EXTRACTION RULES ===
1. For "entities" array: ONLY use the 6 categories listed above
2. For "entity_presence_sentiment" array: MUST have EXACTLY 24 rows (one per predefined medical entity)
3. For "barriers_facilitators" array: MUST check ALL items from the flattened list (~{bf_items_count} rows)
4. DO NOT create new categories like "Diabetes", "Insurance", "Communication_Facilitators", etc.
5. If something doesn't fit a predefined category, DO NOT include it
6. Return ONLY the JSON schema shown below - no extra fields, no extra categories

=== REQUIRED JSON OUTPUT SCHEMA ===
{{
  "questions_asked": ["question template 1", "question template 2"],
  
  "entities": [
    {{"category": "Symptoms", "items": [{{"name": "entity", "evidence": [{{"quote": "...", "speaker": "S2"}}]}}]}},
    {{"category": "Ophthalmic Findings", "items": []}},
    {{"category": "Diagnostic Tools", "items": []}},
    {{"category": "Systemic Risk Factors", "items": []}},
    {{"category": "Treatment Options", "items": []}},
    {{"category": "Demographics/History", "items": []}}
  ],
  
  "patient_perspective": {{
    "occurrence": ["past year", "every morning"],
    "severity": ["vision getting worse", "can't read anymore"],
    "concerns": ["worried about going blind", "afraid of losing job"],
    "goals": ["want to keep driving", "see grandchildren clearly"],
    "occurrence_evidence": [{{"quote": "...", "speaker": "S2"}}],
    "severity_evidence": [{{"quote": "...", "speaker": "S2"}}],
    "concerns_evidence": [{{"quote": "...", "speaker": "S2"}}],
    "goals_evidence": [{{"quote": "...", "speaker": "S2"}}]
  }},
  
  "entity_presence_sentiment": [
    {{"category": "Symptoms", "entity": "blurred vision", "present": true, "sentiment": "negative", "evidence": [{{"quote": "...", "speaker": "S2"}}]}},
    {{"category": "Symptoms", "entity": "fluctuating vision", "present": false, "sentiment": "unknown", "evidence": []}},
    ... (continue for all 24 predefined medical entities - DO NOT skip any)
  ],
  
  "barriers_facilitators": [
    {{"category": "Transportation_Barriers", "item": "no transportation", "present": false, "sentiment": "unknown", "evidence": [], "impact_level": null}},
    {{"category": "Transportation_Barriers", "item": "too far to travel", "present": false, "sentiment": "unknown", "evidence": [], "impact_level": null}},
    ... (continue for ALL items from the flattened list - approximately {bf_items_count} rows)
  ]
}}

=== CRITICAL REMINDERS ===
- entity_presence_sentiment MUST have EXACTLY 24 rows
- barriers_facilitators MUST have approximately {bf_items_count} rows
- ONLY use the 6 medical entity categories listed above
- ONLY use the 12 barrier/facilitator categories listed above
- DO NOT invent categories like "Diabetes", "Insurance", "Communication_Facilitators"
- If you cannot categorize something, leave it out
- Evidence quotes: 5-15 words, direct from transcript
- Sentiment: positive, neutral, negative, mixed, or unknown
- Impact level (barriers/facilitators only): high, medium, low, or null

=== TRANSCRIPT TO ANALYZE ===
{transcript}

=== OUTPUT (JSON ONLY - NO MARKDOWN, NO EXTRA TEXT) ===
"""
def safe_json_load(s: str) -> Any:
    """✅ IMPROVED: Better JSON extraction with clearer error messages"""
    s = (s or "").strip()

    if not s:
        raise ValueError("LLM returned EMPTY response")

    # Remove markdown code blocks if present
    if s.startswith("```"):
        lines = s.split("\n")
        s = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
        s = s.replace("```json", "").replace("```", "").strip()

    # Try direct parse
    try:
        return json.loads(s)
    except json.JSONDecodeError as e:
        # Try to extract first complete JSON object
        start = s.find("{")
        end = s.rfind("}")
        if start != -1 and end != -1 and end > start:
            try:
                return json.loads(s[start:end+1])
            except:
                pass
        
        # If still failing, provide helpful error
        print(f"❌ JSON parsing failed: {e}")
        print(f"Response preview (first 500 chars): {s[:500]}")
        raise ValueError(f"Could not parse LLM response as JSON: {e}")

def normalize_entities(raw_entities) -> list[EntitiesByCategory]:
    """✅ UNCHANGED: This function is good"""
    allowed_categories = [
        "Symptoms",
        "Ophthalmic Findings",
        "Diagnostic Tools",
        "Systemic Risk Factors",
        "Treatment Options",
        "Demographics/History",
    ]

    if not isinstance(raw_entities, list):
        raw_entities = []

    by_cat = {c: [] for c in allowed_categories}

    for cat in raw_entities:
        if not isinstance(cat, dict):
            continue

        category = cat.get("category")
        
        # ✅ ADDED: Skip None or invalid categories
        if category is None or category not in by_cat:
            print(f"  ⚠️  Skipping invalid entity category: {category}")
            continue

        items_out = []
        for it in (cat.get("items") or []):
            if not isinstance(it, dict):
                continue

            evs = []
            for e in (it.get("evidence") or []):
                if not isinstance(e, dict):
                    continue
                evs.append(Evidence(quote=(e.get("quote", "")).strip(), speaker=e.get("speaker")))

            items_out.append(EntityItem(name=(it.get("name", "")).strip(), evidence=evs))

        by_cat[category] = items_out

    out = []
    for c in allowed_categories:
        out.append(EntitiesByCategory(category=c, items=by_cat[c]))
    return out


def normalize_patient_perspective(raw_pp: dict) -> PatientPerspective:
    """✅ UNCHANGED: This function is good"""
    def ev_list(key: str):
        evs = []
        for e in (raw_pp.get(key, []) or []):
            evs.append(Evidence(quote=(e.get("quote","").strip()), speaker=e.get("speaker")))
        return evs

    return PatientPerspective(
        occurrence=raw_pp.get("occurrence", []) or [],
        severity=raw_pp.get("severity", []) or [],
        concerns=raw_pp.get("concerns", []) or [],
        goals=raw_pp.get("goals", []) or [],
        occurrence_evidence=ev_list("occurrence_evidence"),
        severity_evidence=ev_list("severity_evidence"),
        concerns_evidence=ev_list("concerns_evidence"),
        goals_evidence=ev_list("goals_evidence"),
    )

def config_entity_list_json(cfg: dict) -> str:
    """✅ UNCHANGED: This function is good"""
    ent = cfg.get("predefined_entities", {}) or {}
    return json.dumps(ent, ensure_ascii=False, indent=2)

def config_barriers_facilitators_json(cfg: dict) -> str:
    import json
    bf = cfg.get("barriers_facilitators", {}) or {}
    return json.dumps(bf, ensure_ascii=False, indent=2)

def flatten_barriers_facilitators_list(cfg: dict) -> str:
    """
    ✅ NEW: Create explicit list of all barrier/facilitator items
    Returns formatted string showing all items that must be checked
    """
    import json
    bf_dict = cfg.get("barriers_facilitators", {}) or {}
    
    all_items = []
    for category, items in bf_dict.items():
        for item in items:
            all_items.append({
                "category": category,
                "item": item
            })
    
    return json.dumps(all_items, ensure_ascii=False, indent=2)

def count_barriers_facilitators_items(cfg: dict) -> int:
    """
    ✅ NEW: Count total number of barrier/facilitator items to check
    """
    bf_dict = cfg.get("barriers_facilitators", {}) or {}
    total = sum(len(items) for items in bf_dict.values())
    return total

def normalize_presence_rows(rows) -> list[EntityPresenceSentiment]:
    if not isinstance(rows, list):
        return []
    
    # ✅ ADDED: Valid categories
    valid_categories = {
        "Symptoms", "Ophthalmic Findings", "Diagnostic Tools",
        "Systemic Risk Factors", "Treatment Options", "Demographics/History"
    }
    
    out = []
    seen = set()
    
    for r in rows:
        if not isinstance(r, dict):
            continue
        
        category = r.get("category")
        entity = (r.get("entity", "").strip()).lower()
        
        # ✅ ADDED: Skip invalid categories or None
        if category not in valid_categories:
            print(f"  ⚠️  Skipping invalid entity category: {category}")
            continue
        
        # Skip duplicates
        key = (category, entity)
        if key in seen:
            continue
        seen.add(key)
        
        # ✅ ADDED: Validate evidence structure
        evs = []
        evidence_list = r.get("evidence") or []
        if isinstance(evidence_list, list):
            for e in evidence_list:
                if isinstance(e, dict):
                    evs.append(Evidence(quote=(e.get("quote","").strip()), speaker=e.get("speaker")))
        
        try:
            out.append(EntityPresenceSentiment(
                category=category,
                entity=(r.get("entity","").strip()),
                present=bool(r.get("present", False)),
                sentiment=r.get("sentiment","unknown"),
                evidence=evs,
            ))
        except Exception as e:
            print(f"  ⚠️  Failed to create EntityPresenceSentiment: {e}")
            continue
    
    return out


def ensure_complete_entity_presence(presence_rows, predefined_entities_dict):
    """
    Ensure ALL 24 predefined entities are present in the result.
    Add missing ones as present=false
    """
    from schemas import Category
    
    # Flatten predefined entities with proper categories
    required_entities = []
    category_map = {
        "Symptoms": ["blurred vision", "fluctuating vision", "floaters", "poor night vision", "faded colors"],
        "Ophthalmic Findings": ["microaneurysms", "cotton wool spots", "retinal hemorrhages", "macular edema", "neovascularization", "scar tissue"],
        "Diagnostic Tools": ["dilated eye exam", "OCT", "fundus photography", "fluorescein angiography"],
        "Systemic Risk Factors": ["HbA1c", "blood pressure", "cholesterol"],
        "Treatment Options": ["intravitreal injection", "laser treatment", "vitrectomy"],
        "Demographics/History": ["duration of diabetes", "age over 65", "smoking"]
    }
    
    for category, entities in category_map.items():
        for entity in entities:
            required_entities.append((category, entity.lower()))
    
    # Track what we have
    existing = set()
    for row in presence_rows:
        key = (row.category, row.entity.lower())
        existing.add(key)
    
    # Add missing ones
    for cat, ent in required_entities:
        if (cat, ent) not in existing:
            presence_rows.append(EntityPresenceSentiment(
                category=cat,
                entity=ent,
                present=False,
                sentiment='unknown',
                evidence=[]
            ))
    
    return presence_rows


def normalize_barriers_facilitators(rows) -> list[BarrierFacilitatorPresence]:
    if not isinstance(rows, list):
        return []
    
    # ✅ ADDED: Valid categories
    valid_categories = {
        "Transportation_Barriers", "Access_Barriers", "Cost_Barriers", 
        "Belief_Barriers", "Provider_System_Barriers", "Education_Barriers",
        "Family_Social_Barriers", "Support_Facilitators", "Provider_Facilitators",
        "Access_Facilitators", "Technology_Facilitators", "Education_Facilitators"
    }
    
    out = []
    seen = set()
    
    for r in rows:
        if not isinstance(r, dict):
            continue
        
        category = r.get("category")
        item = (r.get("item", "").strip()).lower()
        
        # ✅ ADDED: Skip invalid categories
        if category not in valid_categories:
            print(f"  ⚠️  Skipping invalid barrier/facilitator category: {category}")
            continue
        
        # Skip duplicates
        key = (category, item)
        if key in seen:
            continue
        seen.add(key)
        
        # ✅ ADDED: Validate evidence structure
        evs = []
        evidence_list = r.get("evidence") or []
        if isinstance(evidence_list, list):
            for e in evidence_list:
                if isinstance(e, dict):
                    evs.append(Evidence(quote=(e.get("quote","").strip()), speaker=e.get("speaker")))
        
        try:
            out.append(BarrierFacilitatorPresence(
                category=category,
                item=(r.get("item","").strip()),
                present=bool(r.get("present", False)),
                sentiment=r.get("sentiment","unknown"),
                evidence=evs,
                impact_level=r.get("impact_level"),
            ))
        except Exception as e:
            print(f"  ⚠️  Failed to create BarrierFacilitatorPresence: {e}")
            continue
    
    return out
    
def estimate_token_count(text: str) -> int:
    """✅ NEW: Rough token estimation (1 token ≈ 4 characters)"""
    return len(text) // 4

def chunk_transcript_if_needed(text: str, max_tokens: int = 10000) -> list[str]:
    """
    ✅ NEW: Split very long transcripts into chunks to prevent timeouts
    """
    tokens = estimate_token_count(text)
    if tokens <= max_tokens:
        return [text]
    
    # Split by clear boundaries (file markers)
    chunks = []
    current_chunk = []
    current_tokens = 0
    
    for line in text.split("\n"):
        line_tokens = estimate_token_count(line)
        
        if current_tokens + line_tokens > max_tokens and current_chunk:
            chunks.append("\n".join(current_chunk))
            current_chunk = [line]
            current_tokens = line_tokens
        else:
            current_chunk.append(line)
            current_tokens += line_tokens
    
    if current_chunk:
        chunks.append("\n".join(current_chunk))
    
    return chunks

def merge_chunked_results(chunk_results: list[dict]) -> dict:
    """
    ✅ NEW: Merge results from multiple chunks
    """
    merged = {
        "questions_asked": [],
        "entities": [],
        "patient_perspective": {
            "occurrence": [], "severity": [], "concerns": [], "goals": [],
            "occurrence_evidence": [], "severity_evidence": [], 
            "concerns_evidence": [], "goals_evidence": []
        },
        "entity_presence_sentiment": []
    }
    
    for result in chunk_results:
        # Merge questions
        merged["questions_asked"].extend(result.get("questions_asked", []))
        
        # Merge entities (combine items from same category)
        for cat in result.get("entities", []):
            category_name = cat.get("category")
            existing = next((c for c in merged["entities"] if c.get("category") == category_name), None)
            if existing:
                existing["items"].extend(cat.get("items", []))
            else:
                merged["entities"].append(cat)
        
        # Merge patient perspective
        pp = result.get("patient_perspective", {})
        for key in ["occurrence", "severity", "concerns", "goals"]:
            merged["patient_perspective"][key].extend(pp.get(key, []))
        for key in ["occurrence_evidence", "severity_evidence", "concerns_evidence", "goals_evidence"]:
            merged["patient_perspective"][key].extend(pp.get(key, []))
        
        # Merge entity presence (combine evidence for same entities)
        for item in result.get("entity_presence_sentiment", []):
            entity_name = item.get("entity")
            existing = next((e for e in merged["entity_presence_sentiment"] 
                           if e.get("entity") == entity_name), None)
            if existing:
                existing["evidence"].extend(item.get("evidence", []))
                if item.get("present"):
                    existing["present"] = True
                if item.get("sentiment") != "unknown":
                    existing["sentiment"] = item.get("sentiment")
            else:
                merged["entity_presence_sentiment"].append(item)
        
        # Merge barriers_facilitators (combine evidence for same items)
        if "barriers_facilitators" not in merged:
            merged["barriers_facilitators"] = []
        for item in result.get("barriers_facilitators", []):
            # ✅ ADDED: Validate item structure
            if not isinstance(item, dict):
                continue
            
            item_name = item.get("item")
            existing = next((e for e in merged["barriers_facilitators"] 
                           if e.get("item") == item_name), None)
            if existing:
                # ✅ ADDED: Safely extend evidence
                evidence_to_add = item.get("evidence", [])
                if isinstance(evidence_to_add, list):
                    if "evidence" not in existing:
                        existing["evidence"] = []
                    existing["evidence"].extend(evidence_to_add)
                
                if item.get("present"):
                    existing["present"] = True
                if item.get("sentiment") != "unknown":
                    existing["sentiment"] = item.get("sentiment")
                if item.get("impact_level") and item.get("impact_level") != "null":
                    existing["impact_level"] = item.get("impact_level")
            else:
                merged["barriers_facilitators"].append(item)
    
    return merged
def clean_patient_perspective(pp: PatientPerspective) -> PatientPerspective:
    """
    Remove placeholder text that the model copied from the prompt
    """
    placeholders = {
        "when/how often", "how bad/impact", "worries/fears", 
        "what patient wants", "what patient hopes to achieve",
        "onset/timing/frequency/triggers", "severity + impact on life"
    }
    
    def clean_list(items):
        return [item for item in items if item.lower() not in placeholders and len(item) > 5]
    
    return PatientPerspective(
        occurrence=clean_list(pp.occurrence),
        severity=clean_list(pp.severity),
        concerns=clean_list(pp.concerns),
        goals=clean_list(pp.goals),
        occurrence_evidence=pp.occurrence_evidence,
        severity_evidence=pp.severity_evidence,
        concerns_evidence=pp.concerns_evidence,
        goals_evidence=pp.goals_evidence
    )


def main(config_path: str = "config.yaml", only_ids: list[str] = None):
    """✅ IMPROVED: Better progress tracking and error handling"""
    
    # Load config
    with open(config_path, "r", encoding="utf-8") as f:
        cfg = yaml.safe_load(f)

    input_dir = cfg.get("input_dir")
    out_dir = cfg.get("output_dir", "outputs")
    os.makedirs(out_dir, exist_ok=True)

    model = cfg.get("model", "llama3.1:latest")
    host = cfg.get("ollama_host", "http://localhost:11434")
    num_ctx = int(cfg.get("num_ctx", 12288))  # ✅ FIXED: Lower default
    temperature = float(cfg.get("temperature", 0.1))  # ✅ FIXED: Lower default
    timeout = int(cfg.get("conversation_timeout", 1800))  # ✅ NEW: Configurable timeout
    max_transcript_tokens = int(cfg.get("max_transcript_tokens", 10000))  # ✅ NEW

    docs = load_folder(input_dir)
    groups = group_conversations(docs)

    print(f"\n{'='*60}")
    print(f"📊 EXTRACTION PIPELINE STARTING")
    print(f"{'='*60}")
    print(f"Model: {model}")
    print(f"Context window: {num_ctx} tokens")
    print(f"Temperature: {temperature}")
    print(f"Timeout per conversation: {timeout}s ({timeout//60} minutes)")
    print(f"Total conversations found: {len(groups)}")
    print(f"{'='*60}\n")

    # ✅ REBUILD MODE: if json files already exist, rebuild CSVs without calling LLM
    if cfg.get("rebuild_only", False):
        print("🔄 REBUILD MODE: Loading existing JSONs and regenerating CSVs...\n")
        conversation_reports = []
        for fn in os.listdir(out_dir):
            if fn.endswith(".json") and fn not in ("batch_summary.json",):
                with open(os.path.join(out_dir, fn), "r", encoding="utf-8") as f:
                    conversation_reports.append(ConversationReport.model_validate_json(f.read()))

        all_questions = []
        for r in conversation_reports:
            all_questions.extend(r.questions_asked)

        global_questions = dedupe_strings(all_questions, threshold=95)
        summary = BatchSummary(
            total_unique_conversations=len(conversation_reports),
            conversation_ids=[r.conversation_id for r in conversation_reports],
            global_distinct_questions=global_questions,
        )
        with open(os.path.join(out_dir, "batch_summary.json"), "w", encoding="utf-8") as f:
            f.write(summary.model_dump_json(indent=2))

        # patient_perspective.csv
        rows = []
        for r in conversation_reports:
            rows.append({
                "conversation_id": r.conversation_id,
                "source_files": "; ".join(r.source_files),
                "occurrence": " | ".join(r.patient_perspective.occurrence),
                "severity": " | ".join(r.patient_perspective.severity),
                "concerns": " | ".join(r.patient_perspective.concerns),
                "goals": " | ".join(r.patient_perspective.goals),
            })
        pd.DataFrame(rows).to_csv(os.path.join(out_dir, "patient_perspective.csv"), index=False)

        # entities.csv
        e_rows = []
        for r in conversation_reports:
            for cat in r.entities:
                for item in cat.items:
                    quotes = " || ".join([e.quote for e in item.evidence][:3])
                    e_rows.append({
                        "conversation_id": r.conversation_id,
                        "category": cat.category,
                        "entity": item.name,
                        "evidence_quotes": quotes,
                    })
        pd.DataFrame(e_rows).to_csv(os.path.join(out_dir, "entities.csv"), index=False)

        # entity_presence_sentiment.csv (MAIN TABULATION)
        ps_rows = []
        for r in conversation_reports:
            for row in (r.entity_presence_sentiment or []):
                quotes = " || ".join([e.quote for e in row.evidence][:2])
                ps_rows.append({
                    "conversation_id": r.conversation_id,
                    "category": row.category,
                    "entity": row.entity,
                    "present": row.present,
                    "sentiment": row.sentiment,
                    "evidence_quotes": quotes,
                })
        pd.DataFrame(ps_rows).to_csv(os.path.join(out_dir, "entity_presence_sentiment.csv"), index=False)

        # global_questions.csv
        pd.DataFrame({"question": global_questions}).to_csv(os.path.join(out_dir, "global_questions.csv"), index=False)

        print("\n✅ REBUILD DONE (no LLM calls)")
        print(f"- Total conversations: {summary.total_unique_conversations}")
        print(f"- Outputs saved to: {out_dir}")
        return


    # ✅ MAIN PROCESSING LOOP with better progress tracking
    conversation_reports: list[ConversationReport] = []
    all_questions: list[str] = []
    
    entity_list_json = config_entity_list_json(cfg)

    for idx, (cid, parts) in enumerate(groups, 1):
        # Filter by only_ids if provided
        if only_ids and cid not in only_ids:
            continue

        # Skip if already processed
        out_path = os.path.join(out_dir, f"{cid}.json")
        if os.path.exists(out_path):
            print(f"[{idx}/{len(groups)}] ⏭️  Skipping {cid} (already processed)")
            # Load existing report for summary
            with open(out_path, "r", encoding="utf-8") as f:
                existing = ConversationReport.model_validate_json(f.read())
                conversation_reports.append(existing)
                all_questions.extend(existing.questions_asked)
            continue

        print(f"\n{'─'*60}")
        print(f"[{idx}/{len(groups)}] 📄 Processing: {cid}")
        print(f"{'─'*60}")
        print(f"Source files: {', '.join([p.filename for p in parts])}")
        
        start_time = time.time()
        merged = merge_conversation_text(parts)
        
        # ✅ NEW: Check if transcript needs chunking
        tokens = estimate_token_count(merged)
        print(f"Transcript size: ~{tokens:,} tokens")
        
        chunks = chunk_transcript_if_needed(merged, max_transcript_tokens)
        if len(chunks) > 1:
            print(f"⚠️  Long transcript! Split into {len(chunks)} chunks for processing")

        # ✅ MAJOR CHANGE: SINGLE LLM CALL (or multiple for chunks)
        try:
            if len(chunks) == 1:
                # Single call for normal transcripts
                entity_list_json = config_entity_list_json(cfg)
                barriers_facilitators_json = config_barriers_facilitators_json(cfg)
                barriers_facilitators_flat_list = flatten_barriers_facilitators_list(cfg)  # ✅ ADDED
                bf_items_count = count_barriers_facilitators_items(cfg)  # ✅ ADDED
                
                prompt = COMPREHENSIVE_PROMPT_TEMPLATE.format(
                    rules=SYSTEM_RULES,
                    entity_list_json=entity_list_json,
                    barriers_facilitators_json=barriers_facilitators_json,
                    barriers_facilitators_flat_list=barriers_facilitators_flat_list,  # ✅ ADDED
                    bf_items_count=bf_items_count,  # ✅ ADDED
                    transcript=merged
                )
                
                resp = ollama_generate(
                    model=model, 
                    prompt=prompt, 
                    host=host, 
                    num_ctx=num_ctx, 
                    temperature=temperature,
                    timeout=timeout
                )
                
                # Save raw response
                with open(os.path.join(out_dir, f"{cid}_RAW.txt"), "w", encoding="utf-8") as f:
                    f.write(resp or "")
                
                data = safe_json_load(resp)
                
            else:
                # Multiple calls for chunked transcripts
                chunk_results = []
                for chunk_idx, chunk in enumerate(chunks, 1):
                    print(f"  Processing chunk {chunk_idx}/{len(chunks)}...")
                    
                    entity_list_json = config_entity_list_json(cfg)
                    barriers_facilitators_json = config_barriers_facilitators_json(cfg)
                    barriers_facilitators_flat_list = flatten_barriers_facilitators_list(cfg)  # ✅ ADDED
                    bf_items_count = count_barriers_facilitators_items(cfg)  # ✅ ADDED
                    
                    prompt = COMPREHENSIVE_PROMPT_TEMPLATE.format(
                        rules=SYSTEM_RULES,
                        entity_list_json=entity_list_json,
                        barriers_facilitators_json=barriers_facilitators_json,
                        barriers_facilitators_flat_list=barriers_facilitators_flat_list,  # ✅ ADDED
                        bf_items_count=bf_items_count,  # ✅ ADDED
                        transcript=chunk
                    )
                    
                    resp = ollama_generate(
                        model=model, 
                        prompt=prompt, 
                        host=host, 
                        num_ctx=num_ctx, 
                        temperature=temperature,
                        timeout=timeout
                    )
                    
                    chunk_data = safe_json_load(resp)
                    chunk_results.append(chunk_data)
                
                # Merge chunk results
                data = merge_chunked_results(chunk_results)
                
                # Save merged raw response
                with open(os.path.join(out_dir, f"{cid}_RAW.txt"), "w", encoding="utf-8") as f:
                    f.write(json.dumps(data, indent=2))

            # ✅ Parse and normalize with better error handling
            try:
                questions = dedupe_strings(data.get("questions_asked", []) or [], threshold=95)
            except Exception as e:
                print(f"  ⚠️  Failed to process questions: {e}")
                questions = []
            
            try:
                entities = normalize_entities(data.get("entities", []) or [])
            except Exception as e:
                print(f"  ⚠️  Failed to process entities: {e}")
                entities = []
            
            try:
                pp = normalize_patient_perspective(data.get("patient_perspective", {}) or {})
                pp = clean_patient_perspective(pp)
            except Exception as e:
                print(f"  ⚠️  Failed to process patient perspective: {e}")
                pp = PatientPerspective()
            
            try:
                presence_rows = ensure_complete_entity_presence(presence_rows, cfg.get("predefined_entities", {}))
            except Exception as e:
                print(f"  ⚠️  Failed to process entity presence: {e}")
                presence_rows = []
            
            try:
                bf_rows = normalize_barriers_facilitators(data.get("barriers_facilitators", []))
            except Exception as e:
                print(f"  ⚠️  Failed to process barriers/facilitators: {e}")
                bf_rows = []

            report = ConversationReport(
                conversation_id=cid,
                source_files=[p.filename for p in parts],
                questions_asked=questions,
                entities=entities,
                patient_perspective=pp,
                entity_presence_sentiment=presence_rows,
                barriers_facilitators=bf_rows,
            )

            conversation_reports.append(report)
            all_questions.extend(questions)

            # Save per-conversation JSON
            with open(os.path.join(out_dir, f"{cid}.json"), "w", encoding="utf-8") as f:
                f.write(report.model_dump_json(indent=2))

            elapsed = time.time() - start_time
            print(f"✅ Completed in {elapsed:.1f}s ({elapsed/60:.1f} minutes)")
            print(f"   - Questions found: {len(questions)}")
            print(f"   - Entities found: {sum(len(cat.items) for cat in entities)}")
            print(f"   - Predefined entities checked: {len(presence_rows)}")
            print(f"   - Barriers/Facilitators checked: {len(bf_rows)}")

        except Exception as e:
            print(f"❌ FAILED for conversation {cid}: {e}")
            print(f"   Continuing to next conversation...")
            continue

    # ✅ Generate batch summary
    print(f"\n{'='*60}")
    print(f"📊 GENERATING SUMMARY AND CSV OUTPUTS")
    print(f"{'='*60}\n")

    global_questions = dedupe_strings(all_questions, threshold=90)
    summary = BatchSummary(
        total_unique_conversations=len(conversation_reports),
        conversation_ids=[r.conversation_id for r in conversation_reports],
        global_distinct_questions=global_questions,
    )
    with open(os.path.join(out_dir, "batch_summary.json"), "w", encoding="utf-8") as f:
        f.write(summary.model_dump_json(indent=2))

    # patient_perspective.csv
    rows = []
    for r in conversation_reports:
        rows.append({
            "conversation_id": r.conversation_id,
            "source_files": "; ".join(r.source_files),
            "occurrence": " | ".join(r.patient_perspective.occurrence),
            "severity": " | ".join(r.patient_perspective.severity),
            "concerns": " | ".join(r.patient_perspective.concerns),
            "goals": " | ".join(r.patient_perspective.goals),
        })
    pd.DataFrame(rows).to_csv(os.path.join(out_dir, "patient_perspective.csv"), index=False)

    # entities.csv
    e_rows = []
    for r in conversation_reports:
        for cat in r.entities:
            for item in cat.items:
                quotes = " || ".join([e.quote for e in item.evidence][:3])
                e_rows.append({
                    "conversation_id": r.conversation_id,
                    "category": cat.category,
                    "entity": item.name,
                    "evidence_quotes": quotes,
                })
    pd.DataFrame(e_rows).to_csv(os.path.join(out_dir, "entities.csv"), index=False)

    # ✅ MAIN TABULATION: entity_presence_sentiment.csv
    ps_rows = []
    for r in conversation_reports:
        for row in (r.entity_presence_sentiment or []):
            quotes = " || ".join([e.quote for e in row.evidence][:2])
            ps_rows.append({
                "conversation_id": r.conversation_id,
                "category": row.category,
                "entity": row.entity,
                "present": row.present,
                "sentiment": row.sentiment,
                "evidence_quotes": quotes,
            })
    pd.DataFrame(ps_rows).to_csv(os.path.join(out_dir, "entity_presence_sentiment.csv"), index=False)

    # ✅ NEW: predefined entity presence + sentiment table
    ps_rows = []
    for r in conversation_reports:
        for row in (r.entity_presence_sentiment or []):
            quotes = " || ".join([e.quote for e in row.evidence][:2])
            ps_rows.append({
                "conversation_id": r.conversation_id,
                "category": row.category,
                "entity": row.entity,
                "present": row.present,
                "sentiment": row.sentiment,
                "evidence_quotes": quotes,
            })
    pd.DataFrame(ps_rows).to_csv(os.path.join(out_dir, "entity_presence_sentiment.csv"), index=False)

    # ✅ NEW: barriers_facilitators.csv
    bf_rows = []
    for r in conversation_reports:
        for row in (getattr(r, "barriers_facilitators", []) or []):
            quotes = " || ".join([e.quote for e in row.evidence][:2])
            bf_rows.append({
                "conversation_id": r.conversation_id,
                "category": row.category,
                "item": row.item,
                "present": row.present,
                "sentiment": row.sentiment,
                "impact_level": row.impact_level or "",
                "evidence_quotes": quotes,
            })
    pd.DataFrame(bf_rows).to_csv(os.path.join(out_dir, "barriers_facilitators.csv"), index=False)

    # global_questions.csv
    pd.DataFrame({"question": global_questions}).to_csv(os.path.join(out_dir, "global_questions.csv"), index=False)

    # global_questions.csv
    pd.DataFrame({"question": global_questions}).to_csv(os.path.join(out_dir, "global_questions.csv"), index=False)

    print("\n" + "="*60)
    print("✅ ALL PROCESSING COMPLETE!")
    print("="*60)
    print(f"📁 Output directory: {out_dir}")
    print(f"📊 Total conversations processed: {summary.total_unique_conversations}")
    print(f"❓ Total unique questions: {len(global_questions)}")
    print(f"\n📄 Generated files:")
    print(f"   1. batch_summary.json - Overall statistics")
    print(f"   2. patient_perspective.csv - Patient descriptions")
    print(f"   3. entities.csv - All extracted entities")
    print(f"   4. entity_presence_sentiment.csv - Medical entities tabulation")
    print(f"   5. barriers_facilitators.csv - ADHERENCE ANALYSIS (PM requirement)")
    print(f"   6. global_questions.csv - Distinct questions asked")
    print(f"   7. [ID].json files - Individual conversation reports")
    print("="*60 + "\n")

if __name__ == "__main__":
    main()


# To: -- test one sample
# if __name__ == "__main__":
#     main(only_ids=["BN1103"])