import json
from neuron import h
from cell import Cell
from matplotlib import pyplot as plt

h.load_file('stdrun.hoc')

NUM_CELLS = 200
RED = 2  # see https://www.neuron.yale.edu/neuron/static/py_doc/visualization/graph.html?highlight=graph%20color#Graph.color

my_cells = [Cell(i) for i in range(NUM_CELLS)]

ns = h.NetStim()
ns.number = 1
ns.start = 4  # ms

nc = h.NetCon(ns, my_cells[0].syn)
nc.delay = 1  # ms
nc.weight[0] = 0.01

my_cells_shifted = my_cells[1:]
my_cells_shifted.append(my_cells[0])

ncs = []

for pre, post in zip(my_cells, my_cells_shifted):
    nc1 = h.NetCon(pre.soma(0.5)._ref_v, post.syn, sec=pre.soma)
    nc1.delay = 2  # ms
    nc1.weight[0] = 0.01
    ncs.append(nc1)


h.tstop = 50

ps = h.PlotShape()
ps.show(0)

g = h.Graph()
for i, cell in enumerate(my_cells):
    g.color((i % 9) + 1)
    g.addvar('Cell[{}]'.format(i), cell.soma(0.5)._ref_v)
g.size(0, 50, -80, 50)
h.graphList[0].append(g)

h.finitialize(-65)
h.continuerun(h.tstop)

# save data
spike_times = {}
for cell in my_cells:
    spike_times[cell._gid] = list(cell.spike_times)

data = json.dumps(spike_times, indent=4)

with open('many-cell.json', 'w') as f:
    f.write(data)

plt.figure()
for cell in my_cells:
    if cell.spike_times:
        plt.scatter(cell.spike_times, [cell._gid] * len(cell.spike_times))
plt.axis('off')

plt.figure()
for cell in my_cells:
    if cell.spike_times:
        plt.vlines(cell.spike_times, cell._gid - 0.4, cell._gid + 0.4)
plt.axis('off')
plt.show()
