import ollama
import json
import pandas as pd
from pathlib import Path
from collections import Counter, defaultdict
from utils import setup_logging, save_json, load_json

logger = setup_logging()

def load_all_codes(codes_file='outputs/codes/all_codes_deductive.json'):
    """Load all extracted codes"""
    data = load_json(codes_file)
    return data['codes']

def generate_themes_from_codes(codes, model='llama3.3'):
    """Generate themes by grouping similar codes"""
    
    # Group codes by category
    category_codes = defaultdict(list)
    for code in codes:
        category = code.get('category', 'Unknown')
        category_codes[category].append(code['code'])
    
    themes = {}
    
    for category, code_list in category_codes.items():
        # Create prompt to group codes into themes
        prompt = f"""You are a qualitative researcher. Given these codes from patient interviews about diabetic retinopathy follow-up adherence, group them into 3-5 meaningful THEMES.

CODES:
{json.dumps(code_list[:100], indent=2)}  # Limit to avoid token limit

TASK: Group similar codes into themes. A theme is a broader pattern that captures multiple related codes.

OUTPUT FORMAT (JSON):
{{
  "themes": [
    {{
      "theme_name": "Access and Transportation Barriers",
      "codes": ["transportation barriers", "lack of access", "distance to clinic"],
      "description": "Challenges related to physically reaching healthcare facilities"
    }}
  ]
}}

Generate themes:"""

        try:
            response = ollama.generate(
                model=model,
                prompt=prompt,
                options={'temperature': 0.3, 'num_predict': 2000}
            )
            
            response_text = response['response']
            
            # Parse JSON
            if '```json' in response_text:
                json_str = response_text.split('```json')[1].split('```')[0]
            elif '```' in response_text:
                json_str = response_text.split('```')[1].split('```')[0]
            else:
                json_str = response_text
            
            theme_data = json.loads(json_str.strip())
            themes[category] = theme_data['themes']
            
        except Exception as e:
            logger.error(f"Error generating themes for {category}: {str(e)}")
            themes[category] = []
    
    return themes

def categorize_into_groups(codes):
    """Categorize codes into Barriers vs Facilitators"""
    barriers = []
    facilitators = []
    other = []
    
    for code in codes:
        category = code.get('category', '').upper()
        
        if 'BARRIER' in category:
            barriers.append(code)
        elif 'FACILITATOR' in category:
            facilitators.append(code)
        else:
            other.append(code)
    
    return {
        'barriers': barriers,
        'facilitators': facilitators,
        'other': other
    }

def create_summary_statistics(codes, themes, groups):
    """Create summary statistics"""
    
    stats = {
        'total_codes': len(codes),
        'unique_codes': len(set(c['code'] for c in codes)),
        'total_themes': sum(len(t) for t in themes.values()),
        'groups': {
            'barriers': len(groups['barriers']),
            'facilitators': len(groups['facilitators']),
            'other': len(groups['other'])
        },
        'codes_by_category': {},
        'top_20_codes': []
    }
    
    # Count codes by category
    category_counts = Counter(c.get('category', 'Unknown') for c in codes)
    stats['codes_by_category'] = dict(category_counts)
    
    # Top codes
    code_counts = Counter(c['code'] for c in codes)
    stats['top_20_codes'] = [
        {'code': code, 'count': count} 
        for code, count in code_counts.most_common(20)
    ]
    
    return stats

def export_to_excel(codes, themes, groups, stats, output_file='outputs/analysis_results.xlsx'):
    """Export all results to a single Excel file with multiple sheets"""
    
    with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
        
        # Sheet 1: All Codes
        df_codes = pd.DataFrame(codes)
        df_codes.to_excel(writer, sheet_name='All_Codes', index=False)
        
        # Sheet 2: Barriers
        df_barriers = pd.DataFrame(groups['barriers'])
        df_barriers.to_excel(writer, sheet_name='Barriers', index=False)
        
        # Sheet 3: Facilitators
        df_facilitators = pd.DataFrame(groups['facilitators'])
        df_facilitators.to_excel(writer, sheet_name='Facilitators', index=False)
        
        # Sheet 4: Themes
        themes_list = []
        for category, theme_group in themes.items():
            for theme in theme_group:
                themes_list.append({
                    'Category': category,
                    'Theme': theme.get('theme_name', ''),
                    'Description': theme.get('description', ''),
                    'Number_of_Codes': len(theme.get('codes', []))
                })
        df_themes = pd.DataFrame(themes_list)
        df_themes.to_excel(writer, sheet_name='Themes', index=False)
        
        # Sheet 5: Summary Statistics
        stats_rows = [
            {'Metric': 'Total Codes', 'Value': stats['total_codes']},
            {'Metric': 'Unique Codes', 'Value': stats['unique_codes']},
            {'Metric': 'Total Themes', 'Value': stats['total_themes']},
            {'Metric': 'Barriers Count', 'Value': stats['groups']['barriers']},
            {'Metric': 'Facilitators Count', 'Value': stats['groups']['facilitators']},
        ]
        df_stats = pd.DataFrame(stats_rows)
        df_stats.to_excel(writer, sheet_name='Summary', index=False)
        
        # Sheet 6: Top 20 Codes
        df_top = pd.DataFrame(stats['top_20_codes'])
        df_top.to_excel(writer, sheet_name='Top_20_Codes', index=False)
    
    logger.info(f"Excel file saved: {output_file}")

if __name__ == "__main__":
    from pathlib import Path
    
    # CREATE ALL REQUIRED DIRECTORIES FIRST
    Path('outputs/analysis').mkdir(parents=True, exist_ok=True)
    Path('outputs/themes').mkdir(parents=True, exist_ok=True)
    Path('outputs/visualizations').mkdir(parents=True, exist_ok=True)
    
    # Load codes
    codes = load_all_codes()
    logger.info(f"Loaded {len(codes)} codes")
    
    # Generate themes
    logger.info("Generating themes...")
    themes = generate_themes_from_codes(codes)
    save_json(themes, 'outputs/themes/all_themes.json')
    
    # Categorize into groups
    logger.info("Categorizing into groups...")
    groups = categorize_into_groups(codes)
    save_json(groups, 'outputs/themes/categorized_groups.json')
    
    # Create statistics
    stats = create_summary_statistics(codes, themes, groups)
    save_json(stats, 'outputs/analysis/summary_statistics.json')
    
    # Export to Excel
    logger.info("Exporting to Excel...")
    export_to_excel(codes, themes, groups, stats)
    
    # Print summary
    print("\n" + "="*60)
    print("ANALYSIS SUMMARY")
    print("="*60)
    print(f"Total Codes: {stats['total_codes']}")
    print(f"Unique Codes: {stats['unique_codes']}")
    print(f"Total Themes: {stats['total_themes']}")
    print(f"\nGroups:")
    print(f"  - Barriers: {stats['groups']['barriers']}")
    print(f"  - Facilitators: {stats['groups']['facilitators']}")
    print(f"  - Other: {stats['groups']['other']}")
    print("="*60)