#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Neuronal Membrane Circadian Oscillator Model

This code implements the neuronal membrane-associated circadian oscillator 
described in: "Autonomous circadian oscillators intrinsic to cell membranes"
Forlino et al., 2026

The model couples two timescales:
1. Fast timescale (milliseconds): Hodgkin-Huxley type action potential generation
2. Slow timescale (hours): Membrane-based circadian oscillator

The membrane oscillator operates through:
- Clock Channels (CC) mediate Ca2+ influx
- Ca2+ activates cAMP production (positive feedback)
- Ca2+ produces inhibitor Y (delayed negative feedback)
- Inhibitor Y inactivates CCs, closing the feedback loop
- This generates circadian modulation of neuronal excitability and firing rate

The simulation runs for 4 days and outputs time series data for analysis.
"""

import numpy as np
from scipy.integrate import solve_ivp
from scipy.signal import find_peaks
np.seterr(all='ignore')

# ============================================================================
# SLOW SYSTEM PARAMETERS (Membrane Circadian Oscillator)
# ============================================================================

# Clock Channel (CC) dynamics - units: h^-1
alpha = 0.7      # CC recruitment/activation rate
beta = 0.9       # CC inactivation rate (modulated by inhibitor Y)

# cAMP dynamics (ligand that activates CCs) - units: h^-1
alpha_cAMP = 2.87   # cAMP production rate (Ca2+-dependent)
gamma_cAMP = 2.87   # cAMP degradation rate

# Inhibitor Y dynamics (provides delayed negative feedback) - units: h^-1
alpha_Y = 0.2    # Inhibitor production rate (Ca2+-dependent)
gamma_Y = 0.2    # Inhibitor degradation rate

# Clock Channel properties
gs = 0.9         # Single channel conductance (mS/cm²)
Ncc_Tot = 2.22   # Total number of clock channels (a.u.)


# ============================================================================
# FAST SYSTEM PARAMETERS (Hodgkin-Huxley Neuronal Dynamics)
# ============================================================================

# --- Reversal Potentials (mV) ---
vNa = 45         # Sodium reversal potential
vK = -105        # Potassium reversal potential
vCa = 120        # Calcium reversal potential (HVA channels)
vcc = 10         # Clock channel reversal potential
vL = -62.5       # Leak reversal potential

# --- Maximum Conductances (mS/cm²) ---
# Spiking currents
gNa = 29         # Transient sodium (action potential upstroke)
gK = 13          # Delayed rectifier potassium (action potential repolarization)
gNaP = 20        # Persistent sodium (spontaneous firing)
gHVA = 0.2       # High-voltage-activated calcium
gL = 3           # Leak current

# --- Gating Variable Parameters ---
# Transient Sodium (Na) - activation only, inactivation via h = 1-nK
theta_mNa = -25      # Half-activation voltage (mV)
beta_mNa = -6.5      # Slope factor (mV)

# Delayed Rectifier Potassium (K)
theta_nK = -26       # Half-activation voltage (mV)
beta_nK = -9         # Slope factor (mV)
tau_nK = 10          # Base time constant (ms)

# Persistent Sodium (NaP)
theta_mNaP = -40     # Activation half-voltage (mV)
beta_mNaP = -4       # Activation slope (mV)
theta_hNaP = -54     # Inactivation half-voltage (mV)
beta_hNaP = 5        # Inactivation slope (mV)
tau_hNaP = 500       # Base inactivation time constant (ms)

# High-Voltage-Activated Calcium (HVA)
theta_mHVA = -10.0   # Half-activation voltage (mV)
beta_mHVA = -6.5     # Slope factor (mV)

# --- Other Electrophysiological Parameters ---
C = 2            # Membrane capacitance (μF/cm²)
f = 0.0039       # Current-to-concentration conversion factor (cm²/nC)
tau = 40         # Ca2+ clearance time constant (ms)
Ca0 = 0.37       # Baseline Ca2+ concentration (a.u.)


# ============================================================================
# GATING FUNCTIONS (Steady-state activation/inactivation)
# ============================================================================

def nK_inf(V):
    """Potassium channel activation (steady-state)."""
    return 1 / (1 + np.exp((V - theta_nK) / beta_nK))

def mNa_inf(V):
    """Sodium channel activation (steady-state)."""
    return 1 / (1 + np.exp((V - theta_mNa) / beta_mNa))

def mNaP_inf(V):
    """Persistent sodium activation (steady-state)."""
    return 1 / (1 + np.exp((V - theta_mNaP) / beta_mNaP))

def hNaP_inf(V):
    """Persistent sodium inactivation (steady-state)."""
    return 1 / (1 + np.exp((V - theta_hNaP) / beta_hNaP))

def mHVA_inf(V):
    """High-voltage-activated calcium activation (steady-state)."""
    return 1 / (1 + np.exp((V - theta_mHVA) / beta_mHVA))

def Hill(cAMP):
    """
    Hill function for cAMP-dependent CC activation.
    
    Parameters
    ----------
    cAMP : float
        Intracellular cAMP concentration (a.u.)
        
    Returns
    -------
    float
        CC open probability (0 to 1)
        
    Notes
    -----
    Uses Hill coefficient n=3 and K_1/2 = 1.47
    This introduces the nonlinearity necessary for oscillations.
    """
    return 1 / (1 + (1.47 / cAMP)**3)


# ============================================================================
# TIME CONSTANT FUNCTIONS
# ============================================================================

def nK_tau(V):
    """Voltage-dependent time constant for potassium activation."""
    return tau_nK / np.cosh((V - theta_nK) / (2.0 * beta_nK))

def hNaP_tau(V):
    """Voltage-dependent time constant for persistent sodium inactivation."""
    return tau_hNaP / np.cosh((V - theta_hNaP) / (2.0 * beta_hNaP))


# ============================================================================
# CONVERT RATE CONSTANTS FROM HOURS TO MILLISECONDS
# ============================================================================
# The slow oscillator operates on hour timescale, but ODEs are solved in ms

alpha_ms = alpha / 3600000
beta_ms = beta / 3600000
alpha_cAMP_ms = alpha_cAMP / 3600000
gamma_cAMP_ms = gamma_cAMP / 3600000
alpha_Y_ms = alpha_Y / 3600000
gamma_Y_ms = gamma_Y / 3600000


# ============================================================================
# COUPLED SYSTEM OF DIFFERENTIAL EQUATIONS
# ============================================================================

def compute_derivatives(t0, y):
    """
    Compute time derivatives for the coupled neuron-oscillator system.
    
    This function integrates the fast electrophysiological dynamics (millisecond
    timescale) with the slow membrane oscillator (hour timescale).
    
    Parameters
    ----------
    y : array_like
        State vector containing:
        y[0] : Ncc    - Number of active clock channels (a.u.)
        y[1] : Y      - Inhibitor concentration (a.u.)
        y[2] : cAMP   - Cyclic AMP concentration (a.u.)
        y[3] : V      - Membrane potential (mV)
        y[4] : Ca     - Intracellular calcium concentration (a.u.)
        y[5] : nK     - Potassium channel activation (0-1)
        y[6] : hNaP   - Persistent sodium inactivation (0-1)
    t0 : float
        Current time (ms) - not explicitly used but required by integrator
        
    Returns
    -------
    dy : ndarray
        Time derivatives of all state variables
        
    Notes
    -----
    The membrane oscillator (Ncc, Y, cAMP) operates on hour timescale.
    The electrical dynamics (V, Ca, nK, hNaP) operate on millisecond timescale.
    Ca2+ couples both timescales by modulating both spiking and oscillator dynamics.
    """
    
    dy = np.zeros((7,))
    
    # Extract state variables
    Ncc = y[0]      # Active clock channels
    Y = y[1]        # Inhibitor
    cAMP = y[2]     # Cyclic AMP
    V = y[3]        # Membrane potential
    Ca = y[4]       # Calcium concentration
    nK = y[5]       # Potassium activation
    hNaP = y[6]     # Persistent sodium inactivation
    
    # ========================================================================
    # COMPUTE IONIC CURRENTS
    # ========================================================================
    
    # Clock Channel current (modulated by cAMP)
    n = Hill(cAMP)
    Icc = Ncc * gs * n * (V - vcc)
    
    # Leak current
    IL = gL * (V - vL)
    
    # Transient sodium current (fast Na+ for action potentials)
    # Uses approximation: hNa = 1 - nK (inactivation tied to K activation)
    INa = gNa * (1 - nK) * (mNa_inf(V)**3) * (V - vNa)
    
    # Delayed rectifier potassium current
    IK = gK * (nK**4) * (V - vK)
    
    # Persistent sodium current (supports spontaneous firing)
    INaP = gNaP * mNaP_inf(V) * hNaP * (V - vNa)
    
    # High-voltage-activated calcium current
    IHVA = gHVA * mHVA_inf(V) * (V - vCa)
    
    # ========================================================================
    # DIFFERENTIAL EQUATIONS - SLOW SYSTEM (Membrane Oscillator)
    # ========================================================================
    
    # Clock Channel dynamics
    # Channels are recruited at rate alpha, inactivated by inhibitor Y
    dy[0] = alpha_ms * (Ncc_Tot - Ncc) - beta_ms * Ncc * Y
    
    # Inhibitor dynamics (delayed negative feedback)
    # Produced by Ca2+, provides delayed suppression of CC activity
    dy[1] = Ca * alpha_Y_ms - Y * gamma_Y_ms
    
    # cAMP dynamics (positive feedback modulator)
    # Produced by Ca2+, activates CCs via Hill function
    dy[2] = Ca * alpha_cAMP_ms - cAMP * gamma_cAMP_ms
    
    # ========================================================================
    # DIFFERENTIAL EQUATIONS - FAST SYSTEM (Electrophysiology)
    # ========================================================================
    
    # Membrane potential (current balance equation)
    dy[3] = -(Icc + IL + INa + IK + INaP) / C
    
    # Calcium dynamics
    # Influx through CC (fraction 0.3) and HVA channels, with clearance
    dy[4] = -f * (0.3 * Icc + IHVA) + (Ca0 - Ca) / tau
    
    # Potassium activation (first-order kinetics)
    dy[5] = (nK_inf(V) - nK) / nK_tau(V)
    
    # Persistent sodium inactivation (first-order kinetics)
    dy[6] = (hNaP_inf(V) - hNaP) / hNaP_tau(V)
    
    return dy



# ============================================================================
# INITIAL CONDITIONS
# ============================================================================
# These initial conditions represent a point on the limit cycle
init_cond = np.array([8.16171670e-01,  1.40615160e+00,  1.91132360e+00, 
                      -5.09219970e+01,  1.84736123e+00,  5.85939821e-02,  
                      3.53005281e-02])


# ============================================================================
# MAIN SIMULATION LOOP - 4 DAYS
# ============================================================================

print("Starting 4-day simulation of neuronal membrane oscillator...")
print("=" * 60)

results = []
Ncc = []
Y = []
cAMP = []
spike_train = []

for j in range(4):
    print(f"\nSimulating day {j+1}/4...")
    
    # Time span: one day in milliseconds
    t_start = 1000 * 3600 * 24 * j
    t_end = 1000 * 3600 * 24 * (j + 1)
    
    # Solve ODEs for one day
    single_system = solve_ivp(compute_derivatives, [t_start, t_end], init_cond)
    
    # Extract voltage for spike detection
    V = single_system["y"][3, :]
    
    # Downsample slow variables for storage (every 10000 points)
    T = single_system["t"][::10000]
    Ncc = single_system["y"][0, ::10000] 
    Y = single_system["y"][1, ::10000] 
    cAMP = single_system["y"][2, ::10000] 
    
    # Detect spikes (peaks in voltage above -20 mV)
    peaks = find_peaks(V, height=-20)[0]
    spike_train = single_system["t"][peaks]
    
    # Calculate inter-spike intervals and firing frequency
    ISI = np.diff(spike_train)
    frequency = 1000 / ISI  # Convert to Hz (spikes/second)
    
    # Save time series data for this day
    np.savetxt(f"time_{j}.txt", T)
    np.savetxt(f"Ncc_{j}.txt", Ncc)
    np.savetxt(f"Y_{j}.txt", Y)
    np.savetxt(f"cAMP_{j}.txt", cAMP)
    np.savetxt(f"spike_train_{j}.txt", spike_train)
    
    # Use final state as initial condition for next day
    init_cond = single_system["y"][:, -1]
    
    print(f"  - Completed day {j+1}")
    print(f"  - Number of spikes: {len(peaks)}")
    print(f"  - Mean firing rate: {len(peaks)/(24*3600):.3f} Hz")

print("\n" + "=" * 60)
print("4-day simulation complete!")


# ============================================================================
# EXTRACT 1-SECOND WINDOWS FOR DETAILED ANALYSIS (Figures 2E and 2F)
# ============================================================================

print("\nExtracting 1-second time windows at min/max firing rates...")

# Find times of minimum and maximum firing frequency
argmin = np.argmin(frequency)
argmax = np.argmax(frequency)

def find_closest_indices(T, spike_train):
    """
    Find indices in downsampled time array closest to spike times.
    
    Parameters
    ----------
    T : array_like
        Downsampled time array
    spike_train : array_like
        Spike times in original (high-resolution) time
        
    Returns
    -------
    indices : list
        Indices in T closest to each spike time
    """
    indices = []
    for spike_time in spike_train:
        closest_index = np.argmin(np.abs(T - spike_time))
        indices.append(closest_index)
    return indices

# Map spike times to downsampled time array
closest_indices = find_closest_indices(T, spike_train)
T_at_spikes = T[closest_indices]

# Get actual times of min and max firing rate
Tmin = T_at_spikes[argmin]
Tmax = T_at_spikes[argmax]

# Convert back to indices in full (non-downsampled) array
T_argmin = np.argmin(np.abs(T - Tmin)) * 10000
T_argmax = np.argmin(np.abs(T - Tmax)) * 10000

# Extract 1000-point windows (1 second)
# Minimum firing rate window
T_series_min = single_system["t"][T_argmin:T_argmin+1000] - single_system["t"][T_argmin]
V_series_min = single_system["y"][3, T_argmin:T_argmin+1000]
Ca_series_min = single_system["y"][4, T_argmin:T_argmin+1000]

# Maximum firing rate window
T_series_max = single_system["t"][T_argmax:T_argmax+1000] - single_system["t"][T_argmax]
V_series_max = single_system["y"][3, T_argmax:T_argmax+1000]
Ca_series_max = single_system["y"][4, T_argmax:T_argmax+1000]

# Save 1-second window data
np.savetxt("T_series_min.txt", T_series_min)
np.savetxt("V_series_min.txt", V_series_min)
np.savetxt("Ca_series_min.txt", Ca_series_min)

np.savetxt("T_series_max.txt", T_series_max)
np.savetxt("V_series_max.txt", V_series_max)
np.savetxt("Ca_series_max.txt", Ca_series_max)

print("  - Saved 1-second windows at min and max firing rates")
print("\nAll data files saved successfully!")
print("=" * 60)


# ============================================================================
# END OF SIMULATION
# ============================================================================