"""
Red Blood Cell Membrane Circadian Oscillator Model

This code implements the membrane-associated circadian oscillator described in:
"Autonomous circadian oscillators intrinsic to cell membranes"
Forlino et al., 2026

The model simulates circadian oscillations in anucleate red blood cells through a 
post-translational feedback loop involving:
- K+ ions activating pyruvate kinase (PK)
- PK driving glycolytic flux through metabolic intermediates
- Metabolites promoting PRX hyperoxidation (PRX-SO2/3)
- PRX-SO2/3 activating Gardos channels
- Gardos channels mediating K+ efflux, closing the feedback loop
"""

import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
np.seterr(all='ignore')

# ============================================================================
# GLOBAL PARAMETERS
# ============================================================================

# Hill function parameters for PRX-SO2/3 gating of Gardos channel
PRXII_50 = 1.0  # Half-maximal PRX-SO2/3 concentration
n_val = 2.5     # Hill coefficient (cooperativity)

# Production rates (synthesis rate constants, units: h^-1)
alpha_Met = 0.6344444370833332      # Production rate for metabolites
alpha_PrxII_c = 0.3172222185416666 # Production rate for PRX-SO2/3
alpha_PK = 0.2537777748333333      # Production rate for active pyruvate kinase

# Degradation rates (units: h^-1)
lambda_PK = 0.6344444370833332     # Degradation rate for active PK
lambda_Met = 0.6344444370833332     # Degradation rate for metabolites
lambda_PrxII_c = 0.6344444370833332 # Degradation rate for PRX-SO2/3

# Ion channel and membrane parameters
vK = -90.0      # Potassium reversal potential (mV)
g = 2.726       # Maximum Gardos channel conductance (mS/cm^2)
c = 0.1         # Current-to-concentration conversion factor (cm^2/nC)
K0 = 20         # Baseline intracellular K+ concentration (a.u.)

# Network topology
N = 6   # Number metabolic nodes 
N_ADD = N-1       # Number of additional metabolic nodes (delay elements)

# Simulation parameters
t_span = [0, 48]  # Time span: 0 to 48 hours
n_points = 5000   # Number of time points
t_points = np.linspace(t_span[0], t_span[1], n_points)

# Initial conditions for state variables
# Order: [PK, Met1, Met2, ..., MetN, PrxII]
Y0 = np.array([4.80468981, 3.88829623, 2.99286998, 2.24234517, 1.70040853,
               1.38342661, 1.28386205, 0.68894482])


# ============================================================================
# DIFFERENTIAL EQUATION SYSTEM
# ============================================================================

