# Defines Next Reaction Method class used to exactly simulate the dynamics of the system 
# Created by Matthew Asker 
import random
import numpy as np
from numba.experimental import jitclass
from numba import int64, float64, typeof    

# specify variable types for the numba compiler
strainspec = [
    ('fitness', float64),
    ('population', int64)
    ]

# class to define and track / modify the population of the two strains
@jitclass(strainspec)
class strain():
    def __init__(self, fitness, population):
        self.fitness = fitness
        self.population = population
        
    def pop_change(self, change):
        self.population += change
        
    def fitness_change(self, new_fitness):
        self.fitness = new_fitness

# specify variable types for the numba compiler
spec = [
    ('s', float64),
    ('a', float64),
    ('nu_minus', float64),   
    ('nu_plus', float64),    
    ('K_p', int64),   
    ('K_m', int64),   
    ('x_th', float64),   
    ('simulations', int64),
    ('r_fixation_count', int64), 
    ('total_fixation_count', int64),
    ('coexist_count', int64),
    ('average_pop', float64), 
    ('average_time', float64), 
    ('average_x', float64), 
    ('x_coexist', float64), 
    ('relaxation_time', float64), 
    ('coexist_prob', float64),
    ('final_N_R', float64),
    ('final_N_S', float64)
]

# Next Reaction Method defined by the system given in the main text of the
# paper. For the constant environment case, an object of type next_reaction
# should be initialised with K_p=K_m=K_0 and nu_plus=nu_minus=0.0.
@jitclass(spec)            
class next_reaction():
    def __init__(self, s, a, nu_plus, nu_minus, K_p, K_m, x_th, simulations):
        self.s = s
        self.a = a
        self.nu_plus = nu_plus
        self.nu_minus = nu_minus
        self.K_p = K_p
        self.K_m = K_m
        self.x_th = x_th
        self.simulations = simulations
        self.r_fixation_count = 0
        self.total_fixation_count = 0
        self.coexist_count = 0
        self.final_N_R = 0.0
        self.final_N_S = 0.0
        # self.N_R_track = list()  # Not used unless necessary as cause
        # self.N_S_track = list()  # large memory usage. These track the comp-
        # self.K_track = list()    # osition of the system through time.
        # self.t_track = list()
        self.average_pop = 0
        self.average_time = 0
        self.average_x = 0
        self.x_coexist = 0
        self.relaxation_time = 0
        self.coexist_prob = 0
        
    # method to run the simulation itself
    def run(self):
        K_0 = (self.K_p + self.K_m)/2
        nu = (self.nu_plus + self.nu_minus) / 2
        gamma = (self.K_p - self.K_m)/(2*K_0)
        if nu != 0.0:
            delta = (self.nu_minus - self.nu_plus)/(2*nu)
        else:
            delta = 0.0
        for i in range(0, self.simulations):
            rand = random.uniform(0, 1)
            if rand < (1-delta)/2:
                K_switch_rate = self.nu_minus
                carrying_capacity = self.K_m
            else:
                K_switch_rate = self.nu_plus
                carrying_capacity = self.K_p
            total_pop = carrying_capacity
            s_strain = strain(1, np.round(total_pop*(1-self.x_th))) # sensitive strain
            r_strain = strain(1-self.s, total_pop - s_strain.population) # resistant strain
            
            time = 0.0
            running = True
            fixated = False
            max_time = 2*K_0*(1+gamma*delta)
            time_fixated = -1
            average_x = 0
            x = r_strain.population / total_pop
            
            # self.N_R_track.append(r_strain.population)
            # self.N_S_track.append(s_strain.population)
            # self.K_track.append(carrying_capacity)
            # self.t_track.append(time)
            
            average_fitness = x * r_strain.fitness + s_strain.fitness * (1-x)
            gamma_c = r_strain.fitness / average_fitness
            gamma_d = s_strain.fitness / average_fitness
            c_birth = gamma_c * total_pop * x
            d_birth = gamma_d * total_pop * (1-x)
            c_death = total_pop**2 * x / carrying_capacity
            d_death = total_pop**2 * (1-x) / carrying_capacity
            
            # First take care of the case where the switching bias delta is +/- 1
            if K_switch_rate == 0.0:
                K_switch_time, c_birth_time, c_death_time, d_birth_time, d_death_time = np.inf, np.log(1/random.uniform(0, 1)) / c_birth, np.log(1/random.uniform(0, 1)) / c_death, np.log(1/random.uniform(0, 1)) / d_birth, np.log(1/random.uniform(0, 1)) / d_death
            else:
                K_switch_time, c_birth_time, c_death_time, d_birth_time, d_death_time = np.log(1/random.uniform(0, 1)) / K_switch_rate, np.log(1/random.uniform(0, 1)) / c_birth, np.log(1/random.uniform(0, 1)) / c_death, np.log(1/random.uniform(0, 1)) / d_birth, np.log(1/random.uniform(0, 1)) / d_death
            
            while running:
                if r_strain.population != total_pop and s_strain.population != total_pop: # 2 events not possible if fixated and break next-reaction method
                    if K_switch_time < c_birth_time and K_switch_time < c_death_time and K_switch_time < d_birth_time and K_switch_time < d_death_time:
                        if carrying_capacity == self.K_m:
                            carrying_capacity = self.K_p
                            K_switch_rate = self.nu_plus
                        else:
                            carrying_capacity = self.K_m
                            K_switch_rate = self.nu_minus
                        tau = K_switch_time - time # time increment
                        time = K_switch_time
                        average_x += x*tau

                        c_death_bar = total_pop**2 * x / carrying_capacity
                        d_death_bar = total_pop**2 * (1-x) / carrying_capacity
                                
                        c_death_time = (c_death / c_death_bar)*(c_death_time - time) + time
                        d_death_time = (d_death / d_death_bar)*(d_death_time - time) + time
                        K_switch_time = np.log(1 / random.uniform(0, 1)) / K_switch_rate + time
                        
                        c_death = c_death_bar
                        d_death = d_death_bar
                    elif c_birth_time < c_death_time and c_birth_time < d_birth_time and c_birth_time < d_death_time:
                        r_strain.pop_change(1)
                        if r_strain.population / (r_strain.population + s_strain.population) >= self.x_th:
                            s_strain.fitness_change(1)
                        tau = c_birth_time - time
                        time = c_birth_time
                        average_x += x*tau
                        total_pop = s_strain.population + r_strain.population
                        
                        x = r_strain.population / total_pop

                        if x != 0 and x != 1:                           
                            average_fitness = x*r_strain.fitness + s_strain.fitness*(1-x)
                            gamma_c = r_strain.fitness / average_fitness
                            gamma_d = s_strain.fitness / average_fitness
                            
                            c_birth_bar = gamma_c * total_pop * x
                            d_birth_bar = gamma_d * total_pop * (1-x)
                            c_death_bar = total_pop**2 * x / carrying_capacity
                            d_death_bar = total_pop**2 * (1-x) / carrying_capacity
                            
                            c_birth_time = np.log(1 / random.uniform(0, 1)) / c_birth_bar + time
                            c_death_time = (c_death / c_death_bar)*(c_death_time - time) + time
                            d_birth_time = (d_birth / d_birth_bar)*(d_birth_time - time) + time
                            d_death_time = (d_death / d_death_bar)*(d_death_time - time) + time
                                
                            c_birth = c_birth_bar
                            d_birth = d_birth_bar
                            c_death = c_death_bar
                            d_death = d_death_bar
                        
                    elif c_death_time < d_birth_time and c_death_time < d_death_time:
                        r_strain.pop_change(-1)  
                        if r_strain.population / (r_strain.population + s_strain.population) < self.x_th:
                            s_strain.fitness_change(1-self.a)
                        tau = c_death_time - time
                        time = c_death_time
                        average_x += x*tau
                        total_pop = s_strain.population + r_strain.population
                        
                        x = r_strain.population / total_pop

                        if x != 0 and x != 1:                              
                            average_fitness = x*r_strain.fitness + s_strain.fitness*(1-x)
                            gamma_c = r_strain.fitness / average_fitness
                            gamma_d = s_strain.fitness / average_fitness
                            
                            c_birth_bar =gamma_c * total_pop * x
                            d_birth_bar =gamma_d * total_pop * (1-x)
                            c_death_bar = total_pop**2 * x / carrying_capacity
                            d_death_bar = total_pop**2 * (1-x) / carrying_capacity
                            
                            c_birth_time = (c_birth / c_birth_bar)*(c_birth_time - time) + time
                            c_death_time = np.log(1 / random.uniform(0, 1)) / c_death_bar + time
                            d_birth_time = (d_birth / d_birth_bar)*(d_birth_time - time) + time
                            d_death_time = (d_death / d_death_bar)*(d_death_time - time) + time
    
                            c_birth = c_birth_bar
                            d_birth = d_birth_bar
                            c_death = c_death_bar
                            d_death = d_death_bar    
                    elif d_birth_time < d_death_time:
                        s_strain.pop_change(1)  
                        if r_strain.population / (r_strain.population + s_strain.population) < self.x_th: 
                            s_strain.fitness_change(1-self.a)
                        tau = d_birth_time - time
                        time = d_birth_time
                        average_x += x*tau
                        total_pop = s_strain.population + r_strain.population
                        
                        x = r_strain.population / total_pop

                        if x != 0 and x != 1:                           
                            average_fitness = x*r_strain.fitness + s_strain.fitness*(1-x)
                            gamma_c = r_strain.fitness / average_fitness
                            gamma_d = s_strain.fitness / average_fitness
                            
                            c_birth_bar =gamma_c * total_pop * x
                            d_birth_bar =gamma_d * total_pop * (1-x)
                            c_death_bar = total_pop**2 * x / carrying_capacity
                            d_death_bar = total_pop**2 * (1-x) / carrying_capacity
                            
                            c_birth_time = (c_birth / c_birth_bar)*(c_birth_time - time) + time
                            c_death_time = (c_death / c_death_bar)*(c_death_time - time) + time
                            d_birth_time = np.log(1 / random.uniform(0, 1)) / d_birth_bar + time
                            d_death_time = (d_death / d_death_bar)*(d_death_time - time) + time
    
                            c_birth = c_birth_bar
                            d_birth = d_birth_bar
                            c_death = c_death_bar
                            d_death = d_death_bar    
                    else:
                        s_strain.pop_change(-1)   
                        if r_strain.population / (r_strain.population + s_strain.population) >= self.x_th:
                            s_strain.fitness_change(1)
                        tau = d_death_time - time
                        time = d_death_time
                        average_x += x*tau
                        total_pop = s_strain.population + r_strain.population
                        
                        x = r_strain.population / total_pop

                        if x != 0 and x != 1:           
                            average_fitness = x*r_strain.fitness + s_strain.fitness*(1-x)
                            gamma_c = r_strain.fitness / average_fitness
                            gamma_d = s_strain.fitness / average_fitness
                            
                            c_birth_bar =gamma_c * total_pop * x
                            d_birth_bar =gamma_d * total_pop * (1-x)
                            c_death_bar = total_pop**2 * x / carrying_capacity
                            d_death_bar = total_pop**2 * (1-x) / carrying_capacity
                            
                            c_birth_time = (c_birth / c_birth_bar)*(c_birth_time - time) + time
                            c_death_time = (c_death / c_death_bar)*(c_death_time - time) + time
                            d_birth_time = (d_birth / d_birth_bar)*(d_birth_time - time) + time
                            d_death_time = np.log(1 / random.uniform(0, 1)) / d_death_bar + time
    
                            c_birth = c_birth_bar
                            d_birth = d_birth_bar
                            c_death = c_death_bar
                            d_death = d_death_bar    
                elif r_strain.population == total_pop:
                    if K_switch_time < c_birth_time and K_switch_time < c_death_time:
                        if carrying_capacity == self.K_m:
                            carrying_capacity = self.K_p
                            K_switch_rate = self.nu_plus
                        else:
                            carrying_capacity = self.K_m
                            K_switch_rate = self.nu_minus
                        tau = K_switch_time - time
                        time = K_switch_time
                        average_x += x*tau

                        c_birth_bar = gamma_c * total_pop * x
                        c_death_bar = total_pop**2 * x / carrying_capacity   
                        
                        c_birth_time = (c_birth / c_birth_bar)*(c_birth_time - time) + time
                        c_death_time = (c_death / c_death_bar)*(c_death_time - time) + time
                        K_switch_time = np.log(1 / random.uniform(0, 1)) / K_switch_rate + time
                        
                        c_death = c_death_bar
                        c_birth = c_birth_bar
                    elif c_birth_time < c_death_time:
                        r_strain.pop_change(1)

                        tau = c_birth_time - time
                        time = c_birth_time
                        average_x += x*tau
                        total_pop = s_strain.population + r_strain.population
                        
                        x = r_strain.population / total_pop
                                                
                        average_fitness = x*r_strain.fitness + s_strain.fitness*(1-x)
                        gamma_c = r_strain.fitness / average_fitness
                        
                        c_birth_bar =gamma_c * total_pop * x
                        c_death_bar = total_pop**2 * x / carrying_capacity
                        
                        c_birth_time = np.log(1 / random.uniform(0, 1)) / c_birth_bar + time
                        c_death_time = (c_death / c_death_bar)*(c_death_time - time) + time
                            
                        c_birth = c_birth_bar
                        c_death = c_death_bar
                        
                    else:
                        r_strain.pop_change(-1)  

                        tau = c_death_time - time
                        time = c_death_time
                        average_x += x*tau
                        total_pop = s_strain.population + r_strain.population
                        
                        x = r_strain.population / total_pop
                            
                        average_fitness = x*r_strain.fitness + s_strain.fitness*(1-x)
                        gamma_c = r_strain.fitness / average_fitness
                        
                        c_birth_bar =gamma_c * total_pop * x
                        c_death_bar = total_pop**2 * x / carrying_capacity
                        
                        c_birth_time = (c_birth / c_birth_bar)*(c_birth_time - time) + time
                        c_death_time = np.log(1 / random.uniform(0, 1)) / c_death_bar + time

                        c_birth = c_birth_bar
                        c_death = c_death_bar
                else:
                    if K_switch_time < d_birth_time and K_switch_time < d_death_time:
                        if carrying_capacity == self.K_m:
                            carrying_capacity = self.K_p
                            K_switch_rate = self.nu_plus
                        else:
                            carrying_capacity = self.K_m
                            K_switch_rate = self.nu_minus
                        tau = K_switch_time - time
                        time = K_switch_time
                        average_x += x*tau

                        d_birth_bar =gamma_d * total_pop * (1-x)
                        d_death_bar = total_pop**2 * (1-x) / carrying_capacity

                        d_birth_time = (d_birth / d_birth_bar)*(d_birth_time - time) + time
                        d_death_time = (d_death / d_death_bar)*(d_death_time - time) + time
                        K_switch_time = np.log(1 / random.uniform(0, 1)) / K_switch_rate + time
                        
                        d_death = d_death_bar
                        d_birth = d_birth_bar
                    elif d_birth_time < d_death_time:
                        s_strain.pop_change(1)   
                        tau = d_birth_time - time
                        time = d_birth_time
                        average_x += x*tau
                        total_pop = s_strain.population + r_strain.population
                        
                        x = r_strain.population / total_pop

                        average_fitness = x*r_strain.fitness + s_strain.fitness*(1-x)
                        gamma_d = s_strain.fitness / average_fitness

                        d_birth_bar =gamma_d * total_pop * (1-x)
                        d_death_bar = total_pop**2 * (1-x) / carrying_capacity
                        
                        d_birth_time = np.log(1 / random.uniform(0, 1)) / d_birth_bar + time
                        d_death_time = (d_death / d_death_bar)*(d_death_time - time) + time

                        d_birth = d_birth_bar
                        d_death = d_death_bar    
                    else:
                        s_strain.pop_change(-1)   
                        tau = d_death_time - time
                        time = d_death_time
                        average_x += x*tau
                        total_pop = s_strain.population + r_strain.population
                        
                        x = r_strain.population / total_pop
                                  
                        average_fitness = x*r_strain.fitness + s_strain.fitness*(1-x)
                        gamma_c = r_strain.fitness / average_fitness
                        gamma_d = s_strain.fitness / average_fitness
                        
                        d_birth_bar =gamma_d * total_pop * (1-x)
                        d_death_bar = total_pop**2 * (1-x) / carrying_capacity
                        
                        d_birth_time = (d_birth / d_birth_bar)*(d_birth_time - time) + time
                        d_death_time = np.log(1 / random.uniform(0, 1)) / d_death_bar + time

                        d_birth = d_birth_bar
                        d_death = d_death_bar 
                    
                # self.N_R_track.append(r_strain.population)
                # self.N_S_track.append(s_strain.population)
                # self.K_track.append(carrying_capacity)
                # self.t_track.append(time)
                
                # Following are the conditions to stop the simulation. To run
                # until fixation, the 'running = False' statements in the next
                # if block should be uncommented. To run until some later time
                # max_time, the if block following this should have the same
                # statement uncommented.
                if r_strain.population == total_pop and fixated == False:
                    fixated = True
                    time_fixated = time
                    self.r_fixation_count += 1
                    self.total_fixation_count += 1
                    running = False
                elif s_strain.population == total_pop and fixated == False:
                    fixated = True
                    time_fixated = time
                    self.total_fixation_count += 1
                    running = False
                    
                if time > max_time:
                     if fixated == False:
                          self.coexist_prob += 1
                          time_fixated = max_time
                     running = False
                     self.average_time += time_fixated
                     self.final_N_R += r_strain.population
                     self.final_N_S += s_strain.population                 
            self.average_time += time_fixated
            self.average_x += average_x/time       
        self.final_N_R = self.final_N_R / self.simulations
        self.final_N_S = self.final_N_S / self.simulations
        self.average_time = self.average_time / self.simulations
        self.coexist_prob = self.coexist_prob / self.simulations
        self.average_x = self.average_x / self.simulations