## CONNECTOR

############################################################################### PREAMBLE

import numpy as np
import matplotlib.pyplot as plt
import random
import csv
import time

start_time = time.clock()
############################################################################### DEFINE GLOBALS

WRITE = True

############################################################################### USER INPUTS

eelNumber = 1000

########################### ENVIRONMENT DATA

flowrate = "Lrg"
angle = "20"
eelSize = "8"

loader = ("%s_%s_velPlot_avg_3mm_%scm.asc"%(flowrate,angle,eelSize))

# classified data
data = np.loadtxt(loader)
[row_max, col_max] = data.shape
Tmax = 10000

####################################################
studs = [175,265,355,445,535,625,715,805,895,985,1075,1165,1255,1345,1435,1525,1615,1705,1795,1885,1975,2065,2155,2245,2335,2425,2515]
for k in studs:
    for i in range(31):
        data[0,k+i] = 99
        data[row_max-1,k+i] = 99

data[row_max-1,:] = 99
data[:,0] = 99
data[:,col_max-1] = 99

data2 = np.zeros([3*row_max-2, col_max])
for i in range(row_max-1):
    for j in range(col_max-1):
        data2[i,j] = data[i,j]
        data2[row_max-1+i,j] = data[i, j]
        data2[row_max-1+row_max-1+i,j] = data[i, j]

[row2_max, col2_max] = data2.shape

data2[row2_max-1,:] = 99
data2[:,0] = 99
data2[:,col2_max-1] = 99
data2[0,:] = 99

spawn_col = col2_max-2
spawn_row = np.arange(1,row2_max)

############################################################################### CLASS DEFINITIONS

class eel(object):

    def __init__(self, UniqueID):

        self.id = UniqueID          # ID number

        self.row = random.randint(1,row2_max-4)
        self.col = spawn_col
        self.location = np.array([self.row, self.col])
        self.history = self.location
        self.history = np.vstack([self.history, self.location])

        # Fallback variables
        self.stuck = False
        self.stuckIt = 0
        self.stuckItMax = 30
        self.pastColsNumber = 20
        self.threshold = 2

        # Approach metric varaibles
        self.approaching = False
        self.attempting = False

        #output Data
        self.output = np.array(['ID', 'Step No.', 'Row (y)', 'Col (x)'])
        self.dataLine = np.array([0,0,0,0])

    def createNeighbours(self):

        self.neighbours = list()
        self.passableNeighboursUP = list()
        self.passableNeighboursSIDE = list()
        self.passableNeighboursDOWN = list()

        self.neighbours.append(neighbour(0,self.row-1,self.col-1))
        self.neighbours[0].passable()
        if self.neighbours[0].passable == True:
            self.passableNeighboursUP.append(self.neighbours[0].id)

        self.neighbours.append(neighbour(1,self.row,self.col-1))
        self.neighbours[1].passable()
        if self.neighbours[1].passable == True:
            self.passableNeighboursUP.append(self.neighbours[1].id)

        self.neighbours.append(neighbour(2,self.row+1,self.col-1))
        self.neighbours[2].passable()
        if self.neighbours[2].passable == True:
            self.passableNeighboursUP.append(self.neighbours[2].id)

        self.neighbours.append(neighbour(3,self.row-1,self.col))
        self.neighbours[3].passable()
        if self.neighbours[3].passable == True:
            self.passableNeighboursSIDE.append(self.neighbours[3].id)

        self.neighbours.append(neighbour(4,self.row+1,self.col))
        self.neighbours[4].passable()
        if self.neighbours[4].passable == True:
            self.passableNeighboursSIDE.append(self.neighbours[4].id)

        self.neighbours.append(neighbour(5,self.row-1,self.col+1))
        self.neighbours[5].passable()
        if self.neighbours[5].passable == True:
            self.passableNeighboursDOWN.append(self.neighbours[5].id)

        self.neighbours.append(neighbour(6,self.row,self.col+1))
        self.neighbours[6].passable()
        if self.neighbours[6].passable == True:
            self.passableNeighboursDOWN.append(self.neighbours[6].id)

        self.neighbours.append(neighbour(7,self.row+1,self.col+1))
        self.neighbours[7].passable()
        if self.neighbours[7].passable == True:
            self.passableNeighboursDOWN.append(self.neighbours[7].id)

    def move(self):

        if self.stuck == False:
            if self.passableNeighboursUP:
                n = random.choice(self.passableNeighboursUP)
            elif self.passableNeighboursSIDE:
                n = random.choice(self.passableNeighboursSIDE)
            else:
                n = random.choice(self.passableNeighboursDOWN)

        if self.stuck == True:
            self.stuckIt += 1
            if self.passableNeighboursDOWN:
                n = random.choice(self.passableNeighboursDOWN)
            elif self.passableNeighboursSIDE:
                n = random.choice(self.passableNeighboursSIDE)
            else:
                n = random.choice(self.passableNeighboursUP)

        self.row = self.neighbours[n].row
        self.col = self.neighbours[n].col
        self.location = np.array([self.row, self.col])
        self.history = np.vstack([self.history, self.location])

        if self.stuckIt == self.stuckItMax:
            self.stuck = False
            self.stuckIt = 0

    def checkStuck(self):

        pastCols = np.zeros(self.pastColsNumber)
        if len(self.history) > self.pastColsNumber:
            pastColInds = range(len(self.history) - self.pastColsNumber, len(self.history))

            for i, j in zip(pastColInds, range(self.pastColsNumber)):
                pastCols[j] = self.history[i][1]

            if max(pastCols)-min(pastCols) < self.threshold:
                self.stuck = True
                self.stuckIt = 0

    def updateOutput(self, T):

        self.dataLine[0] = self.id
        self.dataLine[1] = T
        self.dataLine[2] = self.row
        self.dataLine[3] = self.col

        self.output = np.vstack([self.output,self.dataLine])