def goodwin_oscillator_oscillating(t, y, V_val, lambda_PrxII, alpha_PrxII, K_BOOL):
    """
    Goodwin-type oscillator with membrane feedback for RBC circadian rhythm.
    
    This function implements the system of ODEs describing the membrane-based
    circadian oscillator in red blood cells. The architecture follows a Goodwin
    oscillator with 3 + N_ADD nodes.
    
    Parameters
    ----------
    t : float
        Current time point (not explicitly used, required by solve_ivp)
    y : array_like
        State vector containing:
        y[0] : PK (active pyruvate kinase concentration)
        y[1] : Met1 (first metabolic intermediate)
        y[2] to y[1+N_ADD] : Additional metabolic delay nodes (x1 to x_N_ADD)
        y[2+N_ADD] : PrxII (hyperoxidized peroxiredoxin, PRX-SO2/3)
    V_val : float
        Membrane potential (mV)
    lambda_PrxII : float
        Degradation rate for PRX-SO2/3 (h^-1)
    alpha_PrxII : float
        Production rate for PRX-SO2/3 (h^-1)
    K_BOOL : int
        Boolean flag (0 or 1) to enable/disable K+ dynamics
        
    Returns
    -------
    dydt : ndarray
        Time derivatives of all state variables
        
    Notes
    -----
    The total number of nodes is: 2 (PK, Met1) + N_ADD + 1 (PrxII)
    The feedback loop operates through:
    1. K+ activates PK
    2. PK drives metabolic cascade
    3. Terminal metabolite produces PRX-SO2/3
    4. PRX-SO2/3 activates Gardos channels via Hill function
    5. Gardos channels mediate K+ efflux, reducing intracellular K+
    """
    
    # Total number of state variables
    N_TOTAL = 3 + N_ADD
    
    # Extract main variables
    PK = y[0]           # Pyruvate kinase
    Met1 = y[1]           # First metabolic intermediate
    PrxII = y[N_TOTAL - 1]  # Hyperoxidized peroxiredoxin (repressor/modulator)
    
    # ========================================================================
    # Calculate K+ concentration via PRX-SO2/3 feedback on Gardos channel
    # ========================================================================
    
    # Gardos channel conductance (Hill function activation by PRX-SO2/3)
    GGardos = g * PrxII**n_val / (PRXII_50**n_val + PrxII**n_val)
    
    # Gardos channel current (Ohm's law)
    IGardos = GGardos * (V_val - vK)
    
    # Intracellular K+ concentration
    # K+ decreases due to Gardos-mediated efflux
    K = (K0 - IGardos * c) * K_BOOL
    if K < 0:
        K = 0  # Prevent negative concentrations
    
    # Initialize derivative vector
    dydt = np.zeros(N_TOTAL)
    
    # ========================================================================
    # Node 1: Pyruvate Kinase (PK) - K+-activated
    # ========================================================================
    dPKdt = (alpha_PK * K) - (lambda_PK * PK)
    dydt[0] = dPKdt
    
    # ========================================================================
    # Node 2: First Metabolic Intermediate (Met1) - Activated by PK
    # ========================================================================
    dMet1dt = (alpha_Met * PK) - (lambda_Met * Met1)
    dydt[1] = dMet1dt
    
    # ========================================================================
    # Additional Metabolic Nodes (Met2 to MetN) - Delay elements
    # ========================================================================
    # These nodes create a time delay in the feedback loop, which is
    # necessary for oscillations in Goodwin-type systems
    for i in range(N_ADD):
        x_in = y[1 + i]      # Input from previous node
        x_out = y[2 + i]     # Current node concentration
        
        # Each node follows first-order kinetics
        dxidt = (alpha_Met * x_in) - (lambda_Met * x_out)
        dydt[2 + i] = dxidt
    
    # ========================================================================
    # Final Node: PRX-SO2/3 (Repressor/Modulator)
    # ========================================================================
    # PRX-SO2/3 is activated by the last metabolic node and provides
    # the nonlinear feedback to close the loop
    Predecessor_PrxII = y[1 + N_ADD]  # Last metabolic intermediate
    dPrxIIdt = (alpha_PrxII * Predecessor_PrxII) - (lambda_PrxII * PrxII)
    dydt[N_TOTAL - 1] = dPrxIIdt
    
    return dydt


# ============================================================================
# SIMULATION RUNNER
# ============================================================================

def run_simulation(V_val, lambda_PrxII, alpha_PrxII, y0, K_BOOL):
    """
    Execute simulation and compute derived variables.
    
    Parameters
    ----------
    V_val : float
        Membrane potential (mV)
    lambda_PrxII : float
        PRX-SO2/3 degradation rate (h^-1)
    alpha_PrxII : float
        PRX-SO2/3 production rate (h^-1)
    y0 : array_like
        Initial conditions
    K_BOOL : int
        Enable (1) or disable (0) K+ dynamics
        
    Returns
    -------
    sol : OdeResult
        Solution object from solve_ivp containing time points and state variables
    GGardos : ndarray
        Gardos channel conductance time series (mS/cm^2)
    K : ndarray
        Intracellular K+ concentration time series (a.u.)
    """
    
    # Integrate ODEs
    sol = solve_ivp(
        goodwin_oscillator_oscillating, 
        t_span, 
        y0, 
        args=(V_val, lambda_PrxII, alpha_PrxII, K_BOOL), 
        t_eval=t_points,
        method='RK45'
    )
    
    # Extract PRX-SO2/3 time series
    PrxII = sol.y[-1]
    
    # Recalculate Gardos conductance and K+ from simulation results
    GGardos = g * PrxII**n_val / (PRXII_50**n_val + PrxII**n_val)
    IGardos = GGardos * (V_val - vK)
    K = (K0 - IGardos * c) * K_BOOL
    K[K < 0] = 0
    
    return sol, GGardos, K


# ============================================================================
# RUN SIMULATIONS FOR DIFFERENT CONDITIONS
# ============================================================================

# ----------------------------------------------------------------------------
# Control condition: V = -10 mV
# ----------------------------------------------------------------------------
sol_V10, Geff_V10, K_V10 = run_simulation(
    V_val=-10.0, 
    lambda_PrxII=lambda_PrxII_c,
    alpha_PrxII=alpha_PrxII_c, 
    y0=Y0,
    K_BOOL=1
)

