"""
generate_report.py — 2-page Word report using local Ollama (pure Python, no Node required)

Usage:
    python3 generate_report.py \
        --justifications ../outputs/analysis/detailed_justifications.xlsx \
        --hierarchy      ../outputs/analysis/hierarchy_table_output.xlsx \
        --output         report.docx \
        --model          llama3.3:latest
"""

import argparse, json, os, sys, re, urllib.request
import pandas as pd
from docx import Document
from docx.shared import Pt, RGBColor, Inches, Cm
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.enum.table import WD_ALIGN_VERTICAL
from docx.oxml.ns import qn
from docx.oxml import OxmlElement

parser = argparse.ArgumentParser()
parser.add_argument("--justifications", default="detailed_justifications.xlsx")
parser.add_argument("--hierarchy",      default="hierarchy_table_output.xlsx")
parser.add_argument("--output",         default="report.docx")
parser.add_argument("--model",          default="llama3.2")
parser.add_argument("--ollama-url",     default="http://localhost:11434")
args = parser.parse_args()

# ── Stats ─────────────────────────────────────────────────────────────────────
df = pd.read_excel(args.justifications, sheet_name="Justifications")
code_col = "code(outcome entities)" if "code(outcome entities)" in df.columns else "code"

groups_data = []
for g in df["group"].dropna().unique():
    sub = df[df["group"] == g]
    groups_data.append({
        "name":     str(g),
        "themes":   int(sub["theme"].nunique()),
        "codes":    int(sub[code_col].nunique()),
        "patients": int(sub["patient_id"].nunique()),
        "records":  int(len(sub))
    })

stats = {
    "total_records":    int(len(df)),
    "total_patients":   int(df["patient_id"].nunique()),
    "total_files":      int(df["original_file_id"].nunique()),
    "total_codes":      int(df[code_col].nunique()),
    "total_themes":     int(df["theme"].nunique()),
    "total_groups":     int(df["group"].nunique()),
    "quotes_found":     int((df["raw_quote_found"] == "YES").sum()),
    "quotes_not_found": int((df["raw_quote_found"] == "NO").sum()),
    "groups":           groups_data
}

col_descriptions = {
    "patient_id":           "Unique identifier for each patient in the study",
    "original_file_id":     "Identifier of the source transcript file the data was extracted from",
    "chunk_id":             "Sequential chunk number within a transcript used during text processing",
    code_col:               "The clinical outcome entity or code assigned to a piece of evidence",
    "theme":                "The broader thematic category that groups related codes together",
    "group":                "The highest-level classification — either a Barrier type or Facilitator",
    "quote":                "The direct or paraphrased quote extracted from the patient transcript",
    "raw_quote_found":      "Whether the exact quote was found verbatim in the source text (YES/NO)",
    "raw_evidence_snippet": "The raw text snippet from the transcript that supports the code",
    "excel_context":        "Additional contextual notes from the spreadsheet for this record",
    "llm_justification":    "The AI-generated justification explaining why this code was assigned"
}

# ── Call Ollama ───────────────────────────────────────────────────────────────
def call_ollama(base_url, model, prompt):
    url = f"{base_url}/api/generate"
    payload = json.dumps({"model": model, "prompt": prompt, "stream": False}).encode()
    req = urllib.request.Request(url, data=payload, headers={"Content-Type": "application/json"})
    try:
        with urllib.request.urlopen(req, timeout=300) as r:
            return json.loads(r.read())["response"]
    except urllib.error.URLError as e:
        print(f"\nERROR: Cannot connect to Ollama at {base_url}")
        print(f"Make sure Ollama is running: ollama serve")
        print(f"Details: {e}")
        sys.exit(1)

prompt = f"""You are writing a clear, plain-English 2-page research report for a general audience.
The data comes from a qualitative analysis of patient transcripts about healthcare barriers and facilitators.

Dataset statistics:
{json.dumps(stats, indent=2)}

Column descriptions:
{json.dumps(col_descriptions, indent=2)}

Write the report with EXACTLY these 5 section headings on their own line:
1. OVERVIEW
2. DATA STRUCTURE
3. KEY FINDINGS
4. GROUP BREAKDOWN
5. CONCLUSION

Rules:
- Simple, clear, accessible language. No jargon.
- Flowing paragraphs only, no bullet points.
- Use the exact statistics provided.
- Total length: approximately 600-700 words.
- Do not add commentary before or after the report."""

