I am trying to create a basic ring network to run in parallel as in Hines & Carnevalle 2007 (see in http://www.neuron.yale.edu/neuron/nrnpubs) together with the help of Thomas McTavish tutorial on https://nn.med.yale.edu:8000 to implement it in Python. To become familiar with the parallel environment of NEURON, I wanted to create a simple ring of cells, with as many neurons as parallel processes. For that I first created a Cell class, with a synapse on the soma, and methods to get the vectors of time and voltage and to create a NetCon.
Code: Select all
"""
morphology.py
this file contains a class to create a very basic cell type to work
in parallel
"""
from mpi4py import MPI # always load mpi4py before neuron
from neuron import h
import numpy as np
class DummyCell(object):
""" A dummy cell simply for testing """
def __init__(self):
self.soma = h.Section(name ='soma', cell = self)
self.soma.L = self.soma.diam = 30.
self.soma.insert('hh')
# create a synapse in soma
self.syn = h.ExpSyn(self.soma(0.5), name='syn', sec=self.soma)
self.syn.tau = 2.
# assign gid to current thread/rank
# because we did not start ParallelContext, we have to use
# the MPI for getting the rank
self.gid = MPI.COMM_WORLD.Get_rank()
# time and voltage vectors
self._time = h.Vector()
self._time.record(h._ref_t)
self._voltage = h.Vector()
self._voltage.record(self.soma(0.5)._ref_v)
# a list of the NetCons of this cell
self.netcon = []
def connect2target(self, target):
""" connects the firing of this cell to a target
via NetCon and appends the NetCon to the netcon list.
"""
source = self.soma(0.5)._ref_v
netcon = h.NetCon(source, target, sec = self.soma)
netcon.threshold = 10.0
netcon.delay = 3.1
netcon.weight[0] = 0.04
self.netcon.append(netcon)
return netcon
def get_vectorlist(self, time=True, voltage=True):
""" return a list of NumPy vectors with time
and voltage at the soma """
vectorlist = []
if time is True:
vectorlist.append(np.array(self._time))
if voltage is True:
vectorlist.append(np.array(self._voltage))
return vectorlist
Code: Select all
"""
ring.py
A class defining a network of N cells, where N is the number of
process currently called in a parallel environment. Cell N will be
connected to cell with gid =0.
"""
# this will initialize MPI
from morphology import DummyCell as Cell
from neuron import h
class Ring(object):
def __init__(self):
""" creates a ring containing as many cells as
parallel processes """
# list of cell objects
self.cell = Cell() # create one cell in the current thread
# lists of netcons on this host
self.netcon = []
# external stimulator to this cell only if gid=0
self.spk_generator = None
self.nc_generator = None
# external stimulator
if self.cell.gid == 0:
self.spk_generator = h.NetStim()
self.spk_generator.number = 1
self.spk_generator.start = 20.
self.nc_generator = h.NetCon(self.spk_generator, self.cell.syn)
self.nc_generator.delay = 1.
self.nc_generator.weight[0] = 0.04
self.nc_generator.threshold = 10.
self.pc = h.ParallelContext()
# associate the gid of the cell with host's thread
self.pc.set_gid2node(self.cell.gid, int(self.pc.id()))
# attach NetCon source (spike detector) to cell.gid
nc = self.cell.connect2target(None)
self.pc.cell(self.cell.gid, nc)
# ** Connection **
ncells = int(self.pc.nhost())
targid = (self.cell.gid+1)%ncells
if self.pc.gid_exists(targid):
# get the object associated with that gid ???
target = self.pc.gid2cell(targid)
# and connect this gid to its synapse
netcon = self.pc.gid_connect(self.cell.gid, target.syn)
netcon.weight[0] = 0.04
netcon.delay = 1.
self.netcon.append(netcon)
def get_cell(self, cell_gid=0):
""" returns the object associated with the gid """
if self.pc.gid_exists(cell_gid):
return self.pc.gid2cell(cell_gid)
else:
return None
Code: Select all
""" test_simulate.py
execute with mpiexec -np 5 python test_simulate.py
""""
from ring import Ring
h.load_file('stdrun.hoc')
pc = h.ParallelContext()
myring = Ring()
# ** Simulation
pc.set_maxstep(10)
h.stdinit()
h.dt = 0.025
pc.psolve(120)
import numpy as np
# read voltage
cell = myring.get_cell(0)
if cell is not None:
t,v = cell.get_vectorlist()
import matplotlib.pyplot as plt
plt.plot(t,v)
plt.show()
pc.runworker()
pc.done()
Thanks in advance!