"""
Medical Transcript Entity Extraction System
Using Local LLM (Llama 3.1 8B Instruct / Gemma)
"""

import json
import os
import re
from pathlib import Path
from typing import Dict, List, Tuple
import pandas as pd
from datetime import datetime

# PDF extraction
import PyPDF2

# For LLM - we'll use API-style approach since local models need Ollama/similar
# You'll need to have Ollama running with llama3.1:8b or gemma models


class MedicalTranscriptAnalyzer:
    """Analyzes medical transcripts and extracts structured entities"""
    
    def __init__(self, model_name="llama3.1:8b"):
        """
        Initialize the analyzer
        
        Args:
            model_name: Name of the LLM model (llama3.1:8b or gemma)
        """
        self.model_name = model_name
        
        # Define all entities to extract
        self.entities = {
            "Symptoms": [
                "Blurred vision",
                "Fluctuating vision", 
                "Dark or floating spots (floaters)",
                "Poor night vision",
                "Faded colors"
            ],
            "Ophthalmic Findings": [
                "Microaneurysms",
                "Hard exudates",
                "Soft exudates (cotton wool spots)",
                "Retinal hemorrhages",
                "Macular edema",
                "Neovascularization",
                "Scar tissue"
            ],
            "Diagnostic Tools": [
                "Dilated eye exam",
                "Optical coherence tomography (OCT)",
                "Fundus photography",
                "Fluorescein angiography"
            ],
            "Systemic Risk Factors": [
                "Blood sugar levels (HbA1c)",
                "Blood pressure (BP)",
                "Cholesterol levels"
            ],
            "Treatment Options": [
                "Intravitreal injections",
                "Laser treatment",
                "Surgical interventions (vitrectomy)"
            ],
            "Demographics/History": [
                "Duration of diabetes",
                "Age over 65",
                "Smoking status"
            ]
        }
        
        # Flatten for easy access
        self.all_entities = []
        for category, items in self.entities.items():
            for item in items:
                self.all_entities.append(f"{category}: {item}")
    
    def extract_text_from_pdf(self, pdf_path: str) -> str:
        """Extract text from PDF file"""
        text = ""
        try:
            with open(pdf_path, 'rb') as file:
                pdf_reader = PyPDF2.PdfReader(file)
                for page in pdf_reader.pages:
                    text += page.extract_text() + "\n"
        except Exception as e:
            print(f"Error reading {pdf_path}: {e}")
        return text
    
    def identify_conversation_id(self, filename: str) -> str:
        """
        Extract conversation ID from filename
        Handle multi-part conversations (e.g., part1, part2)
        """
        # Remove extension
        base_name = Path(filename).stem
        
        # Extract patient ID (pattern: dr_XXNNNN)
        match = re.search(r'dr_([A-Z]{2}\d{4})', base_name)
        if match:
            patient_id = match.group(1)
            return patient_id
        
        return base_name
    
    def create_few_shot_prompt(self) -> str:
        """Create few-shot examples for better LLM performance"""
        
        few_shot_examples = """
EXAMPLE 1:
Doctor: "How is your vision lately?"
Patient: "I've been seeing floaters, like dark spots moving around. It's been getting worse over the past month."
Doctor: "Any blurred vision?"
Patient: "Yes, especially when I try to read. Everything looks fuzzy."

EXTRACTION:
{
  "Symptoms: Blurred vision": {"present": true, "sentiment": "negative", "details": "Vision is fuzzy, especially when reading"},
  "Symptoms: Dark or floating spots (floaters)": {"present": true, "sentiment": "negative", "details": "Dark spots moving around, worsening over past month"},
  "Patient Concerns": "Vision problems affecting daily activities like reading",
  "Severity": "Moderate to severe - worsening over time"
}

EXAMPLE 2:
Doctor: "What brings you in today?"
Patient: "I've had diabetes for 15 years now. My blood sugar has been hard to control."
Doctor: "What's your latest HbA1c?"
Patient: "It was 8.5 last month. I'm really worried about my eyes."

EXTRACTION:
{
  "Demographics/History: Duration of diabetes": {"present": true, "sentiment": "neutral", "details": "15 years"},
  "Systemic Risk Factors: Blood sugar levels (HbA1c)": {"present": true, "sentiment": "negative", "details": "HbA1c 8.5 - poorly controlled"},
  "Patient Concerns": "Worried about eye complications from diabetes",
  "Patient Goals": "Better control of blood sugar to prevent eye damage"
}
"""
        return few_shot_examples
    
    def build_extraction_prompt(self, transcript_text: str) -> str:
        """Build the prompt for entity extraction with sentiment analysis"""
        
        entity_list = "\n".join([f"- {entity}" for entity in self.all_entities])
        
        prompt = f"""You are a medical data extraction expert. Your task is to analyze a doctor-patient conversation transcript and extract specific medical entities.

{self.create_few_shot_prompt()}

Now analyze the following transcript:

TRANSCRIPT:
{transcript_text[:4000]}  

TASK:
1. For each entity below, determine if it is mentioned (present/absent)
2. If present, extract:
   - Sentiment (positive/neutral/negative)
   - Specific details/context from the conversation
3. Also identify:
   - Main questions asked by the doctor
   - Patient's concerns and worries
   - Patient's goals and what they hope to achieve
   - Occurrence patterns (how often, when)
   - Severity level mentioned

ENTITIES TO EXTRACT:
{entity_list}

RESPONSE FORMAT (JSON only):
{{
  "entity_name": {{
    "present": true/false,
    "sentiment": "positive/neutral/negative",
    "details": "specific quote or summary from conversation",
    "occurrence": "frequency/timing if mentioned",
    "severity": "mild/moderate/severe if mentioned"
  }},
  "doctor_questions": ["question 1", "question 2"],
  "patient_concerns": "summary of patient worries",
  "patient_goals": "what patient hopes to achieve",
  "overall_severity": "assessment"
}}

Respond ONLY with valid JSON, no additional text."""

        return prompt
    
    def call_llm(self, prompt: str) -> str:
        """
        Call local LLM using Ollama API
        Make sure Ollama is running: ollama run llama3.1:8b
        """
        try:
            import requests
            
            response = requests.post(
                'http://localhost:11434/api/generate',
                json={
                    'model': self.model_name,
                    'prompt': prompt,
                    'stream': False,
                    'temperature': 0.1,  # Low temperature for consistent extraction
                    'top_p': 0.9
                },
                timeout=120
            )
            
            if response.status_code == 200:
                result = response.json()
                return result.get('response', '')
            else:
                print(f"LLM API Error: {response.status_code}")
                return ""
                
        except Exception as e:
            print(f"Error calling LLM: {e}")
            return ""
    
    def parse_llm_response(self, response: str) -> Dict:
        """Parse JSON response from LLM"""
        try:
            # Try to extract JSON from response
            json_match = re.search(r'\{.*\}', response, re.DOTALL)
            if json_match:
                json_str = json_match.group()
                return json.loads(json_str)
            else:
                return {}
        except json.JSONDecodeError as e:
            print(f"JSON parsing error: {e}")
            return {}
    
    def analyze_transcript(self, pdf_path: str) -> Dict:
        """Analyze a single transcript"""
        
        print(f"\nAnalyzing: {Path(pdf_path).name}")
        
        # Extract text
        transcript_text = self.extract_text_from_pdf(pdf_path)
        
        if not transcript_text.strip():
            print(f"Warning: No text extracted from {pdf_path}")
            return {}
        
        # Get conversation ID
        conversation_id = self.identify_conversation_id(Path(pdf_path).name)
        
        # Build prompt
        prompt = self.build_extraction_prompt(transcript_text)
        
        # Call LLM
        print(f"Calling {self.model_name}...")
        llm_response = self.call_llm(prompt)
        
        # Parse response
        extracted_data = self.parse_llm_response(llm_response)
        
        # Add metadata
        extracted_data['conversation_id'] = conversation_id
        extracted_data['filename'] = Path(pdf_path).name
        extracted_data['raw_text_length'] = len(transcript_text)
        
        return extracted_data
    
    def analyze_all_transcripts(self, pdf_directory: str) -> List[Dict]:
        """Analyze all PDF transcripts in directory"""
        
        pdf_files = list(Path(pdf_directory).glob("*.pdf"))
        print(f"Found {len(pdf_files)} PDF files")
        
        results = []
        
        for pdf_file in sorted(pdf_files):
            result = self.analyze_transcript(str(pdf_file))
            if result:
                results.append(result)
        
        return results
    
    def create_tabulated_output(self, results: List[Dict], output_path: str):
        """Create comprehensive Excel output with multiple sheets"""
        
        writer = pd.ExcelWriter(output_path, engine='openpyxl')
        
        # Sheet 1: Summary Statistics
        summary_data = {
            'Total Conversations': len(set([r.get('conversation_id', '') for r in results])),
            'Total Documents': len(results),
            'Analysis Date': datetime.now().strftime('%Y-%m-%d %H:%M'),
            'Model Used': self.model_name
        }
        pd.DataFrame([summary_data]).to_excel(writer, sheet_name='Summary', index=False)
        
        # Sheet 2: Entity Presence/Absence Matrix
        entity_matrix = []
        
        for result in results:
            row = {
                'Conversation_ID': result.get('conversation_id', ''),
                'Filename': result.get('filename', '')
            }
            
            # Add each entity
            for entity in self.all_entities:
                entity_data = result.get(entity, {})
                if isinstance(entity_data, dict):
                    row[f"{entity}_Present"] = 'Yes' if entity_data.get('present', False) else 'No'
                    row[f"{entity}_Sentiment"] = entity_data.get('sentiment', 'N/A')
                else:
                    row[f"{entity}_Present"] = 'No'
                    row[f"{entity}_Sentiment"] = 'N/A'
            
            entity_matrix.append(row)
        
        df_entities = pd.DataFrame(entity_matrix)
        df_entities.to_excel(writer, sheet_name='Entity_Matrix', index=False)
        
        # Sheet 3: Detailed Extractions
        detailed_data = []
        
        for result in results:
            for entity in self.all_entities:
                entity_data = result.get(entity, {})
                if isinstance(entity_data, dict) and entity_data.get('present', False):
                    detailed_data.append({
                        'Conversation_ID': result.get('conversation_id', ''),
                        'Filename': result.get('filename', ''),
                        'Entity': entity,
                        'Sentiment': entity_data.get('sentiment', ''),
                        'Details': entity_data.get('details', ''),
                        'Occurrence': entity_data.get('occurrence', ''),
                        'Severity': entity_data.get('severity', '')
                    })
        
        df_detailed = pd.DataFrame(detailed_data)
        df_detailed.to_excel(writer, sheet_name='Detailed_Extractions', index=False)
        
        # Sheet 4: Questions Asked
        questions_data = []
        for result in results:
            questions = result.get('doctor_questions', [])
            if questions:
                questions_data.append({
                    'Conversation_ID': result.get('conversation_id', ''),
                    'Questions': '\n'.join(questions) if isinstance(questions, list) else str(questions)
                })
        
        if questions_data:
            df_questions = pd.DataFrame(questions_data)
            df_questions.to_excel(writer, sheet_name='Doctor_Questions', index=False)
        
        # Sheet 5: Patient Perspectives
        patient_data = []
        for result in results:
            patient_data.append({
                'Conversation_ID': result.get('conversation_id', ''),
                'Filename': result.get('filename', ''),
                'Concerns': result.get('patient_concerns', ''),
                'Goals': result.get('patient_goals', ''),
                'Overall_Severity': result.get('overall_severity', '')
            })
        
        df_patient = pd.DataFrame(patient_data)
        df_patient.to_excel(writer, sheet_name='Patient_Perspectives', index=False)
        
        writer.close()
        print(f"\n✓ Tabulated output saved to: {output_path}")
    
    def generate_analysis_report(self, results: List[Dict], output_path: str):
        """Generate a comprehensive text report"""
        
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write("=" * 80 + "\n")
            f.write("MEDICAL TRANSCRIPT ANALYSIS REPORT\n")
            f.write("=" * 80 + "\n\n")
            
            # Basic Statistics
            unique_conversations = len(set([r.get('conversation_id', '') for r in results]))
            f.write(f"1. SUMMARY STATISTICS\n")
            f.write(f"   - Total Unique Conversations: {unique_conversations}\n")
            f.write(f"   - Total Documents Analyzed: {len(results)}\n")
            f.write(f"   - Model Used: {self.model_name}\n")
            f.write(f"   - Analysis Date: {datetime.now().strftime('%Y-%m-%d %H:%M')}\n\n")
            
            # Most Common Questions
            f.write(f"2. COMMON DOCTOR QUESTIONS\n")
            all_questions = []
            for result in results:
                questions = result.get('doctor_questions', [])
                if isinstance(questions, list):
                    all_questions.extend(questions)
            
            unique_questions = list(set(all_questions))
            for i, q in enumerate(unique_questions[:10], 1):
                f.write(f"   {i}. {q}\n")
            f.write("\n")
            
            # Entity Prevalence
            f.write(f"3. ENTITY PREVALENCE\n")
            entity_counts = {}
            for entity in self.all_entities:
                count = sum(1 for r in results if r.get(entity, {}).get('present', False))
                if count > 0:
                    entity_counts[entity] = count
            
            sorted_entities = sorted(entity_counts.items(), key=lambda x: x[1], reverse=True)
            for entity, count in sorted_entities[:15]:
                percentage = (count / len(results)) * 100
                f.write(f"   - {entity}: {count}/{len(results)} ({percentage:.1f}%)\n")
            f.write("\n")
            
            # Sentiment Analysis
            f.write(f"4. SENTIMENT OVERVIEW\n")
            sentiment_counts = {'positive': 0, 'neutral': 0, 'negative': 0}
            for result in results:
                for entity in self.all_entities:
                    entity_data = result.get(entity, {})
                    if isinstance(entity_data, dict):
                        sentiment = entity_data.get('sentiment', 'neutral')
                        if sentiment in sentiment_counts:
                            sentiment_counts[sentiment] += 1
            
            total = sum(sentiment_counts.values())
            for sentiment, count in sentiment_counts.items():
                percentage = (count / total * 100) if total > 0 else 0
                f.write(f"   - {sentiment.capitalize()}: {count} ({percentage:.1f}%)\n")
            f.write("\n")
            
            f.write("=" * 80 + "\n")
            f.write("END OF REPORT\n")
            f.write("=" * 80 + "\n")
        
        print(f"✓ Analysis report saved to: {output_path}")