class neighbour(object):

    def __init__(self,UniqueID,row,col):

        self.id = UniqueID
        self.row = row
        self.col = col
        self.location = np.array([self.row, self.col])

    def passable(self):
        if data2[self.row, self.col] == 10.0:
            self.passable = True
        else:
            self.passable = False

############################################################################### MAKE AGENTS

eels = list()
eelsInDom = list()

for i in range(eelNumber):
    # Create an eel
    eels.append(eel(i))
    eelsInDom.append(eels[i].id)

passed = list()

approaches = 0
attempts = 0
successes = 0

############################################################################### MAIN LOOP
T = 0

while T < Tmax and len(passed) != eelNumber:

    for i in eelsInDom:

        eels[i].updateOutput(T)

        # Determine if the eel is approaching
        if eels[i].col == 2624 and eels[i].approaching == True:
            approaches += 1
        if eels[i].col == 2625:
            eels[i].approaching = True
        else:
            eels[i].approaching = False

        # Determine if the eel is making an attempt
        if eels[i].col == 2574 and eels[i].attempting == True:
            attempts += 1
        if eels[i].col == 2575:
            eels[i].attempting = True
        else:
            eels[i].attempting = False

        # Determine if the eel has passed
        if eels[i].col < 25:
            passed.append(eels[i].id)
            eelsInDom.remove(eels[i].id)
            successes += 1
            #print('%s passed' % len(passed))

        # Move the eel and check if stuck
        if eels[i].id not in passed:
            eels[i].createNeighbours()
            eels[i].move()
            if eels[i].stuck == False:
                eels[i].checkStuck()


    T += 1

############################################################################### WRITE FILE

if WRITE == True:
    fileNames = {}
    for i in range(eelNumber):
        fileNames[i] = open("outputs/CA/CA_%s_%s_agentData_%scm_%s.txt" %(flowrate,angle,eelSize,i), "w")
        writer = csv.writer(fileNames[i])
        writer.writerows(eels[i].output)
        fileNames[i].close()


############################################################################### PLOT

figManager = plt.get_current_fig_manager()
figManager.window.showMaximized()
levels = [10, 30, 60, 99]
plt.contourf(data2, levels, colors=('c', 'r', 'k'))
plt.axis('scaled')
plt.show()
for i in passed:
    plt.plot(eels[i].history[:,1], eels[i].history[:,0], color='y')
for i in eelsInDom:
    plt.plot(eels[i].history[:,1], eels[i].history[:,0], color='m')



print("\n" + " --------------------------------------------------- ")
print("  ------- %s out of %s successfully passed -------  " % (len(passed),eelNumber))
print(" --------------------------------------------------- " + "\n")
print(" ----- %s agents used " % eelNumber)
print(" ----- %s approaches " % approaches)
print(" ----- %s attempts " % attempts)
print(" ----- %s successes " % successes)
print("     ----- The code took %s seconds to execute -----     " % (round(time.clock() - start_time)))
