# piano_widget.py
# Requirements:
#   pip install numpy scipy sounddevice soundfile
#   pip install music21   (for MIDI parsing & chord naming)
#   pip install pygame    (for external MIDI controller support)
# Samples dir:
#   samples/<InstrumentName>/*.wav (e.g., samples/Piano/C4.wav)

import os
import sys
import math
import threading
import collections
import tempfile
import random
from typing import List, Tuple, Optional

from PyQt6.QtWidgets import (
    QWidget, QPushButton, QHBoxLayout, QVBoxLayout, QLabel, QSlider, QComboBox,
    QScrollArea, QSpinBox, QApplication, QCheckBox, QDialog,
    QListWidget, QListWidgetItem, QFileDialog, QMessageBox, QSizePolicy
)
from PyQt6.QtCore import pyqtSignal, Qt, QTimer, QThread, QObject
from PyQt6.QtGui import QDragEnterEvent, QDropEvent, QGuiApplication, QClipboard, QFont, QKeyEvent

import numpy as np
from scipy import signal
import sounddevice as sd

# ---- Optional deps: music21 ----
try:
    from music21 import note, stream, pitch, converter, chord, harmony, duration, meter, tempo
    HAS_MUSIC21 = True
except ImportError:
    print("music21 not found. MIDI parsing / chord naming will be limited.")
    HAS_MUSIC21 = False

# ---- Optional deps: pygame.midi for external MIDI controllers ----
HAS_PYGAME_MIDI = False
try:
    import pygame
    import pygame.midi as pgmidi
    HAS_PYGAME_MIDI = True
except Exception:
    HAS_PYGAME_MIDI = False

# ---- Conservative audio defaults (we'll override if device demands) ----
sd.default.dtype = 'float32'
sd.default.blocksize = 512
sd.default.latency = 'high'

WHITE_KEY_W = 50
WHITE_KEY_H = 200
WHITE_GAP = 1
BLACK_KEY_W = 30
BLACK_KEY_H = 120
BLACK_OFFSETS = [0.65, 1.8, 3.65, 4.75, 5.85]  # C#, D#, F#, G#, A#

# =========================
# QSS (DPI-aware styling)
# =========================
def load_stylesheet() -> str:
    return r"""
QMainWindow, QDialog, QMessageBox {
    font-size: 10.5pt;
    background-color: #0d1117;
    color: #e8eef7;
    font-family: 'Segoe UI', 'Helvetica', 'Arial', sans-serif;
}
QMessageBox QPushButton { min-width: 5em; font-size: 9.8pt; }
QComboBox, QSpinBox {
    background-color: #0f1522;
    color: #e8eef7;
    border: 1px solid #30415d;
    padding: 2px 6px;
    border-radius: 6px;
}
QPushButton {
    background-color: #12233a;
    color: #e8eef7;
    border: 1px solid #2f4c7a;
    padding: 6px 8px;
    border-radius: 8px;
}
QPushButton:checked { background-color: #1b3b66; }
QPushButton:hover { background-color: #18365e; }
#sectionTitle { font-weight: 600; color: #9ecbff; }
"""

# =========================
# DSP utilities
# =========================
def db_to_lin(db): return 10 ** (db / 20.0)
def soft_clip(x, drive=1.0): return np.tanh(drive * x).astype(np.float32)

def butter_filter(x, fs, cutoff, kind="low", order=4):
    if cutoff <= 0 or cutoff >= fs / 2:
        return x
    b, a = signal.butter(order, cutoff / (fs / 2), btype=kind)
    return signal.lfilter(b, a, x).astype(np.float32)

def compressor(x, threshold_db=-18, ratio=4.0, makeup_db=3.0, attack_ms=5.0, release_ms=50.0, fs=44100):
    thr = db_to_lin(threshold_db)
    atk = math.exp(-1.0 / (fs * (attack_ms / 1000.0)))
    rel = math.exp(-1.0 / (fs * (release_ms / 1000.0)))
    env = 0.0
    gain = np.zeros_like(x, dtype=np.float32)
    for i, s in enumerate(x):
        a = abs(float(s))
        env = max(a, env * (atk if a > env else rel))
        if env > thr:
            comp_gain = (env / thr) ** (-(1.0 - 1.0 / ratio))
        else:
            comp_gain = 1.0
        gain[i] = comp_gain
    y = x * gain
    y *= db_to_lin(makeup_db)
    return y.astype(np.float32)

def make_delay(x, fs, tempo_label="1/4", feedback=0.5, mix=0.25):
    base_sec_map = {"1/2": 1.0, "1/4": 0.5, "1/8": 0.25, "1/16": 0.125, "1/8T": 0.1667}
    dtime = float(base_sec_map.get(tempo_label, 0.5))
    fb = max(0.0, min(0.85, float(feedback)))
    delay_samps = max(1, int(dtime * fs))
    wet = np.zeros_like(x, dtype=np.float32)
    src = x.astype(np.float32)
    acc = np.copy(src)
    idx = delay_samps
    while idx < len(x) * 4:
        seg = np.zeros_like(x, dtype=np.float32)
        start = idx
        end = idx + len(x)
        overlap_len = max(0, min(len(seg), end) - max(0, start))
        if overlap_len <= 0:
            break
        s0 = max(0, start)
        s1 = s0 + overlap_len
        x0 = max(0, -start)
        x1 = x0 + overlap_len
        seg[s0:s1] += acc[x0:x1] * fb
        wet += seg
        acc = seg
        idx += delay_samps
        if np.max(np.abs(acc)) < 1e-5:
            break
    out = (1 - mix) * src + mix * wet
    return out.astype(np.float32)

def schroeder_reverb(x, fs, mix=0.2):
    x = x.astype(np.float32)
    def comb(sig, delay_ms, feedback):
        d = max(1, int(delay_ms * fs / 1000))
        y = np.copy(sig)
        for i in range(d, len(sig)):
            y[i] += feedback * y[i - d]
        return y
    def allpass(sig, delay_ms, g):
        d = max(1, int(delay_ms * fs / 1000))
        y = np.zeros_like(sig, dtype=np.float32)
        for i in range(len(sig)):
            xn = sig[i]
            xd = sig[i - d] if i - d >= 0 else 0.0
            yd = y[i - d] if i - d >= 0 else 0.0
            y[i] = -g * xn + xd + g * yd
        return y
    c = (
        comb(x, 29.7, 0.773) +
        comb(x, 37.1, 0.802) +
        comb(x, 41.1, 0.753) +
        comb(x, 43.7, 0.733)
    ) * 0.25
    ap = allpass(c, 5.0, 0.7)
    ap = allpass(ap, 1.7, 0.7)
    return ((1 - mix) * x + mix * ap).astype(np.float32)

def chorus(x, fs, depth_ms=8.0, rate_hz=0.25, mix=0.3):
    n = len(x)
    depth = int(depth_ms * fs / 1000.0)
    if depth <= 0:
        return x
    t = np.arange(n) / fs
    mod = (np.sin(2 * np.pi * rate_hz * t) + 1) * 0.5
    delay = (mod * depth).astype(int)
    y = np.copy(x)
    for i in range(n):
        d = delay[i]
        if i - d >= 0:
            y[i] += x[i - d]
    y *= 0.5
    return ((1 - mix) * x + mix * y).astype(np.float32)

def tremolo(x, fs, rate_hz=5.0, depth=0.5):
    t = np.arange(len(x)) / fs
    mod = (1.0 - depth) + depth * (0.5 * (1 + np.sin(2 * np.pi * rate_hz * t)))
    return (x * mod.astype(np.float32)).astype(np.float32)

def normalize(x, peak=0.95):
    m = float(np.max(np.abs(x)) + 1e-12)
    return (x / m * peak).astype(np.float32)

def resample_if_needed(wav, src_fs, target_fs=44100):
    if src_fs == target_fs:
        return wav.astype(np.float32), target_fs
    g = math.gcd(src_fs, target_fs)
    up = target_fs // g
    down = src_fs // g
    y = signal.resample_poly(wav, up, down, axis=0)
    return y.astype(np.float32), target_fs

def apply_fade(x, fs, in_ms=5, out_ms=20):
    x = x.astype(np.float32, copy=False)
    n_in = max(1, int(fs * in_ms / 1000.0))
    n_out = max(1, int(fs * out_ms / 1000.0))
    n = len(x)
    if n <= 2:
        return x
    if n_in + n_out >= n:
        w = np.hanning(max(2, n)).astype(np.float32)
        return (x * w).astype(np.float32)
    x[:n_in] *= np.linspace(0.0, 1.0, n_in, dtype=np.float32)
    x[-n_out:] *= np.linspace(1.0, 0.0, n_out, dtype=np.float32)
    return x

def kill_dc_and_dither(x, dither_amp=1e-6):
    x = x.astype(np.float32, copy=False)
    x = x - np.mean(x, dtype=np.float64)
    if dither_amp > 0:
        x = x + np.random.uniform(-dither_amp, dither_amp, size=x.shape).astype(np.float32)
    return x

