Andrew Gillies and I have created a simple network simulation using the MPI features of Neuron 5.9. It models the highly biologically-relevant case of a network of squid axons connected by excitatory synapses (cough). The reason we wrote it was that we needed a script to benchmark the performance of clusters of machines running MPI.
We're posting it here in case it's of use to anyone.
All the best,
David.
Code: Select all
load_file("stdgui.hoc")
ncells = 100 // number of cells
rseed = 1027 // random seed
tstop = 1000
global_maxcons = 20 // number of connections onto each cell
global_weight = 0.004 // weight of each connection
global_delaymean = 13.0 // mean of delays
global_delayvar = 1.4 // variance of delays
global_delaymin = 0.2 // miniumum delays
global_threshold = 10.0 // threshold
begintemplate hhcell
public soma, syns, gid, spiketimes, idvec
objref syns, spiketimes, idvec
create soma
proc init() {
gid=$1
create soma
soma {
nseg = 1
diam = 18.8
L = 18.8
Ra = 123.0
insert hh
}
// list for all the synapses on this cell
syns = new List()
// where to store spiketimes
spiketimes = new Vector()
idvec = new Vector()
}
endtemplate hhcell
objref cells, pc, nc, nclist, stims, nil
// first create the parallel context:
pc = new ParallelContext()
nhost = pc.nhost
if (nhost < 2) { // for no PVM or MPI and for 1 host
nhost = 1
myid = 0
}else{
myid = pc.id
}
nwork = nhost
cells = new List() // the cells
nclist = new List() // the netcons connecting cells in this worker
stims = new List() // the stimulators
// assign the cells to different workers
for i = 0, ncells-1 {
pc.set_gid2node(i, i%nwork)
}
// create the cells
objref tmpcell
for i = 0, ncells-1 {
if (pc.gid_exists(i)==1) {
tmpcell = new hhcell(i)
tmpcell.soma nc = new NetCon(&v(0.5),nil)
pc.cell(i,nc,1)
cells.append(tmpcell)
pc.spike_record(i,tmpcell.spiketimes,tmpcell.idvec)
}
}
// for i = 0, cells.count()-1 {
// print "cell ",cells.object(i).gid,"on id ",myid,"has gid status ",pc.gid_exists(cells.object(i).gid)
// }
// useful parallel functions
proc set_maxstep() {local local_minimum_delay
// arg is max allowed, return val is just for this subnet
local_minimum_delay = pc.set_maxstep(5)
}
proc doinit() {
stdinit()
}
proc pinit() {
if (nwork > 1) {
pc.context("doinit()\n")
}
doinit()
}
proc psolve() {
pc.psolve($1)
}
proc pcontinue() {
if (nwork > 1) {
pc.context("psolve", $1)
}
psolve($1)
}
proc prun() {
pinit()
pcontinue(tstop)
}
proc post_spiketimes() {
if (cells.count()>0) {
for i = 0, cells.count()-1 {
pc.pack(cells.object(i).gid, cells.object(i).spiketimes)
pc.post("spiketimes")
}
}
}
// connect the cells together
// create a synapse on each cell that can communicate accross cpus
objref stim,syn,randomcon,randomdel
// first setup stimulator cell(s)
randomdel = new Random(rseed+pc.id)
od = randomdel.uniform(0,20)
if (cells.count()>0) {
for i = 0, cells.count()-1 {
cells.object(i).soma stim = new IClamp(0.5)
stim.del = 100 + randomdel.repick()
stim.dur = 1
stim.amp = 0.1
stims.append(stim)
}
}
objref syn,randomcon,randomdel
// now setup random synapses
// Rules: each cell has no more than 5 synapses!
// cell 0 is the stimulator cell (done above)
//
// for each cell on this worker...
if (cells.count()>0) {
for i = 0, cells.count()-1 {
// if it doesn't have it maximum number of synapses
randomcon = new Random(rseed+cells.object(i).gid)
wc = randomcon.uniform(0,ncells)
randomdel = new Random(rseed+cells.object(i).gid)
dv = randomdel.normal(global_delaymean,global_delayvar)
while (cells.object(i).syns.count()<global_maxcons) {
// choose a cell to receive from
wc = int(randomcon.repick())
//make sure we don't connect to ourselves
if (wc != cells.object(i).gid) {
//print "connecting cell ",wc," with cell ",cells.object(i).gid
cells.object(i).soma syn = new ExpSyn(0)
cells.object(i).syns.append(syn)
// connection from gid=wc
nc = pc.gid_connect(wc,syn)
nc.weight = global_weight
dv = randomdel.repick()
while (dv<global_delaymin) {
dv = randomdel.repick()
}
nc.delay = dv
nc.threshold = global_threshold
nclist.append(nc)
}
}
}
}
// kick off all the slaves...
pc.runworker()
pc.context("set_maxstep()\n")
set_maxstep()
prun()
pc.context("post_spiketimes()\n")
post_spiketimes()
objref times
sum = 0
gid = 0
for i=0,ncells-1 {
times = new Vector()
pc.take("spiketimes")
pc.unpack(&gid,times)
print gid,times.size()
sum = sum+times.size()
}
print "mean cell firing ", ((sum/ncells)/(tstop/1000)), "Hz"
pc.done()
quit()