from neuron import h
from PyNeuronToolbox import morphology
from matplotlib import pyplot

h.load_file('stdrun.hoc')

class Pyramidal:
    def __init__(self, gid):
        self._gid = gid
        self._setup_morphology()
        self._discretize()
        self._add_channels()
        self._register_netcon()
    def _register_netcon(self):
        self.nc = h.NetCon(self.soma[0](0.5)._ref_v, None, sec=self.soma[0])
        pc = h.ParallelContext()
        pc.set_gid2node(self._gid, int(pc.id()))
        pc.cell(self._gid, self.nc)
        self.spike_times = h.Vector()
        self.nc.record(self.spike_times)
    def _setup_morphology(self):
        self.soma, self.axon = [], []
        self.dend, self.apic = [], []
        morphology.load('c91662.swc', fileformat='swc',
            cell=self)
    def __repr__(self):
        return 'p[%d]' % self._gid
    def _discretize(self, max_seg_length=20):
        for sec in self.all:
            sec.nseg = 1 + 2 * int(sec.L / max_seg_length)
    def _add_channels(self):
        for sec in self.soma:
            sec.insert('hh')
        for sec in self.all:
            sec.insert('pas')
            for seg in sec:
                seg.pas.g = 0.001

class Network:
    def __init__(self, num):
        self.cells = [Pyramidal(i) for i in range(num)]
        # setup an exciteable ExpSyn on each cell's dendrites
        self.syns = [h.ExpSyn(cell.dend[0](0.5)) for cell in self.cells]
        for syn in self.syns:
            syn.e = 0
        # connect cell i to cell (i + 1) % num
        pc = h.ParallelContext()
        self.ncs = []
        for i in range(num):
            nc = pc.gid_connect(i, self.syns[(i + 1) % num])
            nc.delay = 1
            nc.weight[0] = 1
            self.ncs.append(nc)

n = Network(20)

# drive the 0th cell
stim = h.NetStim()
stim.number = 1
stim.start = 3
ncstim = h.NetCon(stim, n.syns[0])
ncstim.delay = 1
ncstim.weight[0] = 1

t = h.Vector()
t.record(h._ref_t)
v = [h.Vector() for cell in n.cells]
for myv, cell in zip(v, n.cells):
    myv.record(cell.soma[0](0.5)._ref_v)

pc = h.ParallelContext()
pc.set_maxstep(10)
h.v_init = -69
h.stdinit()
pc.psolve(100)

for myv in v:
    pyplot.plot(t, myv)
pyplot.xlabel('t (ms)')
pyplot.ylabel('v (mV)')
pyplot.show()

for i, cell in enumerate(n.cells):
    pyplot.vlines(cell.spike_times, i + 0.5, i + 1.5)
pyplot.show()

import json
with open('output.json', 'w') as f:
    f.write(json.dumps({
        i: list(cell.spike_times)
        for i, cell in enumerate(n.cells)},
        indent=4))
