import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from utils import load_json, setup_logging
from wordcloud import WordCloud

logger = setup_logging()

def create_heatmap_all_samples(codes, output_dir='outputs/visualizations'):
    """Create heat map for ALL 31 aspects (samples) with counts + percentages"""
    
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    # Create file_id x category matrix
    file_ids = sorted(set(c['file_id'] for c in codes))
    categories = sorted(set(c.get('category', 'Unknown') for c in codes))
    
    # Create count matrix
    matrix = np.zeros((len(file_ids), len(categories)))
    
    for code in codes:
        file_idx = file_ids.index(code['file_id'])
        cat_idx = categories.index(code.get('category', 'Unknown'))
        matrix[file_idx, cat_idx] += 1
    
    # Create DataFrame
    df_matrix = pd.DataFrame(
        matrix, 
        index=file_ids, 
        columns=categories
    )
    
    # --- ROW PERCENTAGE (each row sums to 100%) ---
    df_row_pct = df_matrix.div(df_matrix.sum(axis=1), axis=0) * 100
    
    # --- COLUMN PERCENTAGE (each column sums to 100%) ---
    df_col_pct = df_matrix.div(df_matrix.sum(axis=0), axis=1) * 100
    
    # --- TOTAL PERCENTAGE ---
    total_codes = df_matrix.sum().sum()
    df_total_pct = (df_matrix / total_codes) * 100
    
    # ========================================================================
    # HEATMAP 1: COUNT + ROW PERCENTAGE (count\nrow%)
    # ========================================================================
    annot_count_row = np.empty_like(df_matrix, dtype=object)
    for i in range(df_matrix.shape[0]):
        for j in range(df_matrix.shape[1]):
            count = int(df_matrix.iloc[i, j])
            row_pct = df_row_pct.iloc[i, j]
            annot_count_row[i, j] = f'{count}\n({row_pct:.1f}%)'
    
    plt.figure(figsize=(20, 14))
    sns.heatmap(df_matrix, annot=annot_count_row, fmt='', cmap='YlOrRd', 
                cbar_kws={'label': 'Code Count'},
                linewidths=0.5, linecolor='gray')
    plt.title('Heat Map: Count + Row Percentage\n(Count\n(Row%))', 
              fontsize=18, fontweight='bold', pad=20)
    plt.xlabel('Category', fontsize=14, fontweight='bold')
    plt.ylabel('Sample ID', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(fontsize=9)
    plt.tight_layout()
    plt.savefig(f'{output_dir}/heatmap_count_row_pct.png', dpi=300, bbox_inches='tight')
    plt.close()
    logger.info("✓ Heatmap: Count + Row % created")
    
    # ========================================================================
    # HEATMAP 2: COUNT + COLUMN PERCENTAGE (count\ncol%)
    # ========================================================================
    annot_count_col = np.empty_like(df_matrix, dtype=object)
    for i in range(df_matrix.shape[0]):
        for j in range(df_matrix.shape[1]):
            count = int(df_matrix.iloc[i, j])
            col_pct = df_col_pct.iloc[i, j]
            annot_count_col[i, j] = f'{count}\n({col_pct:.1f}%)'
    
    plt.figure(figsize=(20, 14))
    sns.heatmap(df_matrix, annot=annot_count_col, fmt='', cmap='Blues', 
                cbar_kws={'label': 'Code Count'},
                linewidths=0.5, linecolor='gray')
    plt.title('Heat Map: Count + Column Percentage\n(Count\n(Col%))', 
              fontsize=18, fontweight='bold', pad=20)
    plt.xlabel('Category', fontsize=14, fontweight='bold')
    plt.ylabel('Sample ID', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(fontsize=9)
    plt.tight_layout()
    plt.savefig(f'{output_dir}/heatmap_count_col_pct.png', dpi=300, bbox_inches='tight')
    plt.close()
    logger.info("✓ Heatmap: Count + Column % created")
    
    # ========================================================================
    # HEATMAP 3: COUNT + TOTAL PERCENTAGE (count\ntotal%)
    # ========================================================================
    annot_count_total = np.empty_like(df_matrix, dtype=object)
    for i in range(df_matrix.shape[0]):
        for j in range(df_matrix.shape[1]):
            count = int(df_matrix.iloc[i, j])
            total_pct = df_total_pct.iloc[i, j]
            annot_count_total[i, j] = f'{count}\n({total_pct:.1f}%)'
    
    plt.figure(figsize=(20, 14))
    sns.heatmap(df_matrix, annot=annot_count_total, fmt='', cmap='Greens', 
                cbar_kws={'label': 'Code Count'},
                linewidths=0.5, linecolor='gray')
    plt.title('Heat Map: Count + Total Percentage\n(Count\n(Total%))', 
              fontsize=18, fontweight='bold', pad=20)
    plt.xlabel('Category', fontsize=14, fontweight='bold')
    plt.ylabel('Sample ID', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(fontsize=9)
    plt.tight_layout()
    plt.savefig(f'{output_dir}/heatmap_count_total_pct.png', dpi=300, bbox_inches='tight')
    plt.close()
    logger.info("✓ Heatmap: Count + Total % created")
    
    # ========================================================================
    # HEATMAP 4: ALL THREE IN ONE (count\nrow%\ncol%\ntotal%) - Compact Version
    # ========================================================================
    annot_all = np.empty_like(df_matrix, dtype=object)
    for i in range(df_matrix.shape[0]):
        for j in range(df_matrix.shape[1]):
            count = int(df_matrix.iloc[i, j])
            row_pct = df_row_pct.iloc[i, j]
            col_pct = df_col_pct.iloc[i, j]
            total_pct = df_total_pct.iloc[i, j]
            annot_all[i, j] = f'{count}\nR:{row_pct:.1f}%\nC:{col_pct:.1f}%\nT:{total_pct:.1f}%'
    
    plt.figure(figsize=(24, 16))
    sns.heatmap(df_matrix, annot=annot_all, fmt='', cmap='RdYlGn_r', 
                cbar_kws={'label': 'Code Count'},
                linewidths=0.5, linecolor='gray',
                annot_kws={'fontsize': 7})
    plt.title('Heat Map: Count + All Percentages\n(Count | R:Row% | C:Col% | T:Total%)', 
              fontsize=18, fontweight='bold', pad=20)
    plt.xlabel('Category', fontsize=14, fontweight='bold')
    plt.ylabel('Sample ID', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(fontsize=9)
    plt.tight_layout()
    plt.savefig(f'{output_dir}/heatmap_count_all_pct.png', dpi=300, bbox_inches='tight')
    plt.close()
    logger.info("✓ Heatmap: Count + ALL percentages created")
    
    # ========================================================================
    # Save all data to Excel
    # ========================================================================
    with pd.ExcelWriter(f'{output_dir}/heatmap_data_combined.xlsx', engine='openpyxl') as writer:
        df_matrix.to_excel(writer, sheet_name='Counts')
        df_row_pct.to_excel(writer, sheet_name='Row_Percentage')
        df_col_pct.to_excel(writer, sheet_name='Column_Percentage')
        df_total_pct.to_excel(writer, sheet_name='Total_Percentage')
        
        # Combined sheet with all in one
        df_combined = pd.DataFrame(index=file_ids, columns=categories)
        for i in range(df_matrix.shape[0]):
            for j in range(df_matrix.shape[1]):
                count = int(df_matrix.iloc[i, j])
                row_pct = df_row_pct.iloc[i, j]
                col_pct = df_col_pct.iloc[i, j]
                total_pct = df_total_pct.iloc[i, j]
                df_combined.iloc[i, j] = f"{count} (R:{row_pct:.1f}% | C:{col_pct:.1f}% | T:{total_pct:.1f}%)"
        
        df_combined.to_excel(writer, sheet_name='Combined_All')
    
    logger.info("✓ Excel file with combined data created")
    
    # ========================================================================
    # Create detailed explanation
    # ========================================================================
    explanation = f"""
HEAT MAP PERCENTAGE EXPLANATIONS:
================================================================================

DATASET OVERVIEW:
- Total Samples (Rows): {len(file_ids)}
- Total Categories (Columns): {len(categories)}
- Total Codes Across All Cells: {int(total_codes)}

================================================================================
PERCENTAGE CALCULATION METHODS:
================================================================================

1. ROW PERCENTAGE (Row %):
   -------------------------
   Denominator: Total codes in EACH ROW (sum of all categories for that sample)
   Formula: (cell_value / row_sum) × 100
   
   Interpretation: 
   "What percentage of THIS SAMPLE's codes fall into this category?"
   
   Each row sums to 100%
   
   Example: 
   Sample dr_MD0610 has 50 total codes across all categories.
   If 10 codes are "BARRIER-SDOH":
   → Cell shows: 10 (20.0%)  [because 10/50 × 100 = 20%]

2. COLUMN PERCENTAGE (Col %):
   ---------------------------
   Denominator: Total codes in EACH COLUMN (sum across all samples for that category)
   Formula: (cell_value / column_sum) × 100
   
   Interpretation:
   "What percentage of THIS CATEGORY's codes come from this sample?"
   
   Each column sums to 100%
   
   Example:
   Category "BARRIER-SDOH" has 200 total codes across all samples.
   If Sample dr_MD0610 contributes 10 codes:
   → Cell shows: 10 (5.0%)  [because 10/200 × 100 = 5%]

3. TOTAL PERCENTAGE (Total %):
   ----------------------------
   Denominator: Total codes across ALL CELLS ({int(total_codes)})
   Formula: (cell_value / total_codes) × 100
   
   Interpretation:
   "What percentage of ALL codes in the entire dataset does this cell represent?"
   
   All cells sum to 100%
   
   Example:
   Total codes = {int(total_codes)}
   If a cell has 10 codes:
   → Cell shows: 10 ({(10/total_codes)*100:.2f}%)

================================================================================
HEATMAP FILES GENERATED:
================================================================================

1. heatmap_count_row_pct.png
   → Shows: Count (Row%)
   → Use when: Comparing how each sample distributes codes across categories

2. heatmap_count_col_pct.png
   → Shows: Count (Col%)
   → Use when: Comparing how each category is distributed across samples

3. heatmap_count_total_pct.png
   → Shows: Count (Total%)
   → Use when: Finding which sample-category combinations dominate the dataset

4. heatmap_count_all_pct.png
   → Shows: Count, Row%, Col%, Total% (all in one)
   → Use when: Need comprehensive view of all metrics

================================================================================
INTERPRETATION GUIDE:
================================================================================

HIGH ROW %: This category is a major focus for this particular sample
HIGH COL %: This sample contributes significantly to this category
HIGH TOTAL %: This combination is highly prevalent in the entire dataset

Example Cell: "25 (R:30.0% | C:12.5% | T:2.1%)"
- 25 codes in this cell
- 30% of this sample's codes are in this category (high focus)
- 12.5% of this category's codes come from this sample (moderate contributor)
- 2.1% of all codes in the dataset are in this cell (small overall proportion)

================================================================================
"""
    
    with open(f'{output_dir}/percentage_explanation_detailed.txt', 'w', encoding='utf-8') as f:
        f.write(explanation)
    
    logger.info("✓ Detailed explanation document created")
    
    return df_matrix, df_row_pct, df_col_pct, df_total_pct


def create_wordclouds(codes, output_dir='outputs/visualizations'):
    """Create word clouds for barriers and facilitators"""
    
    # Barriers
    barrier_codes = [c['code'] for c in codes if 'BARRIER' in c.get('category', '').upper()]
    if barrier_codes:
        wc_barrier = WordCloud(width=1200, height=600, background_color='white',
                               colormap='Reds', max_words=100).generate(' '.join(barrier_codes))
        plt.figure(figsize=(15, 8))
        plt.imshow(wc_barrier, interpolation='bilinear')
        plt.axis('off')
        plt.title('Word Cloud: Barriers to Follow-Up', fontsize=18, fontweight='bold', pad=20)
        plt.tight_layout()
        plt.savefig(f'{output_dir}/wordcloud_barriers.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    # Facilitators
    facilitator_codes = [c['code'] for c in codes if 'FACILITATOR' in c.get('category', '').upper()]
    if facilitator_codes:
        wc_facilitator = WordCloud(width=1200, height=600, background_color='white',
                                   colormap='Greens', max_words=100).generate(' '.join(facilitator_codes))
        plt.figure(figsize=(15, 8))
        plt.imshow(wc_facilitator, interpolation='bilinear')
        plt.axis('off')
        plt.title('Word Cloud: Facilitators to Follow-Up', fontsize=18, fontweight='bold', pad=20)
        plt.tight_layout()
        plt.savefig(f'{output_dir}/wordcloud_facilitators.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    logger.info("✓ Word clouds created")


def create_bar_charts(stats, output_dir='outputs/visualizations'):
    """Create bar charts for summary statistics"""
    
    # Top 20 codes
    top_codes = stats['top_20_codes']
    
    plt.figure(figsize=(14, 10))
    codes_list = [item['code'] for item in top_codes]
    counts_list = [item['count'] for item in top_codes]
    
    bars = plt.barh(codes_list, counts_list, color='steelblue', edgecolor='black')
    
    # Add value labels on bars
    for i, (bar, count) in enumerate(zip(bars, counts_list)):
        plt.text(count + 0.5, i, str(count), va='center', fontweight='bold')
    
    plt.xlabel('Frequency', fontsize=14, fontweight='bold')
    plt.ylabel('Code', fontsize=14, fontweight='bold')
    plt.title('Top 20 Most Frequent Codes', fontsize=18, fontweight='bold', pad=20)
    plt.gca().invert_yaxis()
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'{output_dir}/top_20_codes_bar_chart.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    logger.info("✓ Bar charts created")


if __name__ == "__main__":
    # Load data
    logger.info("Loading data...")
    codes = load_json('outputs/codes/all_codes_deductive.json')['codes']
    stats = load_json('outputs/analysis/summary_statistics.json')
    
    # Create visualizations
    logger.info("Creating heat maps for all 31 samples...")
    create_heatmap_all_samples(codes)
    
    logger.info("Creating word clouds...")
    create_wordclouds(codes)
    
    logger.info("Creating bar charts...")
    create_bar_charts(stats)
    
    logger.info("\n" + "="*80)
    logger.info("ALL VISUALIZATIONS COMPLETED SUCCESSFULLY!")
    logger.info("="*80)
    logger.info("\nGenerated Files:")
    logger.info("  - heatmap_count_row_pct.png (Count + Row %)")
    logger.info("  - heatmap_count_col_pct.png (Count + Column %)")
    logger.info("  - heatmap_count_total_pct.png (Count + Total %)")
    logger.info("  - heatmap_count_all_pct.png (Count + ALL %)")
    logger.info("  - heatmap_data_combined.xlsx (Excel with all data)")
    logger.info("  - percentage_explanation_detailed.txt")
    logger.info("  - wordcloud_barriers.png")
    logger.info("  - wordcloud_facilitators.png")
    logger.info("  - top_20_codes_bar_chart.png")
    logger.info("="*80)
