import ollama
import json
from pathlib import Path
from utils import setup_logging, save_json, load_json
from tqdm import tqdm

logger = setup_logging()

def load_framework(framework_file='prompts/deductive_framework.txt'):
    """Load the deductive framework"""
    with open(framework_file, 'r') as f:
        return f.read()

def create_coding_prompt(chunk, framework):
    """Create the few-shot coding prompt"""
    with open('prompts/coding_prompt.txt', 'r') as f:
        template = f.read()
    
    return template.format(framework=framework, chunk=chunk)

def extract_codes_llama(chunk, framework, model='llama3.3'):
    """Extract codes using Llama3.3 via Ollama"""
    prompt = create_coding_prompt(chunk, framework)
    
    try:
        # CHANGE: Use ollama.chat instead of ollama.generate for better control
        response = ollama.chat(
            model=model,
            messages=[
                {
                    'role': 'system',
                    'content': 'You are an expert qualitative researcher performing deductive thematic analysis.'
                },
                {
                    'role': 'user',
                    'content': prompt
                }
            ],
            options={
                'temperature': 0.3,
                'num_predict': 3000  # Increase token limit
            }
        )
        
        response_text = response['message']['content']  # CHANGED: different response structure
        
        # ADD DEBUGGING
        logger.info(f"LLM Raw Response: {response_text[:200]}...")  # Log first 200 chars
        
        # Better JSON extraction
        if '```json' in response_text:
            json_str = response_text.split('```json')[1].split('```')[0]
        elif '```' in response_text:
            json_str = response_text.split('```')[1].split('```')[0]
        else:
            json_str = response_text
        
        json_str = json_str.strip()
        
        # ADD: Handle malformed JSON
        try:
            codes_data = json.loads(json_str)
        except json.JSONDecodeError as je:
            logger.error(f"JSON Parse Error: {je}")
            logger.error(f"Attempted to parse: {json_str[:500]}")
            return []  # Return empty list instead of crashing
        
        # VALIDATE structure
        if 'codes' not in codes_data:
            logger.warning("Response missing 'codes' key")
            return []
            
        return codes_data['codes']
    
    except Exception as e:
        logger.error(f"Error extracting codes: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())  # ADD: Full error trace
        return []


def process_all_transcripts_deductive(input_dir, output_dir, model='llama3.3'):
    """Process all transcripts with deductive coding"""
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    Path('outputs/analysis').mkdir(parents=True, exist_ok=True)  # ADD THIS
    Path('outputs/themes').mkdir(parents=True, exist_ok=True)    # ADD THIS
    Path('outputs/visualizations').mkdir(parents=True, exist_ok=True)  # ADD THIS

    
    framework = load_framework()
    all_codes = []
    
    processed_files = list(Path(input_dir).glob('*_processed.json'))
    
    for json_file in tqdm(processed_files, desc="Coding transcripts"):
        logger.info(f"Processing: {json_file.name}")
        
        data = load_json(json_file)
        file_id = data['file_id']
        chunks = data['chunks']
        
        file_codes = []
        
        for i, chunk in enumerate(tqdm(chunks, desc=f"  Chunks", leave=False)):
            codes = extract_codes_llama(chunk, framework, model)
            
            # Add metadata
            for code in codes:
                code['file_id'] = file_id
                code['chunk_id'] = i
            
            file_codes.extend(codes)
        
        # Save per-file codes
        output_file = Path(output_dir) / f"{file_id}_codes.json"
        save_json({
            'file_id': file_id,
            'total_codes': len(file_codes),
            'codes': file_codes
        }, output_file)
        
        all_codes.extend(file_codes)
        logger.info(f"  -> Extracted {len(file_codes)} codes")
    
    # Save all codes
    save_json({
        'total_codes': len(all_codes),
        'unique_codes': len(set(c['code'] for c in all_codes)),
        'codes': all_codes
    }, Path(output_dir) / 'all_codes_deductive.json')
    
    logger.info(f"Total codes extracted: {len(all_codes)}")
    logger.info(f"Unique codes: {len(set(c['code'] for c in all_codes))}")
    
    return all_codes

if __name__ == "__main__":
    input_dir = "data/processed/chunked"
    output_dir = "outputs/codes"
    
    codes = process_all_transcripts_deductive(input_dir, output_dir, model='llama3.3')