"""
Barriers & Facilitators Heatmap Generator (IMPROVED VERSION)
Creates a heatmap with:
1. Sample counts in each cell (n=X)
2. Margin totals for rows and columns
"""
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def create_barriers_facilitators_heatmap(csv_path, output_path='barriers_facilitators_heatmap.png'):
    """
    Create a heatmap showing percentage and count of patients mentioning each barrier/facilitator
    WITH margin totals for rows and columns
    
    Args:
        csv_path: Path to barriers_facilitators.csv
        output_path: Where to save the heatmap image
    """
    
    # Load data
    df = pd.read_csv(csv_path)
    
    # Filter only present items (True)
    present_df = df[df['present'] == True].copy()
    
    # Count occurrences and calculate percentages
    item_counts = present_df.groupby(['category', 'item']).size().reset_index(name='count')
    total_patients = df['conversation_id'].nunique()
    item_counts['percentage'] = (item_counts['count'] / total_patients * 100).round(1)
    
    # Map categories to main groups (as requested by PM)
    category_mapping = {
        'Transportation_Barriers': 'SDOH',
        'Access_Barriers': 'SDOH', 
        'Family_Social_Barriers': 'SDOH',
        'Cost_Barriers': 'Cost',
        'Belief_Barriers': 'Beliefs',
        'Provider_System_Barriers': 'Systems',
        'Education_Barriers': 'Systems',
        'Support_Facilitators': 'Facilitators',
        'Provider_Facilitators': 'Facilitators',
        'Access_Facilitators': 'Facilitators',
        'Technology_Facilitators': 'Facilitators',
        'Education_Facilitators': 'Facilitators'
    }
    
    item_counts['main_category'] = item_counts['category'].map(category_mapping)
    
    # Select top 15 most frequent items
    top_items = item_counts.nlargest(15, 'count')
    
    print(f"\n{'='*80}")
    print(f"HEATMAP GENERATION (WITH MARGIN TOTALS)")
    print(f"{'='*80}")
    print(f"Total patients analyzed: {total_patients}")
    print(f"Total items with mentions: {len(item_counts)}")
    print(f"Selected top 15 items for visualization")
    print(f"\nTop 15 items selected:")
    for idx, row in top_items.iterrows():
        print(f"  {row['main_category']:12} | {row['item']:30} | {row['count']:2} patients ({row['percentage']:4.1f}%)")
    
    # Prepare data for heatmap
    heatmap_data = []
    for _, row in top_items.iterrows():
        heatmap_data.append({
            'main_category': row['main_category'],
            'item': row['item'],
            'count': row['count'],
            'percentage': row['percentage']
        })
    
    # Create pivot tables for both count and percentage
    heatmap_df = pd.DataFrame(heatmap_data)
    
    # Pivot for counts
    pivot_counts = heatmap_df.pivot_table(
        values='count', 
        index='main_category', 
        columns='item', 
        fill_value=0
    )
    
    # Pivot for percentages
    pivot_percentages = heatmap_df.pivot_table(
        values='percentage', 
        index='main_category', 
        columns='item', 
        fill_value=0
    )
    
    # Reorder main categories (Barriers first, then Facilitators)
    category_order = ['SDOH', 'Beliefs', 'Systems', 'Cost', 'Facilitators']
    pivot_counts = pivot_counts.reindex([cat for cat in category_order if cat in pivot_counts.index])
    pivot_percentages = pivot_percentages.reindex([cat for cat in category_order if cat in pivot_percentages.index])
    
    # ✅ ADD MARGIN TOTALS
    # Column totals (sum across all categories for each item)
    pivot_counts.loc['TOTAL'] = pivot_counts.sum(axis=0)
    pivot_percentages.loc['TOTAL'] = pivot_percentages.sum(axis=0)  # This sums percentages (not ideal but shows magnitude)
    
    # Row totals (sum across all items for each category)
    pivot_counts['TOTAL'] = pivot_counts.sum(axis=1)
    pivot_percentages['TOTAL'] = pivot_percentages.sum(axis=1)
    
    # Create annotations with "n=X (Y%)" format
    annotations = pivot_counts.astype(int).astype(str) + '\n(' + pivot_percentages.round(1).astype(str) + '%)'
    
    # For cells with 0, show empty
    annotations = annotations.replace('0\n(0.0%)', '')
    
    # Create figure with larger size to accommodate margin totals
    fig, ax = plt.subplots(figsize=(22, 9))
    
    # Create heatmap using percentages for color, but showing both count and percentage
    sns.heatmap(
        pivot_percentages,
        annot=annotations,  # Show "n=X (Y%)"
        fmt='',  # Use our custom format
        cmap='YlOrRd',  # Yellow-Orange-Red color scheme
        cbar_kws={'label': 'Percentage of Patients (%)'},
        linewidths=0.5,
        linecolor='gray',
        ax=ax,
        vmin=0,
        vmax=100
    )
    
    # Customize plot
    ax.set_xlabel('Specific Barriers/Facilitators', fontsize=14, fontweight='bold')
    ax.set_ylabel('Main Categories', fontsize=14, fontweight='bold')
    ax.set_title(
        f'Barriers & Facilitators to DR Follow-up Adherence\n'
        f'(Top 15 Most Frequent Items Across {total_patients} Patients)\n'
        f'Each cell shows: Sample Count (Percentage%)',
        fontsize=16,
        fontweight='bold',
        pad=20
    )
    
    # Rotate x-axis labels for better readability
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(rotation=0, fontsize=12)
    
    # Make TOTAL row and column labels bold
    ylabels = ax.get_yticklabels()
    for label in ylabels:
        if 'TOTAL' in label.get_text():
            label.set_weight('bold')
    
    xlabels = ax.get_xticklabels()
    for label in xlabels:
        if 'TOTAL' in label.get_text():
            label.set_weight('bold')
    
    # Adjust layout to prevent label cutoff
    plt.tight_layout()
    
    # Save figure
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\n✅ Heatmap saved to: {output_path}")
    print(f"{'='*80}\n")
    
    # Also save the data tables
    csv_output_counts = output_path.replace('.png', '_counts.csv')
    csv_output_percentages = output_path.replace('.png', '_percentages.csv')
    
    pivot_counts.to_csv(csv_output_counts)
    pivot_percentages.to_csv(csv_output_percentages)
    
    print(f"✅ Count data saved to: {csv_output_counts}")
    print(f"✅ Percentage data saved to: {csv_output_percentages}")
    
    return pivot_counts, pivot_percentages


if __name__ == "__main__":
    # Use the barriers_facilitators.csv from your outputs
    csv_path = '/home/sandhiya/dr-transcripts-reasons-deductive/outputs/barriers_facilitators.csv'
    output_path = '/home/sandhiya/dr-transcripts-reasons-deductive/outputs/barriers_facilitators_heatmap.png'
    
    pivot_counts, pivot_percentages = create_barriers_facilitators_heatmap(csv_path, output_path)
    
    print("\n✅ Heatmap generation complete!")
    print("\nNOTE: If you want to adjust:")
    print("  - Number of items: Change 'top_items = item_counts.nlargest(15, 'count')' to your desired number")
    print("  - Color scheme: Change 'cmap' parameter (try 'RdYlGn', 'viridis', 'coolwarm')")
    print("  - Figure size: Change 'figsize=(22, 9)' to adjust dimensions")
    print("\nThe heatmap now shows:")
    print("  ✓ Sample count in each cell (n=X)")
    print("  ✓ Percentage in each cell (Y%)")
    print("  ✓ Margin totals for rows (rightmost column)")
    print("  ✓ Margin totals for columns (bottom row)")