"""
Group Level Confusion Matrix — Count + Row% + Col% in every cell
Just run this file. No input files needed.
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# ── Counts from your existing confusion matrix (row = original, col = LLM) ──
counts = np.array([
    [39,  0,  1,  6],   # barriers - beliefs & perceptions
    [ 0,  7,  5,  1],   # barriers - financial & access
    [ 1,  1, 31,  1],   # barriers - healthcare system
    [11,  2,  7, 87],   # facilitators
])

labels = [
    "barriers -\nbeliefs &\nperceptions",
    "barriers -\nfinancial &\naccess",
    "barriers -\nhealthcare\nsystem",
    "facilitators",
]

# ── Row % and Column % ──────────────────────────────────────────────────────
row_pct = counts / counts.sum(axis=1, keepdims=True) * 100
col_pct = counts / counts.sum(axis=0, keepdims=True) * 100

# ── Plot ────────────────────────────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(13, 10))

cmap = LinearSegmentedColormap.from_list(
    "blue_custom", ["#f0f4f8", "#bcd4e6", "#6aaed6", "#2171b5", "#08306b"]
)
im = ax.imshow(counts, cmap=cmap, aspect="auto", vmin=0, vmax=counts.max())

n = len(labels)

# White grid lines between cells
for i in range(n + 1):
    ax.axhline(i - 0.5, color="white", linewidth=2)
    ax.axvline(i - 0.5, color="white", linewidth=2)

# ── Annotate every cell with count + row% + col% ───────────────────────────
for i in range(n):
    for j in range(n):
        cnt  = counts[i, j]
        rpct = row_pct[i, j]
        cpct = col_pct[i, j]

        # White text on dark cells, dark text on light cells
        txt_color = "white" if (cnt / counts.max()) > 0.45 else "#1a1a2e"

        ax.text(j, i - 0.13, f"{cnt}",
                ha="center", va="center",
                fontsize=16, fontweight="bold", color=txt_color)

        ax.text(j, i + 0.20, f"Row: {rpct:.1f}%",
                ha="center", va="center",
                fontsize=9.5, color=txt_color, style="italic")

        ax.text(j, i + 0.42, f"Col: {cpct:.1f}%",
                ha="center", va="center",
                fontsize=9.5, color=txt_color, style="italic")

# ── Axis labels & title ─────────────────────────────────────────────────────
ax.set_xticks(range(n))
ax.set_xticklabels(labels, fontsize=10)
ax.set_yticks(range(n))
ax.set_yticklabels(labels, fontsize=10)
ax.set_xlabel("LLM Re-coding", fontsize=13, labelpad=14)
ax.set_ylabel("Original Coding", fontsize=13, labelpad=14)
ax.set_title("Confusion Matrix: Group Level Coding\n(Count  |  Row%  |  Col%)",
             fontsize=14, fontweight="bold", pad=16)

plt.xticks(rotation=20, ha="right")

cbar = plt.colorbar(im, ax=ax, fraction=0.035, pad=0.03)
cbar.set_label("Count", fontsize=11)

fig.text(
    0.5, -0.03,
    "Row% = % of that original group re-coded into this LLM column   |   "
    "Col% = % of that LLM column that came from this original row",
    ha="center", fontsize=9, color="gray", style="italic"
)

plt.tight_layout()
plt.savefig("confusion_matrix_group_with_pct.png", dpi=180, bbox_inches="tight")
print("Saved: confusion_matrix_group_with_pct.png")
plt.show()