print(f"Calling Ollama ({args.model}) at {args.ollama_url} ...")
narrative = call_ollama(args.ollama_url, args.model, prompt)
print("LLM done. Building Word document...")

# ── Parse sections ────────────────────────────────────────────────────────────
section_titles = ["OVERVIEW", "DATA STRUCTURE", "KEY FINDINGS", "GROUP BREAKDOWN", "CONCLUSION"]

def extract_sections(text):
    sections = {}
    pattern = (r'(?:^|\n)\s*(?:\d+\.\s*)?(' + '|'.join(section_titles) +
               r')\s*\n(.*?)(?=\n\s*(?:\d+\.\s*)?(?:' + '|'.join(section_titles) + r')\s*\n|$)')
    for m in re.finditer(pattern, text, re.IGNORECASE | re.DOTALL):
        sections[m.group(1).upper()] = m.group(2).strip()
    return sections

sections = extract_sections(narrative)

# ── Helpers ───────────────────────────────────────────────────────────────────
def set_cell_bg(cell, hex_color):
    tc   = cell._tc
    tcPr = tc.get_or_add_tcPr()
    shd  = OxmlElement("w:shd")
    shd.set(qn("w:val"),   "clear")
    shd.set(qn("w:color"), "auto")
    shd.set(qn("w:fill"),  hex_color)
    tcPr.append(shd)

def set_cell_border(cell):
    tc   = cell._tc
    tcPr = tc.get_or_add_tcPr()
    tcBorders = OxmlElement("w:tcBorders")
    for side in ["top", "left", "bottom", "right"]:
        border = OxmlElement(f"w:{side}")
        border.set(qn("w:val"),   "single")
        border.set(qn("w:sz"),    "4")
        border.set(qn("w:color"), "CCCCCC")
        tcBorders.append(border)
    tcPr.append(tcBorders)

def add_heading(doc, text, color_hex="1F4E79", size=14):
    p    = doc.add_paragraph()
    p.paragraph_format.space_before = Pt(14)
    p.paragraph_format.space_after  = Pt(6)
    run  = p.add_run(text)
    run.bold      = True
    run.font.size = Pt(size)
    run.font.name = "Arial"
    r, g, b = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
    run.font.color.rgb = RGBColor(r, g, b)
    # bottom border
    pPr  = p._p.get_or_add_pPr()
    pBdr = OxmlElement("w:pBdr")
    bot  = OxmlElement("w:bottom")
    bot.set(qn("w:val"),   "single")
    bot.set(qn("w:sz"),    "4")
    bot.set(qn("w:color"), "D0D8E4")
    pBdr.append(bot)
    pPr.append(pBdr)
    return p

def add_body(doc, text):
    for para in [p.strip() for p in text.split("\n\n") if p.strip()]:
        p = doc.add_paragraph(para)
        p.paragraph_format.space_after = Pt(6)
        for run in p.runs:
            run.font.size = Pt(11)
            run.font.name = "Arial"

# ── Build document ────────────────────────────────────────────────────────────
doc = Document()

# Page margins
for section in doc.sections:
    section.page_width  = Inches(8.5)
    section.page_height = Inches(11)
    section.top_margin    = Inches(0.75)
    section.bottom_margin = Inches(0.75)
    section.left_margin   = Inches(0.75)
    section.right_margin  = Inches(0.75)

# Default font
style = doc.styles["Normal"]
style.font.name = "Arial"
style.font.size = Pt(11)

# ── Title block ───────────────────────────────────────────────────────────────
title = doc.add_paragraph()
title.alignment = WD_ALIGN_PARAGRAPH.CENTER
title.paragraph_format.space_after = Pt(4)
tr = title.add_run("Patient Transcript Analysis Report")
tr.bold = True; tr.font.size = Pt(20); tr.font.name = "Arial"
tr.font.color.rgb = RGBColor(0x1F, 0x4E, 0x79)

sub = doc.add_paragraph()
sub.alignment = WD_ALIGN_PARAGRAPH.CENTER
sub.paragraph_format.space_after = Pt(2)
sr = sub.add_run("Barriers & Facilitators — Deductive Coding Study")
sr.font.size = Pt(12); sr.font.name = "Arial"
sr.font.color.rgb = RGBColor(0x55, 0x55, 0x55)

