from __future__ import annotations
import os
import re
from dataclasses import dataclass
from typing import List, Tuple
import fitz  # pymupdf

@dataclass
class DocText:
    path: str
    filename: str
    conversation_id: str
    part_index: int
    text: str

def infer_conversation_id(filename: str) -> str:
    # examples: dr_WS1217.pdf -> WS1217
    #           dr_DA1010_Part2.pdf -> DA1010
    base = os.path.basename(filename)
    base = re.sub(r"\.pdf$", "", base, flags=re.I)
    base = base.replace("dr_", "").replace("DR_", "")
    base = re.sub(r"_Part\d+$", "", base, flags=re.I)
    base = re.sub(r"\s*\(\d+\)$", "", base).strip()
    return base

def infer_part_index(filename: str) -> int:
    m = re.search(r"_Part(\d+)", filename, flags=re.I)
    if m:
        return int(m.group(1))
    # If no explicit part, assume 1
    return 1

def extract_pdf_text(pdf_path: str) -> str:
    doc = fitz.open(pdf_path)
    chunks = []
    for page in doc:
        t = page.get_text("text")
        if t:
            chunks.append(t)
    return "\n".join(chunks).strip()

def load_folder(input_dir: str) -> List[DocText]:
    out: List[DocText] = []
    for fn in sorted(os.listdir(input_dir)):
        if not fn.lower().endswith(".pdf"):
            continue
        path = os.path.join(input_dir, fn)
        cid = infer_conversation_id(fn)
        pidx = infer_part_index(fn)
        text = extract_pdf_text(path)
        out.append(DocText(path=path, filename=fn, conversation_id=cid, part_index=pidx, text=text))
    return out

def group_conversations(docs: List[DocText]) -> List[Tuple[str, List[DocText]]]:
    by = {}
    for d in docs:
        by.setdefault(d.conversation_id, []).append(d)
    grouped = []
    for cid, parts in by.items():
        parts_sorted = sorted(parts, key=lambda x: x.part_index)
        grouped.append((cid, parts_sorted))
    grouped.sort(key=lambda x: x[0])
    return grouped

def merge_conversation_text(parts: List[DocText]) -> str:
    # merge with clear boundaries so the model knows it's multi-part
    merged = []
    for p in parts:
        merged.append(f"\n\n===== FILE: {p.filename} (Part {p.part_index}) =====\n")
        merged.append(p.text)
    return "\n".join(merged).strip()
