#!/usr/bin/env python3
"""
SYSTEM TEST SCRIPT
Tests all components before running full analysis
"""

import os
import sys
from pathlib import Path

def print_header(text):
    print("\n" + "=" * 60)
    print(f"  {text}")
    print("=" * 60)

def test_python_packages():
    """Test if required Python packages are installed"""
    print_header("TEST 1: Python Packages")
    
    required = {
        'PyPDF2': 'PyPDF2',
        'pandas': 'pandas',
        'openpyxl': 'openpyxl',
        'requests': 'requests',
        'json': 'json',
        're': 're'
    }
    
    all_good = True
    for package, import_name in required.items():
        try:
            __import__(import_name)
            print(f"  ✓ {package}")
        except ImportError:
            print(f"  ✗ {package} - NOT INSTALLED")
            all_good = False
    
    return all_good

def test_ollama_connection():
    """Test connection to Ollama server"""
    print_header("TEST 2: Ollama Connection")
    
    try:
        import requests
        response = requests.get('http://localhost:11434/api/tags', timeout=3)
        
        if response.status_code == 200:
            print("  ✓ Ollama server is running")
            
            data = response.json()
            models = data.get('models', [])
            
            if models:
                print(f"\n  Available models:")
                for model in models:
                    size_gb = model.get('size', 0) / (1024**3)
                    print(f"    - {model['name']} ({size_gb:.1f} GB)")
                
                # Check for our models
                model_names = [m['name'] for m in models]
                if any('llama3.1' in m for m in model_names):
                    print("\n  ✓ Llama 3.1 detected")
                else:
                    print("\n  ⚠ Llama 3.1 not found")
                    print("    Run: ollama pull llama3.1:8b")
                
                if any('gemma' in m for m in model_names):
                    print("  ✓ Gemma detected")
                else:
                    print("  ⚠ Gemma not found")
                    print("    Run: ollama pull gemma:7b")
                
                return True
            else:
                print("  ⚠ No models found")
                print("    Run: ollama pull llama3.1:8b")
                return False
        else:
            print(f"  ✗ Ollama returned error: {response.status_code}")
            return False
            
    except requests.exceptions.ConnectionError:
        print("  ✗ Cannot connect to Ollama")
        print("    Solution: Run 'ollama serve' in another terminal")
        return False
    except Exception as e:
        print(f"  ✗ Error: {e}")
        return False

def test_llm_simple_query():
    """Test LLM with a simple query"""
    print_header("TEST 3: LLM Simple Query")
    
    try:
        import requests
        
        print("  Sending test query to llama3.1:8b...")
        
        response = requests.post(
            'http://localhost:11434/api/generate',
            json={
                'model': 'llama3.1:8b',
                'prompt': 'Say "Hello" and nothing else.',
                'stream': False,
                'temperature': 0.1
            },
            timeout=30
        )
        
        if response.status_code == 200:
            result = response.json()
            llm_response = result.get('response', '')
            print(f"  ✓ LLM responded: {llm_response[:50]}...")
            return True
        else:
            print(f"  ✗ LLM error: {response.status_code}")
            return False
            
    except Exception as e:
        print(f"  ✗ Error: {e}")
        return False

def test_pdf_directory():
    """Check if PDF files exist"""
    print_header("TEST 4: PDF Files")
    
    pdf_dir = "/mnt/user-data/uploads"
    
    if not os.path.exists(pdf_dir):
        print(f"  ✗ Directory not found: {pdf_dir}")
        return False
    
    pdf_files = list(Path(pdf_dir).glob("*.pdf"))
    
    if pdf_files:
        print(f"  ✓ Found {len(pdf_files)} PDF files")
        print("\n  Sample files:")
        for pdf in pdf_files[:5]:
            print(f"    - {pdf.name}")
        if len(pdf_files) > 5:
            print(f"    ... and {len(pdf_files) - 5} more")
        return True
    else:
        print(f"  ⚠ No PDF files found in {pdf_dir}")
        print("    Please upload your PDF files first")
        return False

def test_pdf_extraction():
    """Test extracting text from one PDF"""
    print_header("TEST 5: PDF Text Extraction")
    
    pdf_dir = "/mnt/user-data/uploads"
    pdf_files = list(Path(pdf_dir).glob("*.pdf"))
    
    if not pdf_files:
        print("  ⚠ Skipped - no PDF files available")
        return True
    
    test_pdf = pdf_files[0]
    
    try:
        import PyPDF2
        
        print(f"  Testing: {test_pdf.name}")
        
        with open(test_pdf, 'rb') as file:
            pdf_reader = PyPDF2.PdfReader(file)
            
            print(f"  Pages: {len(pdf_reader.pages)}")
            
            text = ""
            for page in pdf_reader.pages:
                text += page.extract_text()
            
            if text.strip():
                word_count = len(text.split())
                print(f"  ✓ Extracted {word_count} words")
                print(f"  Sample: {text[:100]}...")
                return True
            else:
                print("  ⚠ No text extracted (might be scanned PDF)")
                print("    You may need OCR for this file")
                return False
                
    except Exception as e:
        print(f"  ✗ Error: {e}")
        return False

def test_output_directory():
    """Check if output directory is writable"""
    print_header("TEST 6: Output Directory")
    
    output_dir = "/mnt/user-data/outputs"
    
    try:
        os.makedirs(output_dir, exist_ok=True)
        
        # Test write
        test_file = os.path.join(output_dir, ".test_write")
        with open(test_file, 'w') as f:
            f.write("test")
        
        # Test read
        with open(test_file, 'r') as f:
            content = f.read()
        
        # Clean up
        os.remove(test_file)
        
        print(f"  ✓ Output directory is writable: {output_dir}")
        return True
        
    except Exception as e:
        print(f"  ✗ Cannot write to {output_dir}: {e}")
        return False