# ----------------------------------------------------------------------------
# Hyperpolarized membrane: V = -30 mV (mimics valinomycin treatment)
# Effect: Reduces driving force for K+ efflux, lengthens period
# ----------------------------------------------------------------------------
sol_V30, Geff_V30, K_V30 = run_simulation(
    V_val=-30.0, 
    lambda_PrxII=lambda_PrxII_c,
    alpha_PrxII=alpha_PrxII_c, 
    y0=Y0,
    K_BOOL=1
)

# ----------------------------------------------------------------------------
# MG132 treatment: Proteasomal inhibition (lambda_PrxII = 0)
# Effect: Prevents PRX-SO2/3 degradation, abolishes oscillations
# ----------------------------------------------------------------------------
sol_MG132, Geff_MG132, K_MG132 = run_simulation(
    V_val=-10.0, 
    lambda_PrxII=0,
    alpha_PrxII=alpha_PrxII_c, 
    y0=Y0,
    K_BOOL=1
)
Met_MG132 = sol_MG132.y[1]

# ----------------------------------------------------------------------------
# K+ removal experiment: Simulate transfer to K+-free medium at t=24h
# Effect: Abolishes oscillations when K+ is unavailable
# ----------------------------------------------------------------------------

# First, run control simulation
sol_yesK, Geff_yesK, K_yesK = run_simulation(
    V_val=-10.0, 
    lambda_PrxII=lambda_PrxII_c,
    alpha_PrxII=alpha_PrxII_c, 
    y0=Y0,
    K_BOOL=1
)

# Extract state at midpoint (t=24h) for K+ removal
mid_index = int(n_points / 2)
y0_noK = sol_yesK.y[:, mid_index]

# Run second half with K+ removed (K_BOOL=0)
sol_noK, Geff_noK, K_noK = run_simulation(
    V_val=-10.0, 
    lambda_PrxII=lambda_PrxII_c,
    alpha_PrxII=alpha_PrxII_c, 
    y0=y0_noK,
    K_BOOL=0
)

# Stitch together control and K+-removal simulations
sol_Kremov = np.zeros([3 + N_ADD, n_points])
sol_Kremov[:, :mid_index] = sol_yesK.y[:, :mid_index]
sol_Kremov[:, mid_index:] = sol_noK.y[:, :(n_points - mid_index)]

Geff_Kremov = np.zeros(n_points)
Geff_Kremov[:mid_index] = Geff_yesK[:mid_index]
Geff_Kremov[mid_index:] = Geff_noK[:(n_points - mid_index)]

K_Kremov = np.zeros(n_points)
K_Kremov[:mid_index] = K_yesK[:mid_index]
K_Kremov[mid_index:] = K_noK[:(n_points - mid_index)]

# ----------------------------------------------------------------------------
# Conoidin A treatment: Partial PRX inhibition
# Effect: Reduces PRX-SO2/3 production rate, dampens oscillations
# ----------------------------------------------------------------------------
sol_conA, Geff_conA, K_conA = run_simulation(
    V_val=-10.0, 
    lambda_PrxII=lambda_PrxII_c,
    alpha_PrxII=alpha_PrxII_c / 4,  # Reduce production to 25%
    y0=Y0,
    K_BOOL=1
)


# ============================================================================
# PLOTTING
# ============================================================================

# Select which simulation to display in main plots
sol = sol_V30
PK = sol.y[0]
Met = sol.y[1]
PrxII = sol.y[-1]
K = K_V30
GGardos = Geff_V30

# Time axis formatting
t_ticks = np.arange(0, t_span[1] + 1, 12)


# ----------------------------------------------------------------------------
# Figure 1: Main results
# ----------------------------------------------------------------------------

def format_axes(fig):
    """Apply consistent formatting to all axes in a figure."""
    for i, ax in enumerate(fig.axes):
        ax.tick_params(axis='both', which='major', labelsize=8)
        ax.grid(True, linestyle='--', alpha=0.6)


fig = plt.figure(layout="constrained", figsize=(176/25.4, 80/25.4*(4/3)))
gs = GridSpec(4, 3, figure=fig)

# Create subplots
ax1 = fig.add_subplot(gs[2, 0])  # Hill repression function
ax2 = fig.add_subplot(gs[2, 1])  # Gardos conductance time series
ax3 = fig.add_subplot(gs[0:2, 2])  # Phase space (limit cycle)
ax7 = fig.add_subplot(gs[2, 2])  # Gardos comparison (V=-10 vs V=-30)
ax5 = fig.add_subplot(gs[3, 0])  # Hill activation function
ax6 = fig.add_subplot(gs[3, 1])  # K+ time series
ax4 = fig.add_subplot(gs[3, 2])  # MG132 treatment

