import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from utils import load_json, setup_logging
from wordcloud import WordCloud

logger = setup_logging()

def create_heatmap_all_samples(codes, output_dir='outputs/visualizations'):
    """Create heat map for ALL 31 aspects (samples)"""
    
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    # Create file_id x category matrix
    file_ids = sorted(set(c['file_id'] for c in codes))
    categories = sorted(set(c.get('category', 'Unknown') for c in codes))
    
    # Create count matrix
    matrix = np.zeros((len(file_ids), len(categories)))
    
    for code in codes:
        file_idx = file_ids.index(code['file_id'])
        cat_idx = categories.index(code.get('category', 'Unknown'))
        matrix[file_idx, cat_idx] += 1
    
    # Create DataFrame
    df_matrix = pd.DataFrame(
        matrix, 
        index=file_ids, 
        columns=categories
    )
    
    # --- ROW PERCENTAGE (each row sums to 100%) ---
    # Denominator: sum of each ROW (total codes per file)
    df_row_pct = df_matrix.div(df_matrix.sum(axis=1), axis=0) * 100
    
    # --- COLUMN PERCENTAGE (each column sums to 100%) ---
    # Denominator: sum of each COLUMN (total codes per category)
    df_col_pct = df_matrix.div(df_matrix.sum(axis=0), axis=1) * 100
    
    # --- TOTAL PERCENTAGE (denominator = total codes across all cells) ---
    total_codes = df_matrix.sum().sum()
    df_total_pct = (df_matrix / total_codes) * 100
    
    # Plot 1: Count Heat Map (all 31 samples)
    plt.figure(figsize=(16, 12))
    sns.heatmap(df_matrix, annot=True, fmt='.0f', cmap='YlOrRd', 
                cbar_kws={'label': 'Code Count'})
    plt.title('Heat Map: Code Counts Across All 31 Samples', fontsize=16, fontweight='bold')
    plt.xlabel('Category', fontsize=12)
    plt.ylabel('Sample ID', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(f'{output_dir}/heatmap_counts_all_31_samples.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Plot 2: Row Percentage Heat Map
    plt.figure(figsize=(16, 12))
    sns.heatmap(df_row_pct, annot=True, fmt='.1f', cmap='Blues',
                cbar_kws={'label': 'Row %'})
    plt.title('Heat Map: Row Percentage (each sample = 100%)', fontsize=16, fontweight='bold')
    plt.xlabel('Category', fontsize=12)
    plt.ylabel('Sample ID', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(f'{output_dir}/heatmap_row_percentage.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Plot 3: Column Percentage Heat Map
    plt.figure(figsize=(16, 12))
    sns.heatmap(df_col_pct, annot=True, fmt='.1f', cmap='Greens',
                cbar_kws={'label': 'Column %'})
    plt.title('Heat Map: Column Percentage (each category = 100%)', fontsize=16, fontweight='bold')
    plt.xlabel('Category', fontsize=12)
    plt.ylabel('Sample ID', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(f'{output_dir}/heatmap_column_percentage.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Save percentage data to Excel
    with pd.ExcelWriter(f'{output_dir}/heatmap_data.xlsx') as writer:
        df_matrix.to_excel(writer, sheet_name='Counts')
        df_row_pct.to_excel(writer, sheet_name='Row_Percentage')
        df_col_pct.to_excel(writer, sheet_name='Column_Percentage')
        df_total_pct.to_excel(writer, sheet_name='Total_Percentage')
    
    logger.info("Heat maps created for all 31 samples")
    
    # Create explanation document
    explanation = f"""
HEAT MAP PERCENTAGE EXPLANATIONS:

1. ROW PERCENTAGE:
   - Denominator: Total codes in each ROW (each sample/file)
   - Formula: (cell value / row sum) × 100
   - Interpretation: "What percentage of THIS SAMPLE's codes fall into each category?"
   - Each row sums to 100%
   - Example: If Sample dr_MD0610 has 50 total codes, and 10 are "BARRIER-SDOH",
             then that cell shows 20% (10/50 × 100)

2. COLUMN PERCENTAGE:
   - Denominator: Total codes in each COLUMN (each category)
   - Formula: (cell value / column sum) × 100
   - Interpretation: "What percentage of THIS CATEGORY's codes come from each sample?"
   - Each column sums to 100%
   - Example: If category "BARRIER-SDOH" has 200 total codes across all samples,
             and Sample dr_MD0610 contributes 10, then that cell shows 5% (10/200 × 100)

3. TOTAL PERCENTAGE:
   - Denominator: Total codes across ALL cells ({total_codes})
   - Formula: (cell value / total codes) × 100
   - Interpretation: "What percentage of ALL codes does this cell represent?"
   - All cells sum to 100%

MATRIX DIMENSIONS:
- Rows: {len(file_ids)} samples (file IDs)
- Columns: {len(categories)} categories
- Total cells: {len(file_ids) * len(categories)}
- Total codes: {total_codes}
"""
    
    with open(f'{output_dir}/percentage_explanation.txt', 'w') as f:
        f.write(explanation)
    
    return df_matrix, df_row_pct, df_col_pct

def create_wordclouds(codes, output_dir='outputs/visualizations'):
    """Create word clouds for barriers and facilitators"""
    
    # Barriers
    barrier_codes = [c['code'] for c in codes if 'BARRIER' in c.get('category', '').upper()]
    if barrier_codes:
        wc_barrier = WordCloud(width=800, height=400, background_color='white',
                               colormap='Reds').generate(' '.join(barrier_codes))
        plt.figure(figsize=(12, 6))
        plt.imshow(wc_barrier, interpolation='bilinear')
        plt.axis('off')
        plt.title('Word Cloud: Barriers', fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.savefig(f'{output_dir}/wordcloud_barriers.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    # Facilitators
    facilitator_codes = [c['code'] for c in codes if 'FACILITATOR' in c.get('category', '').upper()]
    if facilitator_codes:
        wc_facilitator = WordCloud(width=800, height=400, background_color='white',
                                   colormap='Greens').generate(' '.join(facilitator_codes))
        plt.figure(figsize=(12, 6))
        plt.imshow(wc_facilitator, interpolation='bilinear')
        plt.axis('off')
        plt.title('Word Cloud: Facilitators', fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.savefig(f'{output_dir}/wordcloud_facilitators.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    logger.info("Word clouds created")

def create_bar_charts(stats, output_dir='outputs/visualizations'):
    """Create bar charts for summary statistics"""
    
    # Top 20 codes
    top_codes = stats['top_20_codes']
    
    plt.figure(figsize=(12, 8))
    codes_list = [item['code'] for item in top_codes]
    counts_list = [item['count'] for item in top_codes]
    
    plt.barh(codes_list, counts_list, color='steelblue')
    plt.xlabel('Frequency', fontsize=12)
    plt.ylabel('Code', fontsize=12)
    plt.title('Top 20 Most Frequent Codes', fontsize=16, fontweight='bold')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(f'{output_dir}/top_20_codes_bar_chart.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    logger.info("Bar charts created")

if __name__ == "__main__":
    # Load data
    codes = load_json('outputs/codes/all_codes_deductive.json')['codes']
    stats = load_json('outputs/analysis/summary_statistics.json')
    
    # Create visualizations
    logger.info("Creating heat maps for all 31 samples...")
    create_heatmap_all_samples(codes)
    
    logger.info("Creating word clouds...")
    create_wordclouds(codes)
    
    logger.info("Creating bar charts...")
    create_bar_charts(stats)
    
    logger.info("All visualizations completed!")