def test_analyzer_import():
    """Test if main analyzer can be imported"""
    print_header("TEST 7: Analyzer Import")
    
    try:
        from medical_transcript_analyzer import MedicalTranscriptAnalyzer
        print("  ✓ MedicalTranscriptAnalyzer imported successfully")
        
        # Test initialization
        analyzer = MedicalTranscriptAnalyzer(model_name="llama3.1:8b")
        print(f"  ✓ Analyzer initialized")
        print(f"  ✓ Entities defined: {len(analyzer.all_entities)}")
        
        return True
        
    except ImportError as e:
        print(f"  ✗ Import error: {e}")
        return False
    except Exception as e:
        print(f"  ✗ Error: {e}")
        return False

def run_mini_analysis():
    """Run analysis on just one PDF as a test"""
    print_header("TEST 8: Mini Analysis (Optional)")
    
    answer = input("\n  Run analysis on one PDF to test? (y/n) [n]: ").strip().lower()
    
    if answer != 'y':
        print("  Skipped")
        return True
    
    try:
        from medical_transcript_analyzer import MedicalTranscriptAnalyzer
        import json
        
        pdf_dir = "/mnt/user-data/uploads"
        pdf_files = list(Path(pdf_dir).glob("*.pdf"))
        
        if not pdf_files:
            print("  ✗ No PDF files to test")
            return False
        
        test_pdf = pdf_files[0]
        print(f"\n  Analyzing: {test_pdf.name}")
        print("  (This may take 1-2 minutes)...")
        
        analyzer = MedicalTranscriptAnalyzer(model_name="llama3.1:8b")
        result = analyzer.analyze_transcript(str(test_pdf))
        
        if result:
            print("\n  ✓ Analysis completed!")
            print(f"  Conversation ID: {result.get('conversation_id')}")
            print(f"  Text length: {result.get('raw_text_length')} characters")
            
            # Count entities found
            entities_found = sum(1 for k, v in result.items() 
                               if isinstance(v, dict) and v.get('present', False))
            
            print(f"  Entities found: {entities_found}")
            
            # Show sample extraction
            print("\n  Sample extraction:")
            for entity, data in result.items():
                if isinstance(data, dict) and data.get('present', False):
                    print(f"    - {entity}")
                    print(f"      Sentiment: {data.get('sentiment')}")
                    print(f"      Details: {data.get('details', '')[:60]}...")
                    break
            
            # Save to temp file
            temp_output = "/tmp/test_result.json"
            with open(temp_output, 'w') as f:
                json.dump(result, f, indent=2)
            print(f"\n  Full result saved to: {temp_output}")
            
            return True
        else:
            print("  ✗ Analysis failed")
            return False
            
    except Exception as e:
        print(f"  ✗ Error: {e}")
        import traceback
        traceback.print_exc()
        return False

def main():
    """Run all tests"""
    print("\n" + "=" * 60)
    print("  MEDICAL TRANSCRIPT ANALYZER - SYSTEM TEST")
    print("=" * 60)
    print("\nThis will verify all components are working correctly")
    print("before running the full analysis.\n")
    
    tests = [
        ("Python Packages", test_python_packages),
        ("Ollama Connection", test_ollama_connection),
        ("LLM Query", test_llm_simple_query),
        ("PDF Files", test_pdf_directory),
        ("PDF Extraction", test_pdf_extraction),
        ("Output Directory", test_output_directory),
        ("Analyzer Import", test_analyzer_import),
    ]
    
    results = {}
    
    for test_name, test_func in tests:
        results[test_name] = test_func()
    
    # Run optional mini analysis
    run_mini_analysis()
    
    # Summary
    print_header("TEST SUMMARY")
    
    passed = sum(1 for v in results.values() if v)
    total = len(results)
    
    print(f"\n  Tests passed: {passed}/{total}\n")
    
    for test_name, result in results.items():
        status = "✓ PASS" if result else "✗ FAIL"
        print(f"  {status}  {test_name}")
    
    print()
    
    if all(results.values()):
        print("=" * 60)
        print("  ✓ ALL TESTS PASSED!")
        print("  System is ready for full analysis")
        print("=" * 60)
        print("\n  Run full analysis with:")
        print("    python3 run_analysis.py")
        print("  or")
        print("    ./quick_start.sh")
        print()
        return 0
    else:
        print("=" * 60)
        print("  ⚠ SOME TESTS FAILED")
        print("  Please fix the issues before running analysis")
        print("=" * 60)
        print("\n  Check:")
        
        if not results.get("Python Packages"):
            print("    - Install packages: pip install PyPDF2 pandas openpyxl requests")
        if not results.get("Ollama Connection"):
            print("    - Start Ollama: ollama serve")
        if not results.get("LLM Query"):
            print("    - Download model: ollama pull llama3.1:8b")
        if not results.get("PDF Files"):
            print("    - Upload PDF files to /mnt/user-data/uploads")
        
        print("\n  See TROUBLESHOOTING.md for detailed solutions")
        print()
        return 1

if __name__ == "__main__":
    try:
        sys.exit(main())
    except KeyboardInterrupt:
        print("\n\nTest interrupted by user.")
        sys.exit(1)