def main():
    """Main execution function"""
    
    print("=" * 80)
    print("MEDICAL TRANSCRIPT ENTITY EXTRACTION SYSTEM")
    print("=" * 80)
    print()
    
    # Configuration
    PDF_DIRECTORY = os.path.expanduser("~/DR-Transcripts-Claude/pdf_uploads")  # Change this to your PDF directory
    OUTPUT_DIR = "/home/sandhiya/DR-Transcripts-Claude/outputs"
    
    # Create output directory
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    # Choose model
    print("Available models:")
    print("1. llama3.1:8b (Llama 3.1 8B Instruct)")
    print("2. gemma:7b (Google Gemma 7B)")
    print()
    
    model_choice = input("Select model (1 or 2, default=1): ").strip() or "1"
    
    model_name = "llama3.1:8b" if model_choice == "1" else "gemma:7b"
    
    print(f"\nUsing model: {model_name}")
    print("\nNote: Make sure Ollama is running with the selected model!")
    print(f"Run: ollama run {model_name}")
    print()
    
    input("Press Enter to continue...")
    
    # Initialize analyzer
    analyzer = MedicalTranscriptAnalyzer(model_name=model_name)
    
    # Analyze all transcripts
    print("\nStarting analysis...")
    results = analyzer.analyze_all_transcripts(PDF_DIRECTORY)
    
    if not results:
        print("No results generated. Check if PDFs exist and Ollama is running.")
        return
    
    # Save results
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # Save raw JSON
    json_output = f"{OUTPUT_DIR}/extracted_data_{timestamp}.json"
    with open(json_output, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print(f"✓ Raw JSON saved to: {json_output}")
    
    # Create Excel tabulation
    excel_output = f"{OUTPUT_DIR}/analysis_tabulation_{timestamp}.xlsx"
    analyzer.create_tabulated_output(results, excel_output)
    
    # Generate text report
    report_output = f"{OUTPUT_DIR}/analysis_report_{timestamp}.txt"
    analyzer.generate_analysis_report(results, report_output)
    
    print("\n" + "=" * 80)
    print("ANALYSIS COMPLETE!")
    print("=" * 80)
    print(f"\nOutput files:")
    print(f"1. {json_output}")
    print(f"2. {excel_output}")
    print(f"3. {report_output}")
    print()


if __name__ == "__main__":
    main()