# =========================
# Audio Engine
# =========================
class AudioEngine:
    def __init__(self, sample_rate=44100):
        self.fs = int(sample_rate)
        self.samples = {}   # key -> np.float32 mono @ fs
        self.lock = threading.Lock()
        self.master_gain = 0.9

        self.stream = None
        self.output_device_index = None
        self._last_open_error = None
        self.voices = collections.deque()

    def _list_output_devices(self):
        try:
            devs = sd.query_devices()
        except Exception as e:
            print(f"[AudioEngine] sd.query_devices() failed: {e}")
            return []
        out = []
        for idx, d in enumerate(devs):
            if d.get('max_output_channels', 0) > 0:
                out.append((idx, d))
        return out

    def _try_open_stream(self, samplerate, device_index):
        try:
            stream = sd.OutputStream(
                samplerate=samplerate,
                channels=1,
                dtype='float32',
                blocksize=sd.default.blocksize,
                latency=sd.default.latency,
                device=device_index,
                callback=self._callback,
            )
            stream.start()
            return stream
        except Exception as e:
            self._last_open_error = str(e)
            return None

    def _ensure_stream(self):
        if self.stream is not None and self.stream.active:
            return True

        rate_candidates = [self.fs, 48000] if self.fs != 48000 else [48000, 44100]
        outputs = self._list_output_devices()
        if not outputs:
            print("[AudioEngine] No output-capable audio devices found.")
            return False

        preferred_indices = []
        try:
            default_in, default_out = sd.default.device
            if isinstance(default_out, (int, np.integer)) and default_out >= 0:
                preferred_indices.append(default_out)
        except Exception:
            pass

        for idx, _ in outputs:
            if idx not in preferred_indices:
                preferred_indices.append(idx)

        for dev_idx in preferred_indices:
            for rate in rate_candidates:
                st = self._try_open_stream(rate, dev_idx)
                if st is not None:
                    self.stream = st
                    self.fs = int(rate)
                    self.output_device_index = dev_idx
                    sd.default.samplerate = self.fs
                    try:
                        sd.default.device = (sd.default.device[0], dev_idx)
                    except Exception:
                        sd.default.device = dev_idx
                    print(f"[AudioEngine] Opened stream on device {dev_idx} @ {self.fs} Hz")
                    return True

        print("[AudioEngine] Failed to open persistent OutputStream.")
        if self._last_open_error:
            print(f"[AudioEngine] Last error: {self._last_open_error}")
        return False

    def _callback(self, outdata, frames, time_info, status):
        if status:
            pass
        out = np.zeros(frames, dtype=np.float32)
        with self.lock:
            new_voices = collections.deque()
            for v in self.voices:
                buf = v["buf"]; pos = v["pos"]
                remaining = len(buf) - pos
                if remaining <= 0:
                    continue
                take = min(frames, remaining)
                out[:take] += buf[pos:pos+take]
                pos += take
                v["pos"] = pos
                if pos < len(buf):
                    new_voices.append(v)
            self.voices = new_voices
        out = np.tanh(1.1 * out).astype(np.float32)
        outdata[:, 0] = out

    def load_dir(self, s_dir):
        self.samples.clear()
        if not os.path.isdir(s_dir):
            return
        import soundfile as sf
        for fname in os.listdir(s_dir):
            if not fname.lower().endswith(".wav"):
                continue
            path = os.path.join(s_dir, fname)
            try:
                base = os.path.splitext(fname)[0]
                key_name = base
                if HAS_MUSIC21:
                    try:
                        p = pitch.Pitch(base)
                        key_name = p.nameWithOctave
                    except Exception:
                        pass
                data, fs = sf.read(path, dtype="float32", always_2d=False)
                if data.ndim == 2:
                    data = data.mean(axis=1)
                data, _ = resample_if_needed(data, fs, int(sd.default.samplerate or 44100))
                self.samples[key_name] = data
            except Exception as e:
                print(f"Failed to load {path}: {e}")

    def render_note(self, key_name, chain_params, length_ms: Optional[float] = None):
        if key_name not in self.samples:
            return None
        x = np.copy(self.samples[key_name])

        lp_cut = int(chain_params.get("lp_cut", 0))
        hp_cut = int(chain_params.get("hp_cut", 0))
        if hp_cut > 0:
            x = butter_filter(x, self.fs, hp_cut, kind="high", order=4)
        if lp_cut > 0:
            x = butter_filter(x, self.fs, lp_cut, kind="low", order=4)

        if chain_params.get("dist_on", False):
            x = soft_clip(x, drive=float(chain_params.get("dist_drive", 1.0)))

        if chain_params.get("comp_on", False):
            x = compressor(
                x,
                threshold_db=int(chain_params.get("comp_thr", -18)),
                ratio=max(1.0, float(chain_params.get("comp_ratio", 4.0))),
                makeup_db=int(chain_params.get("comp_makeup", 3)),
                attack_ms=int(chain_params.get("comp_attack", 5)),
                release_ms=int(chain_params.get("comp_release", 50)),
                fs=self.fs,
            )

        if chain_params.get("trem_on", False):
            x = tremolo(
                x, self.fs,
                rate_hz=float(chain_params.get("trem_rate", 5.0)),
                depth=float(chain_params.get("trem_depth", 0.5))
            )

        if chain_params.get("chorus_on", False):
            x = chorus(
                x, self.fs,
                depth_ms=float(chain_params.get("chorus_depth", 8.0)),
                rate_hz=float(chain_params.get("chorus_rate", 0.25)),
                mix=float(chain_params.get("chorus_mix", 0.3))
            )

        if chain_params.get("delay_on", False):
            x = make_delay(
                x, self.fs,
                tempo_label=chain_params.get("delay_tempo", "1/4"),
                feedback=float(chain_params.get("delay_feedback", 0.5)),
                mix=float(chain_params.get("delay_mix", 0.25))
            )

        if chain_params.get("reverb_on", False):
            x = schroeder_reverb(x, self.fs, mix=float(chain_params.get("reverb_mix", 0.2)))

        if length_ms is not None:
            length_samps = max(1, int(self.fs * (length_ms / 1000.0)))
            x = x[:length_samps]
            x = apply_fade(x, self.fs, in_ms=3, out_ms=min(20, max(3, int(length_ms * 0.2))))

        x = np.tanh(1.2 * x).astype(np.float32)
        x = normalize(x, peak=float(getattr(self, "master_gain", 0.9)))
        x = kill_dc_and_dither(x, dither_amp=1e-6)
        return x

    def add_voice(self, buf: np.ndarray):
        if buf is None or len(buf) == 0:
            return
        if not self._ensure_stream():
            try:
                sd.play(buf.astype(np.float32, copy=False), samplerate=self.fs, blocking=False)
            except Exception as e:
                print(f"[AudioEngine] Fallback sd.play failed: {e}")
            return
        with self.lock:
            self.voices.append({"buf": buf.astype(np.float32, copy=False), "pos": 0})

    def play(self, key_name, chain_params, length_ms: Optional[float] = None):
        buf = self.render_note(key_name, chain_params, length_ms=length_ms)
        self.add_voice(buf)

# =========================
# Effect Pop-out Dialogs
# =========================
class BaseEffectDialog(QDialog):
    def __init__(self, title, params: dict, setp_cb, parent=None):
        super().__init__(parent)
        self.setWindowTitle(title)
        self.setWindowFlag(Qt.WindowType.Tool, True)
        self.params = params
        self._setp = setp_cb
        self.setMinimumWidth(360)
        self.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred)

    def _row(self, label_text, slider_min, slider_max, value_int, on_change_int, postfix=""):
        from PyQt6.QtWidgets import QHBoxLayout
        row = QHBoxLayout()
        lbl = QLabel(label_text)
        val_lbl = QLabel(str(value_int) + postfix)
        s = QSlider(Qt.Orientation.Horizontal)
        s.setRange(int(slider_min), int(slider_max))
        s.setValue(int(value_int))
        def _chg(v):
            val_lbl.setText(str(int(v)) + postfix)
            on_change_int(int(v))
        s.valueChanged.connect(_chg)
        row.addWidget(lbl); row.addWidget(s); row.addWidget(val_lbl)
        return row, s

    def _row_float(self, label_text, slider_min, slider_max, value_float, scale=100.0, on_change_float=None, postfix=""):
        from PyQt6.QtWidgets import QHBoxLayout
        row = QHBoxLayout()
        lbl = QLabel(label_text)
        shown = int(round(value_float * scale))
        val_lbl = QLabel(f"{value_float:.2f}{postfix}")
        s = QSlider(Qt.Orientation.Horizontal)
        s.setRange(int(slider_min), int(slider_max))
        s.setValue(shown)
        def _chg(v):
            fval = v / scale
            val_lbl.setText(f"{fval:.2f}{postfix}")
            if on_change_float:
                on_change_float(fval)
        s.valueChanged.connect(_chg)
        row.addWidget(lbl); row.addWidget(s); row.addWidget(val_lbl)
        return row, s

