import numpy as np
import random
import csv

class IBM:

    def __init__(self,flowrate,angle,eelSize,PRINT,PLOT,WRITE):
        self.eelNumber = 1000
        self.flowrate = flowrate
        self.angle = angle
        self.eelSize = eelSize
        
        data = np.loadtxt('%s_%s_velMag_avg_3mm.asc' %(self.flowrate, self.angle))
        [row_max, col_max] = data.shape

        studs = [175,265,355,445,535,625,715,805,895,985,1075,1165,1255,1345,1435,1525,1615,1705,1795,1885,1975,2065,2155,2245,2335,2425,2515]
        for k in studs:
            for i in range(31):
                data[0,k+i] = 99
                data[row_max-1,k+i] = 99
        
        data[row_max-1,:] = 99
        data[:,0] = 99
        data[:,col_max-1] = 99
        
        data2 = np.zeros([3*row_max-2, col_max])
        
        for i in range(row_max-1):
            for j in range(col_max-1):
                data2[i,j] = data[i,j]
                data2[row_max-1+i,j] = data[i, j]
                data2[row_max-1+row_max-1+i,j] = data[i, j]
        
        [row2_max, col2_max] = data2.shape
        
        data2[row2_max-1,:] = 99
        data2[:,0] = 99
        data2[:,col2_max-1] = 99
        data2[0,:] = 99
        for i in range(col2_max):
            for j in range(row2_max):
                if data2[j,i] == 0:
                    data2[j,i] = 99
        
        self.spawn_col = col2_max-2
        self.spawn_rows = np.arange(1,row2_max-4)

        self.Tmax = 10000
        self.max_burst_time = 20 #Burst speed can only be maintained for 20 seconds (By definition)

        self.makeAgents()
        self.MAIN(data2)
        if PRINT == True:
            self.printOut()
        if PLOT == True:
            self.plotOut(data2)
        if WRITE == True:
            self.writeOut()

    def makeAgents(self):
        self.eels = list()
        self.eelsInDom = list()
        for i in range(self.eelNumber):
            # Create an eel
            self.eels.append(eel(i,self.spawn_col,self.spawn_rows,self.eelSize))
            self.eelsInDom.append(self.eels[i].id)
        self.passed = list()
        self.failed = list()
        self.successes = 0

    def MAIN(self,data):
        self.T = 0
        
        while self.T < self.Tmax and len(self.passed) != self.eelNumber:
            
            for i in self.eelsInDom:
                
                self.eels[i].updateOutput(self.T)
        
                # Determine if the eel is exhausted
                if self.eels[i].burstTime > self.max_burst_time:
                    self.failed.append(self.eels[i].id)
                    self.eelsInDom.remove(self.eels[i].id)
                
                # Determine if the eel has passed
                if i in self.eelsInDom and self.eels[i].col < 25:
                    self.passed.append(self.eels[i].id)
                    self.eelsInDom.remove(self.eels[i].id)
                    self.successes += 1
                
                # Move the eel
                if i in self.eelsInDom:
                    self.eels[i].createNeighbours(data)
                    self.eels[i].move()
                          
            self.T += 1
            
    def printOut(self):
        print("\n" + " ----- %s_%s_velMag_avg_3mm.asc" %(self.flowrate, self.angle))
        print(" --------------------------------------------------- ")
        print("   ------ %s out of %s successfully passed ------  " % (len(self.passed),self.eelNumber))
        print(" --------------------------------------------------- " + "\n")

    def plotOut(self,data):
        import matplotlib.pyplot as plt
        from matplotlib.collections import LineCollection

        fig, axes = plt.subplots(1, 1, sharex=True, sharey=True)
        figManager = plt.get_current_fig_manager()
        figManager.window.showMaximized()
        levels = [0, 30, 60, 99]
        levels = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
        axes.contourf(data, levels, cmap = 'rainbow')
        axes.axis('scaled')
        plt.show()
        
        for i in self.passed:
            x = np.asarray(self.eels[i].history[:,1])
            y = np.asarray(self.eels[i].history[:,0])
            time = np.asarray(self.eels[i].timePerStep)
            points = np.array([x, y]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            norm = plt.Normalize(time.min(), time.max())
            lc = LineCollection(segments, cmap='coolwarm', norm=norm)
            lc.set_array(time)
            lc.set_linewidth(2)
        
        for i in self.eelsInDom:
            x = np.asarray(self.eels[i].history[:,1])
            y = np.asarray(self.eels[i].history[:,0])
            time = np.asarray(self.eels[i].timePerStep)
            points = np.array([x, y]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            norm = plt.Normalize(time.min(), time.max())
            lc = LineCollection(segments, cmap='coolwarm', norm=norm)
            lc.set_array(time)
            lc.set_linewidth(2)
        
        for i in self.failed:
            x = np.asarray(self.eels[i].history[:,1])
            y = np.asarray(self.eels[i].history[:,0])
            time = np.asarray(self.eels[i].timePerStep)
            points = np.array([x, y]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            norm = plt.Normalize(time.min(), time.max())
            lc = LineCollection(segments, cmap='coolwarm', norm=norm)
            lc.set_array(time)
            lc.set_linewidth(2)
        
    def writeOut(self):
        fileNames = {}
        for i in range(self.eelNumber):
            fileNames[i] = open("IBMOutputs/%s_IBMOutputs/%s/IBM_%s_%s_agentData_%scm_%s.txt" %(self.flowrate,self.angle,self.flowrate,self.angle,self.eelSize,i), "w")
            writer = csv.writer(fileNames[i])
            writer.writerows(self.eels[i].output)
            fileNames[i].close()

############################################################################### EEL CLASS

class eel(object):
    
    def __init__(self, UniqueID, spawn_col, spawn_rows, eelSize):
        
        self.id = UniqueID # ID number
        self.row = random.choice(spawn_rows)
        self.spawn_col = spawn_col
        self.col = spawn_col
        self.location = np.array([self.row, self.col])
        self.history = self.location
        self.history = np.vstack([self.history, self.location])
        
        #Exhaustion stuff
        self.eelSize = eelSize
        self.burst = self.determineSpeed()
        self.burstTime = 0
        self.timePerStep = list()
        self.timeReq = 0
        
        #output Data
        self.output = np.array(['ID', 'Burst Speed', 'Step No.', 'Row (y)', 'Col (x)', 'Step Time', 'Total Time', 'xDistance'])
        self.dataLine = np.array([0,0.0,0,0,0,0.0,0.0,0.0])              
        
    def createNeighbours(self,data):
        
        self.neighbours = list()
        self.passableNeighboursUP = list()
        self.passableNeighboursSIDE = list()
        self.passableNeighboursDOWN = list()
        
        self.neighbours.append(neighbour(0,self.row-1,self.col-1,data))
        self.neighbours[0].passable(self.burst,data)   
        if self.neighbours[0].passable == True:
            self.passableNeighboursUP.append(self.neighbours[0].id)
            
        self.neighbours.append(neighbour(1,self.row,self.col-1,data))
        self.neighbours[1].passable(self.burst,data)   
        if self.neighbours[1].passable == True:
            self.passableNeighboursUP.append(self.neighbours[1].id)
            
        self.neighbours.append(neighbour(2,self.row+1,self.col-1,data))
        self.neighbours[2].passable(self.burst,data)
        if self.neighbours[2].passable == True:
            self.passableNeighboursUP.append(self.neighbours[2].id)
            
        self.neighbours.append(neighbour(3,self.row-1,self.col,data))
        self.neighbours[3].passable(self.burst,data)
        if self.neighbours[3].passable == True:
            self.passableNeighboursSIDE.append(self.neighbours[3].id)
            
        self.neighbours.append(neighbour(4,self.row+1,self.col,data))
        self.neighbours[4].passable(self.burst,data)
        if self.neighbours[4].passable == True:
            self.passableNeighboursSIDE.append(self.neighbours[4].id)
            
        self.neighbours.append(neighbour(5,self.row-1,self.col+1,data))
        self.neighbours[5].passable(self.burst,data)
        if self.neighbours[5].passable == True:
            self.passableNeighboursDOWN.append(self.neighbours[5].id)
            
        self.neighbours.append(neighbour(6,self.row,self.col+1,data))
        self.neighbours[6].passable(self.burst,data)
        if self.neighbours[6].passable == True:
            self.passableNeighboursDOWN.append(self.neighbours[6].id)
            
        self.neighbours.append(neighbour(7,self.row+1,self.col+1,data))
        self.neighbours[7].passable(self.burst,data)
        if self.neighbours[7].passable == True:
            self.passableNeighboursDOWN.append(self.neighbours[7].id)
            
    def move(self):
        if self.passableNeighboursUP:
            n = random.choice(self.passableNeighboursUP)
        elif self.passableNeighboursSIDE:
            n = random.choice(self.passableNeighboursSIDE)
        elif self.passableNeighboursDOWN:
            n = random.choice(self.passableNeighboursDOWN)
        
        # relative speed that the fish moves
        speed = self.burst - self.neighbours[n].vel
        if self.row != self.neighbours[n].row and self.col != self.neighbours[n].col: #diff row and diff col so must have moved diagonally
            distance = 0.00070712
        else:
            distance = 0.0005
    
        self.timeReq =  distance/speed
        self.timePerStep.append(self.timeReq)
        self.burstTime += self.timeReq
    
        self.row = self.neighbours[n].row
        self.col = self.neighbours[n].col
        self.location = np.array([self.row, self.col])
        self.history = np.vstack([self.history, self.location])
    
    def updateOutput(self, T):
        
        self.dataLine[0] = self.id
        self.dataLine[1] = self.burst
        self.dataLine[2] = T
        self.dataLine[3] = self.row
        self.dataLine[4] = self.col
        
        if np.isinf(self.timeReq):
            self.dataLine[5] = 1000
            self.dataLine[6] = 1000
        else:
            self.dataLine[5] = self.timeReq
            self.dataLine[6] = self.burstTime
            
        self.dataLine[7] = abs(self.col-self.spawn_col)*(0.5/1000)
        
        self.output = np.vstack([self.output,self.dataLine])


    def determineSpeed(self):
        # This function determines a speed for elver based on a lognormal distribution
        # using data taken from SWIMIT. 
        StdDev = 0.310646159486924
        
        if self.eelSize == "5":
            mean = 0.356484405860182 # 5cm
        elif self.eelSize == "6":
            mean = 0.408353594633246 # 6cm
        elif self.eelSize == "7":
            mean = 0.45724900054965  # 7cm
        elif self.eelSize == "8":    
            mean = 0.503598627206196 # 8cm
        elif self.eelSize == "9":   
            mean = 0.547722492620581 # 9cm
        elif self.eelSize == "10":
            mean = 0.589868994306822 # 10cm
        
        minSpeed = 0.034305
    
        spd2 = -99
        while spd2 < minSpeed:
            spd = np.random.lognormal(np.log(mean), StdDev)
            spd2 = 2*mean - spd       
        
        return spd2              
        
class neighbour(object):
    
    def __init__(self,UniqueID,row,col,data):
        
        self.id = UniqueID
        self.row = row
        self.col = col
        self.location = np.array([self.row, self.col])
        self.vel = data[self.row, self.col]
        
    def passable(self, vel, data):
        if data[self.row, self.col] < vel:
            self.passable = True
        else:
            self.passable = False
            