gen = doc.add_paragraph()
gen.alignment = WD_ALIGN_PARAGRAPH.CENTER
gen.paragraph_format.space_after = Pt(14)
gr = gen.add_run(f"AI-Assisted Analysis Pipeline  |  Generated via Ollama ({args.model})")
gr.italic = True; gr.font.size = Pt(9); gr.font.name = "Arial"
gr.font.color.rgb = RGBColor(0x88, 0x88, 0x88)

# ── Stats banner table ────────────────────────────────────────────────────────
banner_data = [
    ("30",    "Patients"),
    ("38",    "Source Files"),
    ("4",     "Groups"),
    ("21",    "Themes"),
    ("890",   "Unique Codes"),
    ("3,562", "Total Records"),
]
tbl = doc.add_table(rows=1, cols=6)
tbl.style = "Table Grid"
tbl.autofit = False
col_w = Inches(1.167)
for i, (val, lbl) in enumerate(banner_data):
    cell = tbl.rows[0].cells[i]
    cell.width = col_w
    set_cell_bg(cell, "1F4E79")
    set_cell_border(cell)
    cell.vertical_alignment = WD_ALIGN_VERTICAL.CENTER
    vp = cell.add_paragraph(val)
    vp.alignment = WD_ALIGN_PARAGRAPH.CENTER
    vr = vp.runs[0]; vr.bold = True; vr.font.size = Pt(14); vr.font.name = "Arial"
    vr.font.color.rgb = RGBColor(0xFF, 0xFF, 0xFF)
    lp = cell.add_paragraph(lbl)
    lp.alignment = WD_ALIGN_PARAGRAPH.CENTER
    lr = lp.runs[0]; lr.font.size = Pt(8); lr.font.name = "Arial"
    lr.font.color.rgb = RGBColor(0xBB, 0xCC, 0xDD)
    # remove default empty paragraph
    for p in cell.paragraphs:
        if p.text == "":
            p._element.getparent().remove(p._element)

doc.add_paragraph()

# ── Narrative sections ────────────────────────────────────────────────────────
for i, title in enumerate(section_titles):
    if i == 3:  # page break before GROUP BREAKDOWN
        doc.add_page_break()
    add_heading(doc, title)
    body = sections.get(title, f"[{title} content not generated by LLM]")
    add_body(doc, body)

# ── Summary table ─────────────────────────────────────────────────────────────
doc.add_paragraph()
add_heading(doc, "Summary by Group", size=12)

grp_tbl = doc.add_table(rows=1 + len(groups_data), cols=5)
grp_tbl.style = "Table Grid"
grp_tbl.autofit = False
col_widths = [Inches(2.5), Inches(0.9), Inches(0.9), Inches(0.9), Inches(0.9)]
headers = ["Group", "Themes", "Codes", "Patients", "Records"]

for j, (hdr, w) in enumerate(zip(headers, col_widths)):
    cell = grp_tbl.rows[0].cells[j]
    cell.width = w
    set_cell_bg(cell, "1F4E79")
    set_cell_border(cell)
    p = cell.paragraphs[0]
    p.alignment = WD_ALIGN_PARAGRAPH.CENTER
    run = p.add_run(hdr)
    run.bold = True; run.font.size = Pt(10); run.font.name = "Arial"
    run.font.color.rgb = RGBColor(0xFF, 0xFF, 0xFF)

for row_i, g in enumerate(groups_data):
    row = grp_tbl.rows[row_i + 1]
    row_fill = "EBF3FB" if row_i % 2 == 0 else "FFFFFF"
    values = [g["name"], str(g["themes"]), str(g["codes"]), str(g["patients"]), str(g["records"])]
    for j, (val, w) in enumerate(zip(values, col_widths)):
        cell = row.cells[j]
        cell.width = w
        set_cell_bg(cell, row_fill)
        set_cell_border(cell)
        p = cell.paragraphs[0]
        p.alignment = WD_ALIGN_PARAGRAPH.CENTER if j > 0 else WD_ALIGN_PARAGRAPH.LEFT
        run = p.add_run(val)
        run.font.size = Pt(10); run.font.name = "Arial"

doc.save(args.output)
print(f"Done -> {args.output}")