class ReverbDialog(BaseEffectDialog):
    def __init__(self, params, setp_cb, parent=None):
        super().__init__("Reverb", params, setp_cb, parent)
        root = QVBoxLayout(self)
        on = QCheckBox("Enable Reverb"); on.setChecked(bool(params.get("reverb_on", False)))
        on.toggled.connect(lambda c: self._setp("reverb_on", bool(c)))
        root.addWidget(on)
        row, _ = self._row_float("Mix", 0, 100, float(params.get("reverb_mix", 0.2)), scale=100.0,
                                 on_change_float=lambda v: self._setp("reverb_mix", v))
        root.addLayout(row); root.addStretch(1)

class DelayDialog(BaseEffectDialog):
    def __init__(self, params, setp_cb, parent=None):
        super().__init__("Delay", params, setp_cb, parent)
        root = QVBoxLayout(self)
        on = QCheckBox("Enable Delay"); on.setChecked(bool(params.get("delay_on", False)))
        on.toggled.connect(lambda c: self._setp("delay_on", bool(c)))
        root.addWidget(on)
        tempo_row = QHBoxLayout(); tempo_row.addWidget(QLabel("Tempo"))
        tempo_cb = QComboBox(); tempo_cb.addItems(["1/2", "1/4", "1/8", "1/16", "1/8T"])
        tempo_cb.setCurrentText(params.get("delay_tempo", "1/4"))
        tempo_cb.currentTextChanged.connect(lambda t: self._setp("delay_tempo", t))
        tempo_row.addWidget(tempo_cb)
        root.addLayout(tempo_row)
        row, _ = self._row_float("Feedback", 0, 95, float(params.get("delay_feedback", 0.5)), scale=100.0,
                                 on_change_float=lambda v: self._setp("delay_feedback", v))
        root.addLayout(row)
        row, _ = self._row_float("Mix", 0, 100, float(params.get("delay_mix", 0.25)), scale=100.0,
                                 on_change_float=lambda v: self._setp("delay_mix", v))
        root.addLayout(row); root.addStretch(1)

class ChorusDialog(BaseEffectDialog):
    def __init__(self, params, setp_cb, parent=None):
        super().__init__("Chorus", params, setp_cb, parent)
        root = QVBoxLayout(self)
        on = QCheckBox("Enable Chorus"); on.setChecked(bool(params.get("chorus_on", False)))
        on.toggled.connect(lambda c: self._setp("chorus_on", bool(c)))
        root.addWidget(on)
        row, _ = self._row("Depth (ms)", 1, 20, int(round(float(params.get("chorus_depth", 8.0)))),
                           on_change_int=lambda v: self._setp("chorus_depth", float(v)))
        root.addLayout(row)
        row, _ = self._row_float("Rate (Hz)", 1, 200, float(params.get("chorus_rate", 0.25)), scale=100.0,
                                 on_change_float=lambda v: self._setp("chorus_rate", v))
        root.addLayout(row)
        row, _ = self._row_float("Mix", 0, 100, float(params.get("chorus_mix", 0.3)), scale=100.0,
                                 on_change_float=lambda v: self._setp("chorus_mix", v))
        root.addLayout(row); root.addStretch(1)

class TremoloDialog(BaseEffectDialog):
    def __init__(self, params, setp_cb, parent=None):
        super().__init__("Tremolo", params, setp_cb, parent)
        root = QVBoxLayout(self)
        on = QCheckBox("Enable Tremolo"); on.setChecked(bool(params.get("trem_on", False)))
        on.toggled.connect(lambda c: self._setp("trem_on", bool(c)))
        root.addWidget(on)
        row, _ = self._row_float("Rate (Hz)", 1, 1200, float(params.get("trem_rate", 5.0)), scale=100.0,
                                 on_change_float=lambda v: self._setp("trem_rate", v))
        root.addLayout(row)
        row, _ = self._row_float("Depth", 0, 100, float(params.get("trem_depth", 0.5)), scale=100.0,
                                 on_change_float=lambda v: self._setp("trem_depth", v))
        root.addLayout(row); root.addStretch(1)

class DistortionDialog(BaseEffectDialog):
    def __init__(self, params, setp_cb, parent=None):
        super().__init__("Distortion", params, setp_cb, parent)
        root = QVBoxLayout(self)
        on = QCheckBox("Enable Distortion"); on.setChecked(bool(params.get("dist_on", False)))
        on.toggled.connect(lambda c: self._setp("dist_on", bool(c)))
        root.addWidget(on)
        row, _ = self._row_float("Drive", 10, 400, float(params.get("dist_drive", 1.0)), scale=100.0,
                                 on_change_float=lambda v: self._setp("dist_drive", v))
        root.addLayout(row); root.addStretch(1)

class FilterDialog(BaseEffectDialog):
    def __init__(self, params, setp_cb, parent=None):
        super().__init__("Filters", params, setp_cb, parent)
        root = QVBoxLayout(self)
        hp_row, _ = self._row("High-Pass (Hz)", 0, 5000, int(round(float(params.get("hp_cut", 0)))),
                              on_change_int=lambda v: self._setp("hp_cut", int(v)))
        root.addLayout(hp_row)
        lp_row, _ = self._row("Low-Pass (Hz)", 0, 20000, int(round(float(params.get("lp_cut", 0)))),
                              on_change_int=lambda v: self._setp("lp_cut", int(v)))
        root.addLayout(lp_row)
        root.addStretch(1)

class CompressorDialog(BaseEffectDialog):
    def __init__(self, params, setp_cb, parent=None):
        super().__init__("Compressor", params, setp_cb, parent)
        root = QVBoxLayout(self)
        on = QCheckBox("Enable Compressor"); on.setChecked(bool(params.get("comp_on", False)))
        on.toggled.connect(lambda c: self._setp("comp_on", bool(c)))
        root.addWidget(on)
        row, _ = self._row("Threshold (dB)", -40, 0, int(round(float(params.get("comp_thr", -18)))),
                           on_change_int=lambda v: self._setp("comp_thr", int(v)))
        root.addLayout(row)
        ratio = float(params.get("comp_ratio", 4.0))
        row, _ = self._row("Ratio x:1 (x10)", 10, 200, int(round(ratio * 10)),
                           on_change_int=lambda v: self._setp("comp_ratio", max(1.0, v / 10.0)))
        root.addLayout(row)
        row, _ = self._row("Makeup (dB)", 0, 24, int(round(float(params.get("comp_makeup", 3.0)))),
                           on_change_int=lambda v: self._setp("comp_makeup", int(v)))
        root.addLayout(row)
        row, _ = self._row("Attack (ms)", 1, 100, int(round(float(params.get("comp_attack", 5.0)))),
                           on_change_int=lambda v: self._setp("comp_attack", int(v)))
        root.addLayout(row)
        row, _ = self._row("Release (ms)", 10, 500, int(round(float(params.get("comp_release", 50.0)))),
                           on_change_int=lambda v: self._setp("comp_release", int(v)))
        root.addLayout(row)
        root.addStretch(1)

# =========================
# MIDI Clip item & drop list
# =========================
class MidiClipItem(QListWidgetItem):
    def __init__(self, file_path: Optional[str], label: str,
                 chord_name: Optional[str] = None, pitch_names: Optional[List[str]] = None):
        super().__init__(label)
        self.file_path = file_path
        self.chord_name = chord_name
        self.pitch_names = pitch_names or []
        self.setToolTip(self._make_tooltip())

    def _make_tooltip(self):
        parts = []
        if self.file_path:
            parts.append(self.file_path)
        if self.chord_name:
            parts.append(f"Chord: {self.chord_name}")
        if self.pitch_names:
            parts.append("Notes: " + ", ".join(self.pitch_names))
        return "\n".join(parts) if parts else "Chord"

class MidiDropList(QListWidget):
    def __init__(self, parent=None, parse_cb=None, chord_parse_cb=None):
        super().__init__(parent)
        self.setAcceptDrops(True)
        self.setDropIndicatorShown(True)
        self.setDragDropMode(QListWidget.DragDropMode.InternalMove)
        self._parse_cb = parse_cb
        self._chord_parse_cb = chord_parse_cb

    def dragEnterEvent(self, e: QDragEnterEvent):
        md = e.mimeData()
        if md.hasUrls():
            if any(u.isLocalFile() and u.toLocalFile().lower().endswith(".mid") for u in md.urls()):
                e.acceptProposedAction(); return
        if md.hasFormat("application/x-chord-name") or md.hasText():
            e.acceptProposedAction(); return
        super().dragEnterEvent(e)

    def dragMoveEvent(self, e: QDragEnterEvent):
        self.dragEnterEvent(e)

    def dropEvent(self, e: QDropEvent):
        md = e.mimeData()
        handled = False

        if md.hasUrls():
            for u in md.urls():
                if u.isLocalFile():
                    p = u.toLocalFile()
                    if p.lower().endswith(".mid") and os.path.exists(p):
                        label, chord_name, note_names = (os.path.basename(p), None, [])
                        if callable(self._parse_cb):
                            try:
                                label, chord_name, note_names = self._parse_cb(p)
                            except Exception:
                                pass
                        self.addItem(MidiClipItem(p, label, chord_name, note_names))
                        handled = True

        if not handled and (md.hasFormat("application/x-chord-name") or md.hasText()):
            try:
                if md.hasFormat("application/x-chord-name"):
                    ba = md.data("application/x-chord-name")
                    txt = bytes(ba).decode("utf-8")
                else:
                    txt = md.text()
                txt = (txt or "").strip()
                if txt:
                    label, chord_name, note_names = (txt, txt, [])
                    if callable(self._chord_parse_cb):
                        label, chord_name, note_names = self._chord_parse_cb(txt)
                    self.addItem(MidiClipItem(None, label, chord_name, note_names))
                    handled = True
            except Exception:
                pass

        if handled:
            e.acceptProposedAction()
        else:
            super().dropEvent(e)