# Gardos conductance time series (control)
ax2.plot(sol_V10.t, Geff_V10, label=r'$G_{Gardos}$ ($V=-10$ mV)')
ax2.set_xticks(t_ticks)
ax2.set_xlabel('Time (hours)', fontsize=8)

# K+ concentration time series (control)
ax6.plot(sol_V10.t, K_V10, label=r'$[K]$ ($V=-10$ mV)', color='green')
ax6.set_xticks(t_ticks)
ax6.set_xlabel('Time (hours)', fontsize=8)
ax6.set_ylabel(r'$[K^+] (a.u.)$', fontsize=8)

# Gardos conductance comparison (different membrane potentials)
ax7.plot(sol_V10.t, Geff_V10, label=r'$G_{Gardos}$ ($V=-10$ mV)')
ax7.plot(sol_V30.t, Geff_V30, label=r'$G_{Gardos}$ ($V=-30$ mV)')
ax7.set_xticks(t_ticks)
ax7.set_xlabel('Time (hours)', fontsize=8)
ax7.set_ylabel(r'$G_{Gardos} (a.u.)$', fontsize=8)

# Phase space trajectory (limit cycle)
ax3.plot(sol_V10.y[1], K_V10, linewidth=1.5, label='Limit Cycle', color="purple")
ax3.set_xlabel(r'$Met_{1}$ (a.u.)', fontsize=8)
ax3.set_ylabel(r'$[K^+] (a.u.)$', fontsize=8)

# MG132 treatment (dual y-axis plot)
color_geff = 'tab:blue'
ax4.plot(sol_MG132.t, Geff_MG132, color=color_geff, label=r'$G_{Gardos}$')
ax4.set_xlabel('Time (hours)', fontsize=8)
ax4.set_ylabel(r'$G_{Gardos}$ (a.u.)', fontsize=8, color=color_geff)
ax4.tick_params(axis='y', labelcolor=color_geff)
ax4.set_xticks(t_ticks)
ax4.spines['left'].set_color(color_geff)

ax4_twin = ax4.twinx()
color_met = 'red'
ax4_twin.plot(sol_MG132.t, Met_MG132, color=color_met, linestyle='--', label=r'$Met_{1}$')
ax4_twin.set_ylabel(r'$Met_{1}$ (a.u.)', fontsize=8, color=color_met)
ax4_twin.tick_params(axis='y', labelcolor=color_met)
ax4_twin.spines['right'].set_color(color_met)

lines, labels = ax4.get_legend_handles_labels()
lines2, labels2 = ax4_twin.get_legend_handles_labels()
ax4.legend(lines + lines2, labels + labels2, loc='upper right', fontsize=7)

# Hill functions illustration
k = 100
h = 4

def hill1(x):
    """Hill repression function (nuclear TTFL)."""
    return k**h / (k**h + x**h)

def hill2(x):
    """Hill activation function (membrane oscillator)."""
    return x**h / (k**h + x**h)

z = np.logspace(0, 4)
H1 = hill1(z)
H2 = hill2(z)

# Repression function (for comparison with nuclear TTFL)
ax1.plot(np.log10(z), H1, color="r", linewidth=3)
ax1.spines[['right', 'top']].set_visible(False)
ax1.tick_params(left=False, right=False, labelleft=False, 
                labelbottom=False, bottom=False)
ax1.set_xlabel("Inhibitor concentration ([R])", fontsize=8)
ax1.set_ylabel("mRNA production", fontsize=8)

# Activation function (Gardos channel gating)
ax5.plot(np.log10(z), H2, color="r", linewidth=3)
ax5.spines[['right', 'top']].set_visible(False)
ax5.tick_params(left=False, right=False, labelleft=False, 
                labelbottom=False, bottom=False)
ax5.set_xlabel("Ligand concentration ([Z])", fontsize=8)
ax5.set_ylabel("Conductance", fontsize=8)

# Apply formatting and turn off grid for specific panels
format_axes(fig)
for ax in [ax1, ax5, ax2, ax3, ax4, ax4_twin, ax6, ax7]:
    ax.grid(False)

plt.savefig("Figure_1.svg")
plt.show()


