| | 101 | def _add_prefix(synapse_model, prefix, port_map): |
| | 102 | """ |
| | 103 | Add a prefix to all variables in `synapse_model`, except for variables with |
| | 104 | receive ports and specified in `port_map`. |
| | 105 | """ |
| | 106 | synapse_model.__cache__ = {} |
| | 107 | exclude = [] |
| | 108 | new_port_map = [] |
| | 109 | for name1, name2 in port_map: |
| | 110 | if synapse_model.ports_map[name2].mode == 'recv': |
| | 111 | exclude.append(name2) |
| | 112 | new_port_map.append((name1, name2)) |
| | 113 | else: |
| | 114 | new_port_map.append((name1, prefix + '_' + name2)) |
| | 115 | synapse_model.add_prefix(prefix + '_', exclude=exclude) |
| | 116 | return new_port_map |
| | 117 | |
| | 118 | |
| | 119 | class _build_nineml_celltype(type): |
| | 120 | """ |
| | 121 | Metaclass for building NineMLCellType subclasses |
| | 122 | """ |
| | 123 | def __new__(cls, name, bases, dct): |
| | 124 | # join the neuron and synapse components into a single component |
| | 125 | combined_model = dct["neuron_model"] |
| | 126 | for label in dct["synapse_models"].keys(): |
| | 127 | port_map = dct["port_map"][label] |
| | 128 | port_map = _add_prefix(dct["synapse_models"][label], label, port_map) |
| | 129 | dct["weight_variables"][label] = label + "_" + dct["weight_variables"][label] |
| | 130 | combined_model = join(combined_model, |
| | 131 | dct["synapse_models"][label], |
| | 132 | port_map, |
| | 133 | name=name) |
| | 134 | dct["combined_model"] = combined_model |
| | 135 | # set class attributes required for a PyNN cell type class |
| | 136 | dct["default_parameters"] = dict((name, 1.0) |
| | 137 | for name in combined_model.parameters) |
| | 138 | dct["default_initial_values"] = dict((name, 0.0) |
| | 139 | for name in combined_model.state_variables) |
| | 140 | dct["synapse_types"] = dct["synapse_models"].keys() #really need an ordered dict |
| | 141 | dct["injectable"] = True # need to determine this. How?? |
| | 142 | dct["recordable"] = [port.name for port in combined_model.analog_ports] + ['spikes', 'regime'] |
| | 143 | dct["standard_receptor_type"] = (dct["synapse_types"] == ('excitatory', 'inhibitory')) |
| | 144 | dct["conductance_based"] = True # how to determine this?? |
| | 145 | dct["model_name"] = name |
| | 146 | logger.debug("Creating class '%s' with bases %s and dictionary %s" % (name, bases, dct)) |
| | 147 | # generate and compile NMODL code, then load the mechanism into NEUORN |
| | 148 | _compile_nmodl(combined_model, dct["weight_variables"]) # weight variables should really be stored within combined_model |
| | 149 | return type.__new__(cls, name, bases, dct) |
| | 150 | |
| | 151 | |
| | 152 | |
| | 162 | |
| | 163 | |
| | 164 | def join(c1, c2, port_map=[], name=None): |
| | 165 | """Create a NineML component by joining the two given components.""" |
| | 166 | logger.debug("Joining components %s and %s with port map %s" % (c1, c2, port_map)) |
| | 167 | logger.debug("New component will have name '%s'" % name) |
| | 168 | # combine bindings from c1 and c2 |
| | 169 | bindings = {} |
| | 170 | for b in chain(c1.bindings, c2.bindings): |
| | 171 | bindings[b.name] = b |
| | 172 | # combine ports (some will later be removed) |
| | 173 | all_ports = c1.ports_map.copy() |
| | 174 | all_ports.update(c2.ports_map) |
| | 175 | # event ports do not be passed to the constructor, as they are attached to transitions |
| | 176 | for port_name, port in all_ports.items(): |
| | 177 | if isinstance(port, nineml.EventPort): |
| | 178 | all_ports.pop(port_name) |
| | 179 | # connect ports. |
| | 180 | # currently, when ports are connected they disappear. It might be better to |
| | 181 | # explicitly keep the ports in the new component but mark them as connected |
| | 182 | for name1, name2 in port_map: |
| | 183 | assert name1 in c1.ports_map, "%s is not in %s" % (name1, c1.ports_map.keys()) |
| | 184 | assert name2 in c2.ports_map, "%s is not in %s" % (name2, c2.ports_map.keys()) |
| | 185 | |
| | 186 | port1 = c1.ports_map[name1] |
| | 187 | port2 = c2.ports_map[name2] |
| | 188 | assert port1.mode != port2.mode |
| | 189 | if port1.mode == 'send': |
| | 190 | send_port = port1 |
| | 191 | recv_port = port2 |
| | 192 | send_port_name = name1 |
| | 193 | recv_port_name = name2 |
| | 194 | else: |
| | 195 | send_port = port2 |
| | 196 | recv_port = port1 |
| | 197 | send_port_name = name2 |
| | 198 | recv_port_name = name1 |
| | 199 | # when connecting ports in which the send port has an expression, need |
| | 200 | # to create a binding for this expression in the new component |
| | 201 | if send_port.expr: |
| | 202 | func_args = c1.non_parameter_symbols.union(c2.non_parameter_symbols).intersection(send_port.expr.names) |
| | 203 | lhs = "%s(%s)" % (send_port_name, ",".join(func_args)) |
| | 204 | send_binding = nineml.Binding(lhs, send_port.expr.rhs) |
| | 205 | bindings[send_binding.name] = send_binding |
| | 206 | for eq in chain(c1.equations, c2.equations): |
| | 207 | if send_port_name in eq.names: |
| | 208 | eq.rhs = eq.rhs_name_transform({send_port_name: lhs}) |
| | 209 | if recv_port.mode == 'reduce': |
| | 210 | # need to retain reduce ports as they can be connected to in a future join |
| | 211 | if recv_port_name in bindings: |
| | 212 | # this reduce port has already been connected to, so combine using its reduce_op |
| | 213 | reduce_binding = bindings[recv_port_name] |
| | 214 | func_args = func_args.union(reduce_binding.args) |
| | 215 | lhs = "%s(%s)" % (recv_port_name, ",".join(func_args)) |
| | 216 | rhs = recv_port.reduce_op.join([reduce_binding.rhs, send_binding.lhs]) |
| | 217 | else: |
| | 218 | # this is the first time this reduce port has been connected to |
| | 219 | lhs = "%s(%s)" % (recv_port_name, ",".join(func_args)) |
| | 220 | rhs = send_binding.lhs |
| | 221 | bindings[recv_port_name] = nineml.Binding(lhs, rhs) |
| | 222 | recv_port.connected = True |
| | 223 | else: |
| | 224 | all_ports.pop(name1) |
| | 225 | else: |
| | 226 | if recv_port.mode == 'reduce': |
| | 227 | raise NotImplementedError |
| | 228 | else: |
| | 229 | all_ports.pop(name1) |
| | 230 | |
| | 231 | if name1 != name2: |
| | 232 | #c2.substitute(name2, name1) # need to implement this. Currently this all only works if name1 == name2 |
| | 233 | # probably needs to happen sooner in the function |
| | 234 | all_ports.pop(name2) |
| | 235 | |
| | 236 | # where parameters have become bindings due to connecting ports, replace |
| | 237 | # bare names with function calls in the equations |
| | 238 | for bname, binding in bindings.items(): |
| | 239 | for eq in chain(c1.equations, c2.equations): |
| | 240 | if bname in eq.names: |
| | 241 | print "#### replacing %s by %s" % (bname, binding.lhs) |
| | 242 | pattern = re.compile(r'%s(\([\w\, ]*\))?' % bname) |
| | 243 | m = pattern.search(eq.rhs) |
| | 244 | if m: |
| | 245 | eq.rhs = pattern.sub(binding.lhs, eq.rhs) |
| | 246 | else: |
| | 247 | eq.rhs = eq.rhs_name_transform({bname: binding.lhs}) |
| | 248 | |
| | 249 | # create new regimes from all possible combinations of the regimes from the |
| | 250 | # two components |
| | 251 | regime_map = {} |
| | 252 | for r1 in c1.regimes: |
| | 253 | regime_map[r1.name] = {} |
| | 254 | for r2 in c2.regimes: |
| | 255 | if r1.name == r2.name: |
| | 256 | new_name = r1.name |
| | 257 | else: |
| | 258 | new_name = "%s_AND_%s" % (r1.name, r2.name) |
| | 259 | kwargs = {'name': new_name} |
| | 260 | new_regime = nineml.Regime(*r1.nodes.union(r2.nodes), **kwargs) |
| | 261 | regime_map[r1.name][r2.name] = new_regime |
| | 262 | # create transitions between all the new regimes |
| | 263 | transitions = [] |
| | 264 | for r1 in c1.regimes: |
| | 265 | for r2 in c2.regimes: |
| | 266 | for t in r1.transitions: |
| | 267 | new_transition = nineml.Transition(*t.nodes, |
| | 268 | from_=regime_map[r1.name][r2.name], |
| | 269 | to=regime_map[t.to.name][r2.name], |
| | 270 | condition=t.condition) |
| | 271 | transitions.append(new_transition) |
| | 272 | for t in r2.transitions: |
| | 273 | new_transition = nineml.Transition(*t.nodes, |
| | 274 | from_=regime_map[r1.name][r2.name], |
| | 275 | to=regime_map[r1.name][t.to.name], |
| | 276 | condition=t.condition) |
| | 277 | transitions.append(new_transition) |
| | 278 | |
| | 279 | regimes = [] |
| | 280 | for d in regime_map.values(): |
| | 281 | regimes.extend(d.values()) |
| | 282 | name = name or "%s__%s" % (c1.name, c2.name) |
| | 283 | return nineml.Component(name, |
| | 284 | regimes=regimes, |
| | 285 | transitions=transitions, |
| | 286 | ports=all_ports.values(), |
| | 287 | bindings=bindings.values()) |