# =========================
# MIDI INPUT THREAD (pygame.midi)
# =========================
class MidiNoteEvent(QObject):
    noteOn = pyqtSignal(int, int)   # (midi_note, velocity)
    noteOff = pyqtSignal(int, int)  # (midi_note, velocity)

class MidiInputThread(QThread):
    def __init__(self, device_id: int, parent=None):
        super().__init__(parent)
        self.device_id = device_id
        self.running = False
        self.signals = MidiNoteEvent()
        self._inp = None

    def run(self):
        if not HAS_PYGAME_MIDI:
            return
        try:
            if not pygame.get_init():
                pygame.init()
            if not pgmidi.get_init():
                pgmidi.init()
            self._inp = pgmidi.Input(self.device_id)
            self.running = True
            while self.running:
                if self._inp.poll():
                    events = self._inp.read(16)
                    for ev in events:
                        data, _ts = ev
                        status = data[0] & 0xF0
                        d1 = data[1]
                        d2 = data[2]
                        if status == 0x90:
                            if d2 > 0:
                                self.signals.noteOn.emit(int(d1), int(d2))
                            else:
                                self.signals.noteOff.emit(int(d1), 0)
                        elif status == 0x80:
                            self.signals.noteOff.emit(int(d1), int(d2))
                self.msleep(2)
        except Exception as e:
            print(f"[MIDI] Error: {e}")
        finally:
            try:
                if self._inp:
                    self._inp.close()
            except Exception:
                pass

    def stop(self):
        self.running = False

# =========================
# Piano UI + MIDI parsing + Chord playback + ARPEGGIATOR
# =========================
class PianoKey(QPushButton):
    def __init__(self, pitch_name, is_black=False, parent=None):
        super().__init__(parent)
        self.pitch_name = pitch_name
        self.is_black = is_black
        my_style = (
            "QPushButton{background-color:#333;border:1px solid #555;border-radius:4px}"
            "QPushButton:pressed{background-color:#007acc}"
        ) if is_black else (
            "QPushButton{background-color:#fdfdfd;border:1px solid #ccc;border-radius:4px;"
            "color:#555;font-weight:bold;padding:0}"
            "QPushButton:pressed{background-color:#00aaff}"
        )
        self._base_style = my_style
        self.setStyleSheet(my_style)
        self.setFixedSize(BLACK_KEY_W, BLACK_KEY_H) if is_black else self.setFixedSize(WHITE_KEY_W, WHITE_KEY_H)
        if not is_black:
            self.setText(self.pitch_name[0])

    def flash(self, ms=200):
        if self.is_black:
            self.setStyleSheet("QPushButton{background-color:#555;border:1px solid #66a;border-radius:4px}")
        else:
            self.setStyleSheet(
                "QPushButton{background-color:#bfe9ff;border:1px solid #66a;border-radius:4px;"
                "color:#333;font-weight:bold;padding:0}"
            )
        QTimer.singleShot(ms, lambda: self.setStyleSheet(self._base_style))

class ArpDialog(QDialog):
    def __init__(self, parent=None, get_params_cb=None, set_params_cb=None):
        super().__init__(parent)
        self.setWindowTitle("Arpeggiator")
        self.setWindowFlag(Qt.WindowType.Tool, True)
        self.setMinimumWidth(380)
        self.getp = get_params_cb
        self.setp = set_params_cb

        root = QVBoxLayout(self)

        row = QHBoxLayout()
        row.addWidget(QLabel("Pattern"))
        self.pattern_cb = QComboBox()
        self.pattern_cb.addItems(["Up", "Down", "UpDown", "Random", "As-Listed"])
        self.pattern_cb.currentTextChanged.connect(lambda t: self._set("arp_pattern", t))
        row.addWidget(self.pattern_cb)
        root.addLayout(row)

        row = QHBoxLayout()
        row.addWidget(QLabel("Subdivision"))
        self.subdiv_cb = QComboBox()
        self.subdiv_cb.addItems(["1/4", "1/8", "1/8T", "1/16", "1/32"])
        self.subdiv_cb.currentTextChanged.connect(lambda t: self._set("arp_subdiv", t))
        row.addWidget(self.subdiv_cb)
        root.addLayout(row)

        row = QHBoxLayout()
        row.addWidget(QLabel("Gate %"))
        self.gate_slider = QSlider(Qt.Orientation.Horizontal); self.gate_slider.setRange(10, 100)
        self.gate_slider.valueChanged.connect(lambda v: self._set("arp_gate", int(v)))
        row.addWidget(self.gate_slider)
        self.gate_val = QLabel("70%"); row.addWidget(self.gate_val)
        self.gate_slider.valueChanged.connect(lambda v: self.gate_val.setText(f"{v}%"))
        root.addLayout(row)

        row = QHBoxLayout()
        row.addWidget(QLabel("Swing %"))
        self.swing_slider = QSlider(Qt.Orientation.Horizontal); self.swing_slider.setRange(0, 75)
        self.swing_slider.valueChanged.connect(lambda v: self._set("arp_swing", int(v)))
        row.addWidget(self.swing_slider)
        self.swing_val = QLabel("0%"); row.addWidget(self.swing_val)
        self.swing_slider.valueChanged.connect(lambda v: self.swing_val.setText(f"{v}%"))
        root.addLayout(row)

        row = QHBoxLayout()
        row.addWidget(QLabel("Octave Range"))
        self.octaves_spin = QSpinBox(); self.octaves_spin.setRange(1, 4)
        self.octaves_spin.valueChanged.connect(lambda v: self._set("arp_octaves", int(v)))
        row.addWidget(self.octaves_spin)
        root.addLayout(row)

        enabled = QCheckBox("Enable Arpeggiator"); enabled.toggled.connect(lambda c: self._set("arp_on", bool(c)))
        root.addWidget(enabled)

        root.addStretch(1)

        if callable(self.getp):
            p = self.getp()
            self.pattern_cb.setCurrentText(p.get("arp_pattern", "Up"))
            self.subdiv_cb.setCurrentText(p.get("arp_subdiv", "1/8"))
            self.gate_slider.setValue(int(p.get("arp_gate", 70))); self.gate_val.setText(f"{int(p.get('arp_gate',70))}%")
            self.swing_slider.setValue(int(p.get("arp_swing", 0))); self.swing_val.setText(f"{int(p.get('arp_swing',0))}%")
            self.octaves_spin.setValue(int(p.get("arp_octaves", 1)))
            enabled.setChecked(bool(p.get("arp_on", False)))

    def _set(self, k, v):
        if callable(self.setp):
            self.setp(k, v)