# ----------------------------------------------------------------------------
# Figure 2: Supplementary multi-condition comparison
# ----------------------------------------------------------------------------

N_ROWS = 5
N_COLUMNS = 4

fig = plt.figure(figsize=(15, 12))
gs = GridSpec(N_ROWS + 1, N_COLUMNS, figure=fig, 
                       height_ratios=[1, 1, 0.2, 0.8, 1, 1],
                       wspace=0.2)

ax = [[None for _ in range(N_COLUMNS)] for _ in range(N_ROWS)]

# Create top two rows with shared axes
for j in range(N_COLUMNS):
    if j == 0:
        ax[0][j] = fig.add_subplot(gs[0, j])
    else:
        ax[0][j] = fig.add_subplot(gs[0, j], sharex=ax[0][0], sharey=ax[0][0])

for j in range(N_COLUMNS):
    if j == 0:
        ax[1][j] = fig.add_subplot(gs[1, j], sharex=ax[0][0])
    else:
        ax[1][j] = fig.add_subplot(gs[1, j], sharex=ax[0][0], sharey=ax[1][0])

# Create broken axis for PRX-SO2/3 (row 2)
ax_broken_top = [None] * N_COLUMNS
ax_broken_bottom = [None] * N_COLUMNS

for j in range(N_COLUMNS):
    if j == 0:
        ax_broken_top[j] = fig.add_subplot(gs[2, j], sharex=ax[0][0])
    else:
        ax_broken_top[j] = fig.add_subplot(gs[2, j], sharex=ax[0][0], sharey=ax_broken_top[0])
    
    if j == 0:
        ax_broken_bottom[j] = fig.add_subplot(gs[3, j], sharex=ax[0][0])
    else:
        ax_broken_bottom[j] = fig.add_subplot(gs[3, j], sharex=ax[0][0], sharey=ax_broken_bottom[0])
    
    ax[2][j] = [ax_broken_bottom[j], ax_broken_top[j]]
    
    # Set y-limits for broken axis
    ax_broken_top[j].set_ylim(12.0, 13)
    ax_broken_bottom[j].set_ylim(0, 2.5)
    
    # Hide x-axis labels on top panel
    ax_broken_top[j].tick_params(axis='x', bottom=False, top=False,
                                 labelbottom=False, labeltop=False)
    ax_broken_bottom[j].xaxis.tick_bottom()
    
    # Draw break indicators
    d = .015
    kwargs = dict(transform=ax_broken_top[j].transAxes, color='k', clip_on=False, linewidth=1)
    ax_broken_top[j].plot((-d, +d), (-d, +d), **kwargs)
    ax_broken_top[j].plot((1 - d, 1 + d), (-d, +d), **kwargs)
    
    kwargs.update(transform=ax_broken_bottom[j].transAxes)
    ax_broken_bottom[j].plot((-d, +d), (1 - d, 1 + d), **kwargs)
    ax_broken_bottom[j].plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
    
    # Adjust position to create gap between panels
    gap = 0.008
    bottom_pos = ax_broken_bottom[j].get_position()
    top_pos = ax_broken_top[j].get_position()
    ax_broken_top[j].set_position([
        top_pos.x0,
        bottom_pos.y1 + gap,
        top_pos.width,
        top_pos.height
    ])

# Adjust positions of top two rows
gap = 0.01
for i in [0, 1]:
    for j in range(N_COLUMNS):
        pos = ax[i][j].get_position()
        ax[i][j].set_position([pos.x0, pos.y0 - gap, pos.width, pos.height])

# Create bottom two rows
for j in range(N_COLUMNS):
    if j == 0:
        ax[3][j] = fig.add_subplot(gs[4, j], sharex=ax[0][0])
    else:
        ax[3][j] = fig.add_subplot(gs[4, j], sharex=ax[0][0], sharey=ax[3][0])

for j in range(N_COLUMNS):
    if j == 0:
        ax[4][j] = fig.add_subplot(gs[5, j], sharex=ax[0][0])
    else:
        ax[4][j] = fig.add_subplot(gs[5, j], sharex=ax[0][0], sharey=ax[4][0])

# Plot all conditions
# Column 0: Control (V=-10) vs Hyperpolarized (V=-30)
ax[0][0].plot(t_points, sol_V10.y[0], linewidth=4)
ax[1][0].plot(t_points, sol_V10.y[3], linewidth=4)
ax[2][0][0].plot(t_points, sol_V10.y[-1], linewidth=4)
ax[2][0][1].plot(t_points, sol_V10.y[-1], linewidth=4)
ax[3][0].plot(t_points, Geff_V10, linewidth=4)
ax[4][0].plot(t_points, K_V10, linewidth=4)

