import re
from textblob import TextBlob
from pathlib import Path
from utils import setup_logging

from pathlib import Path
Path("data/processed/chunked").mkdir(parents=True, exist_ok=True)  # Create directory

logger = setup_logging()

def clean_transcript(text):
    """Clean transcript text"""
    # Remove speaker labels (Interviewer:, Patient:, etc.)
    text = re.sub(r'^(Interviewer|Patient|Doctor|Provider):\s*', '', text, flags=re.MULTILINE)
    
    # Remove filler words
    fillers = ['um', 'uh', 'like', 'you know', 'kind of', 'sort of']
    for filler in fillers:
        text = re.sub(rf'\b{filler}\b', '', text, flags=re.IGNORECASE)
    
    # Remove [inaudible], [crosstalk], etc.
    text = re.sub(r'\[.*?\]', '', text)
    
    # Normalize whitespace
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'\n\s*\n', '\n\n', text)
    
    # Optional: Fix spelling (be careful with medical terms)
    # blob = TextBlob(text)
    # text = str(blob.correct())
    
    return text.strip()

def chunk_text(text, chunk_size=20, overlap=5):
    """
    Chunk text into paragraphs (following paper's approach)
    chunk_size: number of sentences per chunk
    """
    sentences = text.split('. ')
    chunks = []
    
    for i in range(0, len(sentences), chunk_size - overlap):
        chunk = '. '.join(sentences[i:i + chunk_size])
        if chunk:
            chunks.append(chunk)
    
    return chunks

def preprocess_all_transcripts(input_dir, output_dir, chunk_size=20):
    """Preprocess all transcripts"""
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    for txt_file in Path(input_dir).glob('*.txt'):
        logger.info(f"Preprocessing: {txt_file.name}")
        
        with open(txt_file, 'r', encoding='utf-8') as f:
            text = f.read()
        
        # Clean
        cleaned = clean_transcript(text)
        
        # Chunk
        chunks = chunk_text(cleaned, chunk_size)
        
        # Save
        output_data = {
            'file_id': txt_file.stem,
            'full_text': cleaned,
            'chunks': chunks,
            'num_chunks': len(chunks)
        }
        
        output_file = Path(output_dir) / f"{txt_file.stem}_processed.json"
        save_json(output_data, output_file)
        
        logger.info(f"  -> {len(chunks)} chunks created")

if __name__ == "__main__":
    from utils import save_json
    import sys
    sys.path.append('..')  # Add this line
    
    input_dir = "data/processed"
    output_dir = "data/processed/chunked"
    
    preprocess_all_transcripts(input_dir, output_dir)  # ADD THIS LINE - you forgot to call the function!