# =========================
# Piano Widget (MIDI + keyboard input)
# =========================
class PianoWidget(QWidget):
    notePlayed = pyqtSignal(object)

    def __init__(self, start_octave=3, num_octaves=2, parent=None):
        super().__init__(parent)
        self.setObjectName("centralWidget")
        self.engine = AudioEngine(sample_rate=44100)
        self.start_octave = start_octave
        self.num_octaves = num_octaves

        self.params = {
            "reverb_on": False, "reverb_mix": 0.2,
            "delay_on": False, "delay_tempo": "1/4", "delay_feedback": 0.5, "delay_mix": 0.25,
            "chorus_on": False, "chorus_depth": 8.0, "chorus_rate": 0.25, "chorus_mix": 0.3,
            "trem_on": False, "trem_rate": 5.0, "trem_depth": 0.5,
            "dist_on": False, "dist_drive": 1.0,
            "comp_on": False, "comp_thr": -18, "comp_ratio": 4.0, "comp_makeup": 3.0,
            "comp_attack": 5.0, "comp_release": 50.0,
            "lp_cut": 0, "hp_cut": 0,
            "arp_on": False, "arp_pattern": "Up", "arp_subdiv": "1/8", "arp_gate": 70, "arp_swing": 0, "arp_octaves": 1
        }

        self._dialogs = {}
        self.black_buttons = []
        self.white_buttons = []
        self.key_map = {}
        self._arp_timer = QTimer(self); self._arp_timer.timeout.connect(self._arp_tick)
        self._arp_running = False
        self._arp_seq: List[str] = []
        self._arp_index = 0
        self._arp_step_ms = 125.0

        # MIDI state
        self._midi_thread: Optional[MidiInputThread] = None
        self._midi_connected = False

        # Computer keyboard mapping
        self._kb_white = ['A','S','D','F','G','H','J','K','L',';',"'", "]"]
        self._kb_black = ['W','E','T','Y','U','O','P']
        self._kb_pressed = set()

        self._init_ui()
        self.setFocusPolicy(Qt.FocusPolicy.StrongFocus)
        self._update_arp_step_ms()

    # -------- START: ADDED FIX --------
    
    def _set_param(self, key, value):
        """Helper to set a parameter and update dependent systems."""
        self.params[key] = value
        
        # If an arp setting or BPM changed, update the timer logic
        if key.startswith("arp_") or key == "bpm":
            self._update_arp_step_ms()

    def _open_dialog(self, key: str, btn: QPushButton):
        """
        Handles opening/closing effect dialogs based on the button's check state.
        Also syncs the button's state with the corresponding '_on' parameter.
        """
        
        # 1. Sync the parameter from the button's new state.
        on_key = {
            "reverb":"reverb_on","delay":"delay_on","chorus":"chorus_on",
            "tremolo":"trem_on","distortion":"dist_on","compressor":"comp_on","arp":"arp_on"
        }.get(key)
        
        if on_key:
            self._set_param(on_key, btn.isChecked())

        # 2. Handle dialog visibility.
        if not btn.isChecked():
            # Button is OFF: Close the dialog if it exists.
            if key in self._dialogs and self._dialogs[key]:
                try:
                    self._dialogs[key].close()
                except Exception: 
                    pass # Window may already be gone
                self._dialogs[key] = None
            return # Done.

        # Button is ON: Open or raise the dialog.
        if key in self._dialogs and self._dialogs[key]:
            try:
                self._dialogs[key].raise_()
                self._dialogs[key].activateWindow()
                return # Done, it's already open.
            except Exception:
                self._dialogs[key] = None # Stale reference

        # 3. Map key to the correct Dialog Class.
        dialog_map = {
            "reverb": ReverbDialog,
            "delay": DelayDialog,
            "chorus": ChorusDialog,
            "tremolo": TremoloDialog,
            "distortion": DistortionDialog,
            "filters": FilterDialog,
            "compressor": CompressorDialog,
            "arp": ArpDialog
        }
        Klass = dialog_map.get(key)
        if not Klass:
            btn.setChecked(False) # Should not happen
            if on_key: self._set_param(on_key, False)
            return

        # 4. Create and show the new dialog instance.
        try:
            if key == "arp":
                dialog = ArpDialog(parent=self, 
                                   get_params_cb=lambda: self.params, 
                                   set_params_cb=self._set_param)
            else:
                dialog = Klass(params=self.params, 
                               setp_cb=self._set_param, 
                               parent=self)
            
            self._dialogs[key] = dialog
            
            # 5. Connect the dialog's 'finished' signal (e.g., 'X' button)
            #    to uncheck our main button and sync the param.
            def on_finish(result_code):
                btn.setChecked(False)
                if on_key:
                    self._set_param(on_key, False)
                if key in self._dialogs:
                    self._dialogs[key] = None
            
            dialog.finished.connect(on_finish)
            
            # Sync the dialog's internal checkbox (if it has one)
            if on_key:
                cb = dialog.findChild(QCheckBox)
                if cb:
                    cb.setChecked(True) # We know btn.isChecked() is True here
            
            dialog.show()
            dialog.raise_()
            dialog.activateWindow()

        except Exception as e:
            print(f"Error opening dialog {key}: {e}")
            btn.setChecked(False)
            if on_key: self._set_param(on_key, False)
            
    # -------- END: ADDED FIX --------

    # -------- chord helpers --------
    def _infer_chord_name_from_pitches(self, m21_pitches) -> Optional[str]:
        if not m21_pitches: return None
        try:
            c = chord.Chord(m21_pitches)
            root_name = c.root().name
            root_pc = pitch.Pitch(root_name).pitchClass
        except Exception:
            pcs = sorted(set(int(p.pitchClass) for p in m21_pitches))
            root_pc = pcs[0]
            root_name = pitch.Pitch(midi=root_pc).name
        pcs_full = sorted(set(int(p.pitchClass) for p in m21_pitches))
        intervals = sorted(((pc - root_pc) % 12) for pc in pcs_full)
        s = set(intervals); name = root_name
        if {0, 4, 7} <= s: qual = ""
        elif {0, 3, 7} <= s: qual = "m"
        elif {0, 3, 6} <= s: qual = "dim"
        elif {0, 4, 8} <= s: qual = "aug"
        elif {0, 2, 7} <= s: return name + "sus2"
        elif {0, 5, 7} <= s: return name + "sus4"
        else:
            if {0, 4} <= s: qual = "(no5)"
            elif {0, 3} <= s: qual = "m(no5)"
            else:
                try: return name + " " + " ".join(p.name for p in m21_pitches)
                except Exception: return None
        if 10 in s and {0, 4, 7} <= s: qual += "7"
        elif 11 in s and {0, 4, 7} <= s: qual += "maj7"
        elif 10 in s and {0, 3, 7} <= s: qual = "m7"
        elif 11 in s and {0, 3, 7} <= s: qual = "m(maj7)"
        elif {0, 3, 6, 10} <= s: qual = "ø7"
        elif {0, 3, 6, 9} <= s: qual = "dim7"
        if 2 in s and not (1 in s or 3 in s):
            qual += "add9" if "7" not in qual else ("" if "9" in qual else "add9")
        return name + qual

    def _parse_midi_clip(self, file_path: str) -> Tuple[str, Optional[str], List[str]]:
        label = os.path.basename(file_path)
        chord_name, note_names = None, []
        if not HAS_MUSIC21: return label, chord_name, note_names
        try:
            s = converter.parse(file_path)
            notes_sorted = sorted(s.recurse().notesAndRests, key=lambda n: float(n.offset))
            first_offset = None
            onset_notes = []
            for el in notes_sorted:
                if el.isRest: continue
                o = float(el.offset)
                if first_offset is None: first_offset = o
                if abs(o - first_offset) < 1e-6 and el.isNote:
                    onset_notes.append(el)
                elif o > first_offset + 1e-6:
                    break
            if onset_notes:
                m21_pitches = [n.pitch for n in onset_notes if hasattr(n, "pitch")]
                note_names = [p.nameWithOctave for p in m21_pitches if hasattr(p, "nameWithOctave")]
                c = chord.Chord(m21_pitches)
                try:
                    cs = harmony.chordSymbolFromChord(c); chord_name = cs.figure or None
                except Exception:
                    chord_name = None
                if not chord_name:
                    chord_name = self._infer_chord_name_from_pitches(m21_pitches)
            else:
                ch_stream = s.chordify().flat
                chs = ch_stream.getElementsByClass('Chord')
                if chs:
                    c = chs[0]
                    note_names = [p.nameWithOctave for p in c.pitches]
                    try:
                        cs = harmony.chordSymbolFromChord(c); chord_name = cs.figure
                    except Exception:
                        chord_name = self._infer_chord_name_from_pitches(c.pitches)
            if chord_name: label = f"{label} — {chord_name}"
            elif note_names: label = f"{label} — {' '.join(note_names)}"
        except Exception:
            pass
        return label, chord_name, note_names

    def _parse_chord_text(self, text: str) -> Tuple[str, Optional[str], List[str]]:
        chord_txt = (text or "").strip()
        if not chord_txt: return text, None, []
        pitches = self._derive_pitches_from_chord_name(chord_txt, self.chord_oct_spin.value())
        chord_name = chord_txt if pitches else None
        if not pitches and HAS_MUSIC21:
            try:
                cs = harmony.ChordSymbol(chord_txt)
                pitches = [pitch.Pitch(p).nameWithOctave for p in cs.pitches]
                if pitches:
                    for i in range(len(pitches)):
                        pn = pitch.Pitch(pitches[i]); pn.octave = self.chord_oct_spin.value()
                        pitches[i] = pn.nameWithOctave
                chord_name = chord_txt if pitches else None
            except Exception:
                chord_name = None
        label = f"{chord_txt}" if not pitches else f"{chord_txt} — {' '.join(pitches)}"
        return label, chord_name, pitches

    # -------- UI --------
    def _init_ui(self):
        wrapper = QVBoxLayout(self)
        wrapper.setContentsMargins(10, 10, 10, 10)
        wrapper.setSpacing(8)
        self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)

        top_bar = QHBoxLayout()
        top_bar.addWidget(QLabel("Instrument:"))
        self.instrument_combo = QComboBox()
        if os.path.isdir('samples'):
            self.instrument_combo.addItems([d for d in os.listdir('samples') if os.path.isdir(os.path.join('samples', d))])
        self.instrument_combo.currentTextChanged.connect(self.switch_instrument)
        top_bar.addWidget(self.instrument_combo)

        # --- MIDI controls (optional) ---
        self.midi_controls_container = QWidget()
        midi_layout = QHBoxLayout(self.midi_controls_container); midi_layout.setContentsMargins(8,0,0,0)
        midi_layout.addWidget(QLabel("MIDI In:"))
        self.midi_combo = QComboBox()
        self.midi_refresh_btn = QPushButton("Refresh")
        self.midi_connect_btn = QPushButton("Connect")
        self.midi_disconnect_btn = QPushButton("Disc"); self.midi_disconnect_btn.setEnabled(False)
        self.midi_status = QLabel("⏺ Not connected")
        midi_layout.addWidget(self.midi_combo)
        midi_layout.addWidget(self.midi_refresh_btn)
        midi_layout.addWidget(self.midi_connect_btn)
        midi_layout.addWidget(self.midi_disconnect_btn)
        midi_layout.addWidget(self.midi_status)
        if HAS_PYGAME_MIDI:
            self._populate_midi_devices()
            self.midi_refresh_btn.clicked.connect(self._populate_midi_devices)
            self.midi_connect_btn.clicked.connect(self._connect_midi)
            self.midi_disconnect_btn.clicked.connect(self._disconnect_midi)
            top_bar.addWidget(self.midi_controls_container)
        else:
            self.midi_controls_container.setVisible(False)

        top_bar.addSpacing(14)
        top_bar.addWidget(QLabel("Start Octave:"))
        self.start_oct_spin = QSpinBox(); self.start_oct_spin.setRange(0, 8)
        self.start_oct_spin.setValue(self.start_octave)
        self.start_oct_spin.valueChanged.connect(self._on_range_changed)
        top_bar.addWidget(self.start_oct_spin)

        top_bar.addWidget(QLabel("Octaves:"))
        self.num_oct_spin = QSpinBox(); self.num_oct_spin.setRange(1, 7)
        self.num_oct_spin.setValue(self.num_octaves)
        self.num_oct_spin.valueChanged.connect(self._on_range_changed)
        top_bar.addWidget(self.num_oct_spin)

        top_bar.addSpacing(14)
        top_bar.addWidget(QLabel("BPM:"))
        self.bpm_spin = QSpinBox(); self.bpm_spin.setRange(20, 300); self.bpm_spin.setValue(120)
        self.bpm_spin.valueChanged.connect(lambda _: self._update_arp_step_ms())
        top_bar.addWidget(self.bpm_spin)

        top_bar.addSpacing(14)
        top_bar.addWidget(QLabel("Master Vol"))
        self.master_slider = QSlider(Qt.Orientation.Horizontal)
        self.master_slider.setRange(10, 100)
        self.master_slider.setValue(int(round(self.engine.master_gain * 100)))
        self.master_slider.valueChanged.connect(lambda v: setattr(self.engine, "master_gain", v / 100.0))
        self.master_slider.setFixedWidth(140)
        top_bar.addWidget(self.master_slider)
        top_bar.addStretch()
        wrapper.addLayout(top_bar)

        self.scroll = QScrollArea()
        self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
        self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
        self.scroll.setWidgetResizable(False)
        self.keys_container = QWidget()
        self.keys_container.setMinimumHeight(WHITE_KEY_H)
        self.scroll.setWidget(self.keys_container)
        wrapper.addWidget(self.scroll)

        launcher = QHBoxLayout()
        launcher.setSpacing(5)
        def add_button(text, key):
            btn = QPushButton(text)
            btn.setCheckable(True)
            btn.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred)
            btn.clicked.connect(lambda: self._open_dialog(key, btn))
            if key in ("reverb","delay","chorus","tremolo","distortion","compressor","arp"):
                map_on = {"reverb":"reverb_on","delay":"delay_on","chorus":"chorus_on",
                          "tremolo":"trem_on","distortion":"dist_on","compressor":"comp_on","arp":"arp_on"}
                on_key = map_on.get(key)
                if on_key:
                    btn.setChecked(bool(self.params.get(on_key, False)))
            launcher.addWidget(btn)
        add_button("Reverb…", "reverb")
        add_button("Delay…", "delay")
        add_button("Chorus…", "chorus")
        add_button("Tremolo…", "tremolo")
        add_button("Distortion…", "distortion")
        add_button("Filters…", "filters")
        add_button("Compressor…", "compressor")
        add_button("Arp…", "arp")
        wrapper.addLayout(launcher)

        midi_box = QVBoxLayout()
        title_row = QHBoxLayout()
        title = QLabel("MIDI CLIPS (drop .mid or chord name here)")
        title.setObjectName("sectionTitle")
        title_row.addWidget(title)

        title_row.addSpacing(20)
        title_row.addWidget(QLabel("Chord Octave"))
        self.chord_oct_spin = QSpinBox()
        self.chord_oct_spin.setRange(0, 8)
        self.chord_oct_spin.setValue(4)
        title_row.addWidget(self.chord_oct_spin)
        title_row.addStretch()
        midi_box.addLayout(title_row)

        self.midi_list = MidiDropList(parse_cb=self._parse_midi_clip, chord_parse_cb=self._parse_chord_text)
        self.midi_list.setMinimumHeight(160)
        self.midi_list.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred)
        midi_box.addWidget(self.midi_list)

        ctrl = QHBoxLayout()
        play_midi_btn = QPushButton("Play MIDI (system)"); play_midi_btn.clicked.connect(self._play_selected_midi)
        play_chord_btn = QPushButton("Play Chord (Piano)"); play_chord_btn.clicked.connect(self._play_selected_chord)
        copy_name_btn = QPushButton("Copy Chord Name"); copy_name_btn.clicked.connect(self._copy_selected_chord_name)
        save_btn = QPushButton("Save Selected As…"); save_btn.clicked.connect(self._save_selected_midi_as)
        arp_start_btn = QPushButton("Start Arp"); arp_start_btn.clicked.connect(self._start_arp_from_selection)
        arp_stop_btn = QPushButton("Stop Arp"); arp_stop_btn.clicked.connect(self.stop_arp)
        arp_save_btn = QPushButton("Save Arp as MIDI…"); arp_save_btn.clicked.connect(self._save_arp_as_midi)
        for b in (play_midi_btn, play_chord_btn, copy_name_btn, save_btn, arp_start_btn, arp_stop_btn, arp_save_btn):
            b.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred)
        ctrl.addWidget(play_midi_btn); ctrl.addWidget(play_chord_btn); ctrl.addWidget(copy_name_btn)
        ctrl.addSpacing(10); ctrl.addWidget(save_btn); ctrl.addSpacing(10)
        ctrl.addWidget(arp_start_btn); ctrl.addWidget(arp_stop_btn); ctrl.addWidget(arp_save_btn)
        ctrl.addStretch()
        midi_box.addLayout(ctrl)

        wrapper.addLayout(midi_box)

        self._build_keys()
        if self.instrument_combo.count() > 0:
            self.switch_instrument(self.instrument_combo.currentText())

    # ---- MIDI device helpers ----
    def _populate_midi_devices(self):
        if not HAS_PYGAME_MIDI:
            return
        try:
            if not pygame.get_init():
                pygame.init()
            if not pgmidi.get_init():
                pgmidi.init()
        except Exception as e:
            print(f"[MIDI] init failed: {e}")
            return
        self.midi_combo.clear()
        try:
            dev_count = pgmidi.get_count()
            for i in range(dev_count):
                info = pgmidi.get_device_info(i)  # (interf, name, input, output, opened)
                name = info[1].decode(errors="ignore") if isinstance(info[1], (bytes, bytearray)) else str(info[1])
                is_input = bool(info[2])
                opened = bool(info[4])
                if is_input:
                    disp = f"{i}: {name}" + (" (busy)" if opened else "")
                    self.midi_combo.addItem(disp, i)
        except Exception as e:
            print(f"[MIDI] enumerate failed: {e}")

    def _connect_midi(self):
        if self._midi_connected or not HAS_PYGAME_MIDI:
            return
        if self.midi_combo.count() == 0:
            QMessageBox.information(self, "MIDI", "No MIDI input devices found.")
            return
        device_id = int(self.midi_combo.currentData())

        # Pre-check busy & direction
        try:
            info = pgmidi.get_device_info(device_id)
            is_input = bool(info[2]); opened = bool(info[4])
            if not is_input:
                QMessageBox.warning(self, "MIDI", "Selected device is not an input.")
                return
            if opened:
                QMessageBox.warning(self, "MIDI", "That device is busy (already opened by another app).")
                return
        except Exception:
            pass

        # slight delay helps some hosts right after init/refresh
        def _start():
            try:
                self._midi_thread = MidiInputThread(device_id, self)
                self._midi_thread.signals.noteOn.connect(self._on_midi_note_on)
                self._midi_thread.signals.noteOff.connect(self._on_midi_note_off)
                self._midi_thread.start()
                self._midi_connected = True
                self.midi_status.setText("✅ Connected")
                self.midi_connect_btn.setEnabled(False)
                self.midi_disconnect_btn.setEnabled(True)
            except Exception as e:
                QMessageBox.critical(self, "MIDI Error",
                                     f"Could not open MIDI input:\n{e}\n"
                                     "Close other apps using MIDI and try again.")
        QTimer.singleShot(50, _start)

    def _disconnect_midi(self):
        if not self._midi_connected:
            return
        try:
            if self._midi_thread:
                self._midi_thread.stop()
                self._midi_thread.wait(500)
                # safety: close input if thread created it but didn't hit finally
                if getattr(self._midi_thread, "_inp", None):
                    try:
                        self._midi_thread._inp.close()
                    except Exception:
                        pass
        except Exception:
            pass
        self._midi_thread = None
        self._midi_connected = False
        self.midi_status.setText("⏺ Not connected")
        self.midi_connect_btn.setEnabled(True)
        self.midi_disconnect_btn.setEnabled(False)

    def _on_midi_note_on(self, midi_note: int, velocity: int):
        name = self._midi_note_to_key_name(midi_note)
        if not name:
            return
        if name in self.key_map:
            self.key_map[name].flash(140)
        self.engine.play(name, self.params)

    def _on_midi_note_off(self, midi_note: int, velocity: int):
        # one-shot sampler: nothing required here
        pass

    def _midi_note_to_key_name(self, mn: int) -> Optional[str]:
        names = ['C','C#','D','D#','E','F','F#','G','G#','A','A#','B']
        octave = (mn // 12) - 1
        name = names[mn % 12] + str(octave)
        if name in self.key_map:
            return name
        if HAS_MUSIC21:
            try:
                p = pitch.Pitch(name)
                for delta in range(1, 24):
                    for sign in (+1, -1):
                        p_ = pitch.Pitch(midi=max(0, min(127, p.midi + sign*delta)))
                        cand = p_.nameWithOctave
                        if cand in self.key_map:
                            return cand
            except Exception:
                pass
        else:
            min_oct = self.start_octave
            max_oct = self.start_octave + self.num_octaves - 1
            base = names[mn % 12]
            for oc in range(min_oct, max_oct+1):
                cand = f"{base}{oc}"
                if cand in self.key_map:
                    return cand
        return None

    # -------- keys --------
    def _clear_keys(self):
        for b in self.white_buttons: b.setParent(None); b.deleteLater()
        for b, _, _ in self.black_buttons: b.setParent(None); b.deleteLater()
        self.white_buttons.clear(); self.black_buttons.clear(); self.key_map.clear()

    def _build_keys(self):
        self._clear_keys()
        notes_names = ['C','C#','D','D#','E','F','F#','G','G#','A','A#','B']
        black_indices = {1, 3, 6, 8, 10}
        wk_w = WHITE_KEY_W + WHITE_GAP
        total_white = 7 * self.num_octaves
        total_w = total_white * wk_w
        self.keys_container.setMinimumWidth(total_w)
        self.keys_container.resize(total_w, WHITE_KEY_H)
        y = 0; w_index = 0

        for octv in range(self.start_octave, self.start_octave + self.num_octaves):
            for i, name in enumerate(notes_names):
                if i in black_indices: continue
                pitch_name = f"{name}{octv}"
                key = PianoKey(pitch_name, False, self.keys_container)
                x = w_index * wk_w
                key.move(x, y)
                key.clicked.connect(lambda _, p=pitch_name: self._on_key_clicked(p))
                key.show()
                self.white_buttons.append(key); self.key_map[pitch_name] = key
                w_index += 1

        for oct_idx, octv in enumerate(range(self.start_octave, self.start_octave + self.num_octaves)):
            for pos_idx, semitone in enumerate(['C#', 'D#', 'F#', 'G#', 'A#']):
                pitch_name = f"{semitone}{octv}"
                key = PianoKey(pitch_name, True, self.keys_container)
                key.clicked.connect(lambda _, p=pitch_name: self._on_key_clicked(p))
                key.show()
                self.black_buttons.append((key, oct_idx, pos_idx)); self.key_map[pitch_name] = key
        self.reposition_black_keys()

    def resizeEvent(self, event):
        super().resizeEvent(event); self.reposition_black_keys()

    def reposition_black_keys(self):
        wk_w = WHITE_KEY_W + WHITE_GAP
        for key, oct_idx, pos_idx in self.black_buttons:
            x = int(oct_idx * 7 * wk_w + BLACK_OFFSETS[pos_idx] * wk_w)
            key.move(x, 0); key.raise_()

    def switch_instrument(self, name):
        sr = int(sd.default.samplerate or 44100)
        self.engine.fs = sr
        self.engine.load_dir(os.path.join('samples', name))

    def _on_range_changed(self, _):
        self.start_octave = self.start_oct_spin.value()
        self.num_octaves = self.num_oct_spin.value()
        self._build_keys()

    def _on_key_clicked(self, p_name):
        if HAS_MUSIC21:
            try: self.notePlayed.emit(note.Note(p_name))
            except Exception: pass
        if p_name in self.key_map: self.key_map[p_name].flash(150)
        self.engine.play(p_name, self.params)

    # -------- Computer Keyboard Handling --------
    def keyPressEvent(self, e: QKeyEvent):
        if e.isAutoRepeat():
            return
        ch = e.text().upper()
        if not ch or ch in self._kb_pressed:
            return
        self._kb_pressed.add(ch)

        names = ['C','C#','D','D#','E','F','F#','G','G#','A','A#','B']
        white_names, black_names = [], []
        for oc in range(self.start_octave, self.start_octave + self.num_octaves):
            for i, nm in enumerate(names):
                full = f"{nm}{oc}"
                if i in {1,3,6,8,10}: black_names.append(full)
                else: white_names.append(full)

        def play_if_exists(nm):
            if nm in self.key_map:
                self.key_map[nm].flash(120)
                self.engine.play(nm, self.params)

        if ch in [c.upper() for c in self._kb_white]:
            idx = [c.upper() for c in self._kb_white].index(ch)
            if idx < len(white_names):
                play_if_exists(white_names[idx]); return

        if ch in [c.upper() for c in self._kb_black]:
            idx = [c.upper() for c in self._kb_black].index(ch)
            if idx < len(black_names):
                play_if_exists(black_names[idx]); return

    def keyReleaseEvent(self, e: QKeyEvent):
        if e.isAutoRepeat():
            return
        ch = e.text().upper()
        if ch in self._kb_pressed:
            self._kb_pressed.remove(ch)

    # -------- Arpeggiator --------
    def _update_arp_step_ms(self):
        bpm = max(1, int(self.bpm_spin.value()))
        beat_ms = 60000.0 / bpm
        subdiv = self.params.get("arp_subdiv", "1/8")
        factor = {"1/4": 1.0, "1/8": 0.5, "1/8T": 1/3.0, "1/16": 0.25, "1/32": 0.125}.get(subdiv, 0.5)
        self._arp_step_ms = beat_ms * factor
        if self._arp_running:
            self._arp_timer.setInterval(int(self._arp_step_ms))

    def _build_arp_sequence(self, base_notes: List[str]) -> List[str]:
        if not base_notes: return []
        fitted = self._fit_notes_to_keyboard(base_notes, self.chord_oct_spin.value())
        if not fitted: return []

        def midi_of(nm):
            try:
                return pitch.Pitch(nm).midi if HAS_MUSIC21 else 60
            except Exception:
                return 60

        octaves = max(1, int(self.params.get("arp_octaves", 1)))
        expanded: List[str] = []
        for oc in range(octaves):
            for nm in fitted:
                if HAS_MUSIC21:
                    p = pitch.Pitch(nm); p.octave = p.octave + oc
                    expanded.append(p.nameWithOctave)
                else:
                    expanded.append(nm)

        expanded = [n for n in expanded if n in self.key_map]

        pattern = self.params.get("arp_pattern", "Up")
        if pattern == "As-Listed":
            seq = expanded
        elif pattern == "Up":
            seq = sorted(expanded, key=midi_of)
        elif pattern == "Down":
            seq = sorted(expanded, key=midi_of, reverse=True)
        elif pattern == "UpDown":
            up = sorted(expanded, key=midi_of)
            down = list(reversed(up))
            seq = up + down[1:-1] if len(up) > 1 else up
        elif pattern == "Random":
            seq = expanded[:]; random.shuffle(seq)
        else:
            seq = expanded
        return seq

    def _start_arp_from_selection(self):
        notes = self._get_notes_from_selection()
        if not notes:
            QMessageBox.information(self, "No notes", "Select a MIDI chord (or type a chord name) first.")
            return
        self.start_arp(notes)

    def start_arp(self, chord_notes: List[str]):
        seq = self._build_arp_sequence(chord_notes)
        if not seq:
            QMessageBox.information(self, "Arp", "Could not build arpeggio sequence from the chosen chord.")
            return
        self._arp_seq = seq
        self._arp_index = 0
        self._update_arp_step_ms()
        self._arp_timer.setInterval(int(self._arp_step_ms))
        self._arp_running = True
        self._arp_timer.start()

    def stop_arp(self):
        self._arp_running = False
        self._arp_timer.stop()

    def _arp_tick(self):
        if not self._arp_running or not self._arp_seq:
            self.stop_arp(); return
        nname = self._arp_seq[self._arp_index % len(self._arp_seq)]
        step_ms = float(self._arp_step_ms)
        gate_pct = max(5, min(100, int(self.params.get("arp_gate", 70))))
        length_ms = max(10.0, step_ms * (gate_pct / 100.0))

        swing_pct = max(0, min(75, int(self.params.get("arp_swing", 0))))
        is_off = (self._arp_index % 2 == 1)
        delay_ms = (swing_pct / 100.0) * (step_ms * 0.5) if is_off else 0.0

        if nname in self.key_map: self.key_map[nname].flash(int(min(240, length_ms)))
        self.engine.play(nname, self.params, length_ms=length_ms)

        next_interval = int(step_ms + delay_ms)
        self._arp_timer.setInterval(max(1, next_interval))
        self._arp_index += 1

    # -------- Save Arp as MIDI --------
    def _save_arp_as_midi(self):
        if not HAS_MUSIC21:
            QMessageBox.warning(self, "music21 missing", "Install music21 to export MIDI.")
            return

        notes = self._get_notes_from_selection()
        if not notes:
            QMessageBox.information(self, "No notes", "Select a MIDI chord (or type a chord name) first.")
            return

        seq = self._build_arp_sequence(notes)
        if not seq:
            QMessageBox.information(self, "Arp", "Could not build arpeggio sequence from the chosen chord.")
            return

        bpm = max(1, int(self.bpm_spin.value()))
        step_ms = float(self._arp_step_ms)
        qlen = (step_ms / 1000.0) * (bpm / 60.0)  # quarterLength per step

        gate_pct = max(5, min(100, int(self.params.get("arp_gate", 70))))
        gate_ql = qlen * (gate_pct / 100.0)

        st = stream.Stream()
        st.append(tempo.MetronomeMark(number=bpm))
        st.append(meter.TimeSignature('4/4'))

        t = 0.0
        for nm in seq:
            try:
                p = pitch.Pitch(nm)
                nobj = note.Note(p)
                nobj.duration = duration.Duration(gate_ql)
                st.insert(t, nobj)
                t += qlen
            except Exception:
                continue

        target, _ = QFileDialog.getSaveFileName(self, "Save Arp as MIDI", "arp.mid", "MIDI Files (*.mid)")
        if not target:
            return
        try:
            st.write('midi', fp=target)
            QMessageBox.information(self, "Saved", f"Arp MIDI saved to:\n{target}")
        except Exception as e:
            QMessageBox.critical(self, "Save Error", f"Could not save arp:\n{e}")

    # -------- MIDI/Chord helpers & save --------
    def _get_selected_midi_item(self) -> Optional[MidiClipItem]:
        it = self.midi_list.currentItem()
        return it if isinstance(it, MidiClipItem) else None

    def _play_selected_midi(self):
        item = self._get_selected_midi_item()
        if not item: QMessageBox.information(self, "No Selection", "Select a MIDI item first."); return
        if not item.file_path:
            QMessageBox.information(self, "No MIDI File", "This item only contains a chord name. Use 'Play Chord (Piano)'."); return
        if not HAS_MUSIC21:
            QMessageBox.warning(self, "music21 not installed", "Install music21 to preview MIDI."); return
        try:
            s = converter.parse(item.file_path); s.show('midi')
        except Exception as e:
            QMessageBox.critical(self, "MIDI Error", f"Failed to play MIDI:\n{e}")

    def _derive_pitches_from_chord_name(self, chord_name: str, default_oct: int) -> List[str]:
        if not HAS_MUSIC21: return []
        try:
            cs = harmony.ChordSymbol(chord_name); names_ = []
            for p_ in cs.pitches:
                pn = pitch.Pitch(p_); pn.octave = default_oct
                names_.append(pn.nameWithOctave)
            return names_
        except Exception:
            return []

    def _play_selected_chord(self):
        notes = self._get_notes_from_selection()
        if not notes:
            QMessageBox.information(self, "No notes found", "Could not determine chord notes for playback.")
            return
        playable = self._fit_notes_to_keyboard(notes, prefer_oct=self.chord_oct_spin.value())
        for nname in playable:
            if nname in self.key_map: self.key_map[nname].flash(180)
            self.engine.play(nname, self.params)

    def _copy_selected_chord_name(self):
        item = self._get_selected_midi_item()
        if not item or not item.chord_name:
            QMessageBox.information(self, "No Chord Name", "Select an item with a chord name."); return
        QGuiApplication.clipboard().setText(item.chord_name, mode=QClipboard.Mode.Clipboard)
        QMessageBox.information(self, "Copied", f"Copied chord name: {item.chord_name}")

    def _save_selected_midi_as(self):
        item = self._get_selected_midi_item()
        if not item: QMessageBox.information(self, "No Selection", "Select a MIDI item first."); return
        if not item.file_path:
            if not HAS_MUSIC21 or not item.chord_name:
                QMessageBox.information(self, "Nothing to save", "This item has no MIDI file."); return
            try:
                cs = harmony.ChordSymbol(item.chord_name)
                st = stream.Stream(); ch = chord.Chord(cs.pitches); ch.duration = duration.Duration('whole'); st.append(ch)
                tmp = os.path.join(tempfile.gettempdir(), f"{item.chord_name}.mid"); st.write('midi', fp=tmp)
                item.file_path = tmp
            except Exception as e:
                QMessageBox.critical(self, "Save Error", f"Could not create MIDI from chord:\n{e}"); return
        target, _ = QFileDialog.getSaveFileName(self, "Save MIDI As", os.path.basename(item.file_path), "MIDI Files (*.mid)")
        if not target: return
        try:
            with open(item.file_path, "rb") as src, open(target, "wb") as dst: dst.write(src.read())
            QMessageBox.information(self, "Saved", f"Saved to:\n{target}")
        except Exception as e:
            QMessageBox.critical(self, "Save Error", f"Could not save file:\n{e}")

    def _fit_notes_to_keyboard(self, names: List[str], prefer_oct: int) -> List[str]:
        if not names: return []
        min_oct = self.start_octave; max_oct = self.start_octave + self.num_octaves - 1
        out = []
        for n in names:
            try:
                p = pitch.Pitch(n) if HAS_MUSIC21 else None
                if not p: continue
                p.octave = prefer_oct
                while p.octave > max_oct: p.octave -= 1
                while p.octave < min_oct: p.octave += 1
                cand = p.nameWithOctave
                if cand not in self.key_map:
                    found = None
                    for delta in range(1, 5):
                        for ss in (prefer_oct - delta, prefer_oct + delta):
                            p.octave = ss; cand = p.nameWithOctave
                            if cand in self.key_map: found = cand; break
                        if found: break
                    if found: cand = found
                    else: continue
                out.append(cand)
            except Exception:
                continue
        return out

    def _get_notes_from_selection(self) -> List[str]:
        item = self._get_selected_midi_item()
        if not item: return []
        notes = list(item.pitch_names) if item.pitch_names else []
        if not notes and item.chord_name:
            notes = self._derive_pitches_from_chord_name(item.chord_name, self.chord_oct_spin.value())
        if not notes and item.file_path and HAS_MUSIC21:
            try:
                base = os.path.splitext(os.path.basename(item.file_path))[0]
                notes = self._derive_pitches_from_chord_name(base, self.chord_oct_spin.value())
            except Exception: pass
        return notes

    # -------- Cleanup --------
    def closeEvent(self, e):
        try:
            self._disconnect_midi()
            if HAS_PYGAME_MIDI and pgmidi.get_init():
                pgmidi.quit()  # release PortMidi host cleanly
        except Exception:
            pass
        super().closeEvent(e)

# =========================
# Window wrapper
# =========================
class PianoWindow(QWidget):
    def __init__(self, start_octave=3, num_octaves=4, parent=None, note_callback=None):
        super().__init__(parent, flags=Qt.WindowType.Window)
        self.setObjectName("mainWindow")
        self.setWindowTitle("Interactive Piano")
        self.setMinimumSize(900, 800)
        self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
        self._piano = PianoWidget(start_octave, num_octaves, self)
        if note_callback:
            self._piano.notePlayed.connect(note_callback)
        layout = QVBoxLayout(self); layout.addWidget(self._piano)

    def show_and_raise(self):
        self.show(); self.raise_(); self.activateWindow()

# =========================
# Quick run
# =========================
if __name__ == "__main__":
    app = QApplication(sys.argv)
    app.setAttribute(Qt.ApplicationAttribute.AA_EnableHighDpiScaling, True)
    app.setAttribute(Qt.ApplicationAttribute.AA_UseHighDpiPixmaps, True)
    base_font = QFont('Segoe UI', 10); app.setFont(base_font)
    app.setStyleSheet(load_stylesheet())
    w = PianoWindow(start_octave=3, num_octaves=3)
    w.show()
    sys.exit(app.exec())