ax[0][0].plot(t_points, sol_V30.y[0], linewidth=4)
ax[1][0].plot(t_points, sol_V30.y[3], linewidth=4)
ax[2][0][0].plot(t_points, sol_V30.y[-1], linewidth=4)
ax[2][0][1].plot(t_points, sol_V30.y[-1], linewidth=4)
ax[3][0].plot(t_points, Geff_V30, linewidth=4)
ax[4][0].plot(t_points, K_V30, linewidth=4)

# Column 1: Conoidin A
ax[0][1].plot(t_points, sol_conA.y[0], linewidth=4)
ax[1][1].plot(t_points, sol_conA.y[3], linewidth=4)
ax[2][1][0].plot(t_points, sol_conA.y[-1], linewidth=4)
ax[2][1][1].plot(t_points, sol_conA.y[-1], linewidth=4)
ax[3][1].plot(t_points, Geff_conA, linewidth=4)
ax[4][1].plot(t_points, K_conA, linewidth=4)

# Column 2: MG132
ax[0][2].plot(t_points, sol_MG132.y[0], linewidth=4)
ax[1][2].plot(t_points, sol_MG132.y[3], linewidth=4)
ax[2][2][0].plot(t_points, sol_MG132.y[-1], linewidth=4)
ax[2][2][1].plot(t_points, sol_MG132.y[-1], linewidth=4)
ax[3][2].plot(t_points, Geff_MG132, linewidth=4)
ax[4][2].plot(t_points, K_MG132, linewidth=4)

# Column 3: K+ removal
ax[0][3].plot(t_points, sol_Kremov[0], linewidth=4)
ax[1][3].plot(t_points, sol_Kremov[3], linewidth=4)
ax[2][3][0].plot(t_points, sol_Kremov[-1], linewidth=4)
ax[2][3][1].plot(t_points, sol_Kremov[-1], linewidth=4)
ax[3][3].plot(t_points, Geff_Kremov, linewidth=4)
ax[4][3].plot(t_points, K_Kremov, linewidth=4)

# Add vertical line at t=24h for K+ removal condition
for n in range(N_ROWS):
    if n != 2:
        ax[n][3].axvline(x=24, color='r', linestyle='--', linewidth=2, zorder=5)
    else:
        ax[n][3][0].axvline(x=24, color='r', linestyle='--', linewidth=2, zorder=5)
        ax[n][3][1].axvline(x=24, color='r', linestyle='--', linewidth=2, zorder=5)

# Labels
for n in range(N_COLUMNS):
    ax[-1][n].set_xlabel('Time (hours)', fontsize=16)

ax[0][0].set_ylabel('PK (a.u.)', fontsize=16)
ax[1][0].set_ylabel(r'$Met_{3}$ (a.u.)', fontsize=16)
ax[2][0][0].set_ylabel('PRXO (a.u.)', fontsize=16)
ax[3][0].set_ylabel(r'$G_{Gardos}$ (a.u.)', fontsize=16)
ax[4][0].set_ylabel(r'$[K^+]$ (a.u.)', fontsize=16)

# Format ticks
for i in range(N_ROWS):
    for j in range(N_COLUMNS):
        if i != 2:
            ax[i][j].set_xticks(t_ticks)
            if i != N_ROWS - 1:
                ax[i][j].tick_params(axis='x', labelbottom=False)
            else:
                ax[i][j].tick_params(axis='x', labelsize=14)
            
            if j != 0:
                ax[i][j].tick_params(axis='y', labelleft=False)
            else:
                ax[i][j].tick_params(axis='y', labelsize=14)

# Format broken axis ticks
for j in range(N_COLUMNS):
    ax[2][j][0].set_xticks(t_ticks)
    ax[2][j][0].tick_params(axis='x', labelbottom=False)
    ax[2][j][1].tick_params(axis='x', labelbottom=False)
    if j != 0:
        ax[2][j][0].tick_params(axis='y', labelleft=False)
        ax[2][j][1].tick_params(axis='y', labelleft=False)
    else:
        ax[2][j][0].tick_params(axis='y', labelsize=14)
        ax[2][j][1].tick_params(axis='y', labelsize=14)

plt.savefig("Sup.svg")
plt.show()


# ============================================================================
# END OF CODE
# ============================================================================