#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Neuronal Membrane Circadian Oscillator - Plotting Script

This script visualizes the results from the neuronal membrane oscillator simulation.
It loads time series data generated by NEURON.py and creates Figure 2 from:
"Autonomous circadian oscillators intrinsic to cell membranes"
Forlino et al., 2026

The figure shows:
- Circadian oscillations in clock channel availability (Ncc)
- Circadian modulation of neuronal firing rate
- Phase space trajectories (limit cycles)
- Relationship between firing rate and cAMP/Ncc
- Voltage and calcium traces during high and low firing periods
"""

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
import numpy as np


# ============================================================================
# LOAD AND CONCATENATE DATA FROM 4-DAY SIMULATION
# ============================================================================

print("Loading simulation data from 4 days...")

# Initialize arrays
T = []           # Time points (downsampled)
Ncc = []         # Clock channel availability
Y = []           # Inhibitor concentration
cAMP = []        # Cyclic AMP concentration
frequency = []   # Firing rate
spike_train = [] # Spike times
ISI = []         # Inter-spike intervals

# Load and concatenate data from each day
for j in range(4):
    # Load slow variables (downsampled)
    T = np.concatenate((T, np.loadtxt(f"time_{j}.txt")))
    Ncc = np.concatenate((Ncc, np.loadtxt(f"Ncc_{j}.txt")))
    Y = np.concatenate((Y, np.loadtxt(f"Y_{j}.txt")))
    cAMP = np.concatenate((cAMP, np.loadtxt(f"cAMP_{j}.txt")))
    
    # Load spike times and calculate ISI
    st = np.loadtxt(f"spike_train_{j}.txt")
    st = st[1:]  # Remove first spike (no preceding ISI)
    isi = np.diff(st)
    st = st[:-1]  # Remove last spike (no following ISI)
    
    spike_train = np.concatenate((spike_train, st))
    ISI = np.concatenate((ISI, isi))

# Convert ISI to firing frequency (Hz)
frequency = 1000 / ISI  # ISI in ms, so 1000/ISI gives Hz

print(f"  - Total simulation time: {T[-1]/3600000:.2f} hours")
print(f"  - Total spikes: {len(spike_train)}")
print(f"  - Mean firing rate: {np.mean(frequency):.3f} Hz")


# ============================================================================
# MAP SPIKE TIMES TO SLOW VARIABLE TIME POINTS
# ============================================================================

def find_closest_indices(T, spike_train):
    """
    Find indices in downsampled time array closest to each spike time.
    
    This function maps high-resolution spike times to the downsampled
    time array used for slow variables (Ncc, cAMP, Y).
    
    Parameters
    ----------
    T : array_like
        Downsampled time array (ms)
    spike_train : array_like
        High-resolution spike times (ms)
        
    Returns
    -------
    indices : list
        Indices in T closest to each spike time
    """
    indices = []
    for spike_time in spike_train:
        closest_index = np.argmin(np.abs(T - spike_time))
        indices.append(closest_index)
    return indices

# Get Ncc and cAMP values at spike times
closest_indices = find_closest_indices(T, spike_train)
Ncc_at_peaks = Ncc[closest_indices]
cAMP_at_peaks = cAMP[closest_indices]


# ============================================================================
# SMOOTH DATA WITH MOVING AVERAGE
# ============================================================================

def simple_moving_average(data, window):
    """
    Calculate simple moving average (SMA) to smooth noisy time series.
    
    Parameters
    ----------
    data : array_like
        The time series data to smooth
    window : int
        Window size for the moving average (number of points)
    
    Returns
    -------
    sma : ndarray
        Smoothed data (length reduced by window-1)
        
    Notes
    -----
    Uses 'valid' mode, so output length is len(data) - window + 1
    This removes high-frequency fluctuations while preserving circadian trends.
    """
    sma = np.convolve(data, np.ones(window), 'valid') / window
    return sma

# Apply moving average with 10000-point window
# This smooths out fast fluctuations while preserving circadian oscillations
spike_train_smooth = simple_moving_average(spike_train, 10000)
frequency_smooth = simple_moving_average(frequency, 10000)
Ncc_at_peaks_smooth = simple_moving_average(Ncc_at_peaks, 10000)
cAMP_at_peaks_smooth = simple_moving_average(cAMP_at_peaks, 10000)

print("Data smoothing complete.")


# ============================================================================
# LOAD 1-SECOND WINDOW DATA (MIN/MAX FIRING RATES)
# ============================================================================

# Load data for detailed 1-second time windows
# These show voltage and calcium traces during low (night) and high (day) firing
T_series_min = np.loadtxt("T_series_min.txt")
V_series_min = np.loadtxt("V_series_min.txt")
Ca_series_min = np.loadtxt("Ca_series_min.txt")

T_series_max = np.loadtxt("T_series_max.txt")
V_series_max = np.loadtxt("V_series_max.txt")
Ca_series_max = np.loadtxt("Ca_series_max.txt")


# ============================================================================
# CONFIGURE FIGURE AND SUBPLOT LAYOUT
# ============================================================================

print("\nGenerating Figure 2...")

# Create figure with custom layout
fig = plt.figure(figsize=(12, 8), constrained_layout=True)

def format_axes(fig):
    """Apply consistent formatting to all axes in the figure."""
    for i, ax in enumerate(fig.axes):
        ax.tick_params(axis='both', which='major', labelsize=14)
        ax.grid()

# Define main grid: 12 rows × 3 columns
gs = GridSpec(12, 3, figure=fig)

# Create sub-grids for Ncc and firing rate panels with day/night bars
gs00 = GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[6:9, 0], 
                               height_ratios=[1, 20])
gs01 = GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[6:9, 1], 
                               height_ratios=[1, 20])

# Define all subplots
ax1 = fig.add_subplot(gs[9:12, 0])  # Voltage traces (1-sec windows)
ax2 = fig.add_subplot(gs[9:12, 1])  # Calcium traces (1-sec windows)
ax3 = fig.add_subplot(gs[0:4, 2])   # Phase space: cAMP vs Ncc
ax4 = fig.add_subplot(gs[4:8, 2])   # Firing rate vs cAMP
ax5 = fig.add_subplot(gs[8:12, 2])  # Firing rate vs Ncc

ax6 = fig.add_subplot(gs00[0, 0])   # Day/night bar for Ncc
ax7 = fig.add_subplot(gs00[1, 0])   # Ncc time series
ax8 = fig.add_subplot(gs01[0, 0])   # Day/night bar for firing rate
ax9 = fig.add_subplot(gs01[1, 0])   # Firing rate time series

# Hide tick labels on day/night indicator bars
ax6.set_xticks([])
ax6.set_yticks([])
ax8.set_xticks([])
ax8.set_yticks([])


# ============================================================================
# PLOT CLOCK CHANNEL AVAILABILITY (Ncc) TIME SERIES
# ============================================================================

# Define day/night cycle parameters
night_start = 12  # ZT 12 (start of night)
night_end = 24    # ZT 24/0 (end of night)
total_duration = 24  # Hours per day

# Convert time to hours and prepare data
days = T / 3600000  # Convert ms to hours
activity = Ncc

# Configure day/night indicator bar (ax6)
ax6.set_ylim([0, 1])
ax6.set_xlim([np.min(days), np.max(days)])

# Plot Ncc time series (ax7)
ax7.plot(days, activity, color='blue', zorder=10)
ymin, ymax = ax7.get_ylim()

# Add day/night shading (dark = night, white = day)
for i in range(int(np.ceil(days[-1] / total_duration))):
    # Add black rectangle for night period (ZT 12-24)
    ax6.add_patch(patches.Rectangle(
        (night_start + i * total_duration, 0), 
        night_end - night_start, 
        1, 
        color='black'
    ))

# Format Ncc plot
ax7.set_xlabel('ZT', fontsize=14)
ax7.set_ylabel("Ncc (a.u.)", fontsize=14)

# Convert x-axis to Zeitgeber Time (ZT)
# ZT 0 = lights on, ZT 12 = lights off
maxtick = ((np.max(days)) // 12 + 1) * 12
numsteps = int(maxtick / 12 + 1)
xticks = np.linspace(0, maxtick, numsteps)
xtick_labels = (xticks % 24).astype(int)
ax7.set_xticks(xticks, xtick_labels)
ax7.set_xlim([np.min(days), np.max(days)])
ax7.set_ylim([ymin, ymax])


# ============================================================================
# PLOT FIRING RATE TIME SERIES
# ============================================================================

# Convert smoothed spike times to hours
days = spike_train_smooth / 3600000
activity = frequency_smooth

# Configure day/night indicator bar (ax8)
ax8.set_ylim([0, 1])
ax8.set_xlim([np.min(days), np.max(days)])

# Plot firing rate time series (ax9)
ax9.plot(days, activity, color='blue', zorder=10)
ymin, ymax = ax9.get_ylim()

# Add day/night shading
for i in range(int(np.ceil(days[-1] / total_duration))):
    ax8.add_patch(patches.Rectangle(
        (night_start + i * total_duration, 0), 
        night_end - night_start, 
        1, 
        color='black'
    ))

# Format firing rate plot
ax9.set_xlabel('ZT', fontsize=14)
ax9.set_ylabel("Firing rate (Hz)", fontsize=14)

# Convert x-axis to ZT
maxtick = ((np.max(days)) // 12 + 1) * 12
numsteps = int(maxtick / 12 + 1)
xticks = np.linspace(0, maxtick, numsteps)
xtick_labels = (xticks % 24).astype(int)
ax9.set_xticks(xticks, xtick_labels)
ax9.set_xlim([np.min(days), np.max(days)])
ax9.set_ylim([ymin, ymax])


# ============================================================================
# PLOT PHASE SPACE TRAJECTORIES (LIMIT CYCLES)
# ============================================================================

# Phase space: cAMP vs Ncc (shows limit cycle of membrane oscillator)
ax3.plot(Ncc, cAMP, linewidth=2.5)
ax3.set_ylabel("[cAMP] (a.u.)", fontsize=14)
ax3.set_xlabel(r'$N_{\mathrm{CC}}$ (a.u.)', fontsize=14)

# Firing rate vs cAMP (shows nearly linear relationship)
ax4.plot(cAMP_at_peaks_smooth, frequency_smooth, linewidth=2.5)
ax4.set_xlabel('[cAMP] (a.u.)', fontsize=14)
ax4.set_ylabel('Firing rate (Hz)', fontsize=14)

# Firing rate vs Ncc
ax5.plot(Ncc_at_peaks_smooth, frequency_smooth, linewidth=2.5)
ax5.set_xlabel(r'$N_{\mathrm{CC}}$ (a.u.)', fontsize=14)
ax5.set_ylabel('Firing rate (Hz)', fontsize=14)


# ============================================================================
# PLOT 1-SECOND WINDOWS: VOLTAGE AND CALCIUM
# ============================================================================

# Voltage traces during low (green) and high (red) firing periods
ax1.plot(T_series_min, V_series_min, color="green", label="Night (low firing)")
ax1.plot(T_series_max, V_series_max, color="red", label="Day (high firing)")
ax1.set_xlim([0, 1000])
ax1.set_xlabel('Time (ms)', fontsize=14)
ax1.set_ylabel('V (mV)', fontsize=14)

# Calcium traces during low (green) and high (red) firing periods
ax2.plot(T_series_min, Ca_series_min, color="green", label="Night (low firing)")
ax2.plot(T_series_max, Ca_series_max, color="red", label="Day (high firing)")
ax2.set_xlim([0, 1000])
ax2.set_xlabel('Time (ms)', fontsize=14)
ax2.set_ylabel('[Ca] (a.u.)', fontsize=14)


# ============================================================================
# FINALIZE AND SAVE FIGURE
# ============================================================================

# Apply formatting to all axes
format_axes(fig)

# Save high-resolution figure
plt.savefig("Figure_2.svg", dpi=1200)
print("Figure saved as 'Figure_2.svg'")

plt.show()

print("\nPlotting complete!")


# ============================================================================
# END OF PLOTTING SCRIPT
# ============================================================================