# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max-Planck-Institute for Medical Research, Heidelberg, Germany
# Authors: Sven Dorkenwald, Philipp Schubert, Joergen Kornfeld
import copy
import gc
import numpy as np
import re
import time
import os
from multiprocessing import Pool, Manager, cpu_count
from numpy import array as arr
from scipy import spatial
from sys import stdout
import cPickle as pickle
from scipy import spatial
from scipy.sparse.csgraph._tools import csgraph_from_dense
from sklearn.externals import joblib
from knossos_utils.knossosdataset import KnossosDataset
from multi_proc.multi_proc_main import start_multiprocess
from knossos_utils import skeleton_utils as su
from processing.axoness import predict_axoness_from_nodes
from processing.mapper import feature_valid_syns, calc_syn_dict
from processing.synapticity import parse_synfeature_from_node
from utils.datahandler import get_filepaths_from_dir, \
load_ordered_mapped_skeleton, get_paths_of_skelID
from utils.datahandler import write_obj2pkl, load_pkl2obj
from knossos_utils.skeleton import Skeleton, SkeletonAnnotation, SkeletonNode
[docs]def write_summaries(wd):
"""Write information about contact sites and synapses, axoness and
connectivity.
Parameters
----------
wd : str
Path to working directory of SyConn
"""
cs_dir = wd + '/contactsites/'
cs_nodes, cs_feats = collect_contact_sites(cs_dir, only_sj=True)
write_cs_summary(cs_nodes, cs_feats, cs_dir)
cs_nodes, cs_feats = collect_contact_sites(cs_dir, only_sj=False)
features, axoness_info, syn_pred = feature_valid_syns(cs_dir, only_sj=True,
only_syn=False)
features, axoness_info, pre_dict, all_post_ids, valid_syn_array, ax_dict =\
calc_syn_dict(features[syn_pred], axoness_info[syn_pred], get_all=True)
write_obj2pkl(pre_dict, cs_dir + 'pre_dict.pkl')
write_obj2pkl(ax_dict, cs_dir + 'axoness_dict.pkl')
write_cs_summary(cs_nodes, cs_feats, cs_dir, supp='_all', syn_only=False)
features, axoness_info, _ = feature_valid_syns(cs_dir, only_sj=False,
only_syn=False,
all_contacts=True)
features, axoness_info, pre_dict, all_post_ids, valid_syn_array, ax_dict =\
calc_syn_dict(features, axoness_info, get_all=True)
write_obj2pkl(pre_dict, cs_dir + 'pre_dict_all.pkl')
write_obj2pkl(ax_dict, cs_dir + 'axoness_dict_all.pkl')
gc.collect()
write_property_dict(cs_dir)
conn_dict_wrapper(wd, all=False)
conn_dict_wrapper(wd, all=True)
print "Found %d contact sites containing %d" \
" synapses." % (len(cs_feats), np.sum(syn_pred))
[docs]def write_cs_summary(cs_nodes, cs_feats, cs_dir, supp='', syn_only=True):
"""Writes contact site summary of all contact sites without sampling
Parameters
----------
cs_nodes: lsit of SkeletonNodes
contact site nodes
cs_feats: np.array
synapse features of contact sites
cs_dir: str
path to contact site directory
supp : str
supplement for resulting file
syn_only: bool
if only synapses are to be saved.
"""
dummy_skel = Skeleton()
dummy_anno = SkeletonAnnotation()
dummy_anno.setComment('CS Summary')
if len(cs_nodes) == 0:
write_obj2pkl({}, cs_dir + 'cs_dict%s.pkl' % supp)
dummy_skel.add_annotation(dummy_anno)
fname = cs_dir + 'cs_summary%s.k.zip' % supp
dummy_skel.to_kzip(fname)
return
clf_path = cs_dir + '/../models/rf_synapses/rfc_syn.pkl'
# print "\nUsing %s for synapse prediction." % clf_path
rfc_syn = joblib.load(clf_path)
probas = rfc_syn.predict_proba(cs_feats)
preds = rfc_syn.predict(cs_feats)
cnt = 0
cs_dict = {}
for syn_nodes in cs_nodes:
for node in syn_nodes:
if 'center' in node.getComment():
proba = probas[cnt]
pred = preds[cnt]
node.data['syn_proba'] = proba[1]
node.data['syn_pred'] = pred
feat = node.data['syn_feat']
assert np.array_equal(feat, cs_feats[cnt]), "Features unequal!"
cs_name = node.data['cs_name']
for dummy_node in syn_nodes:
dummy_anno.addNode(dummy_node)
cnt += 1
if syn_only and 'sj' not in node.getComment():
continue
if 'sj' in node.getComment():
overlap_vx = np.load(cs_dir + '/overlap_vx/' +
cs_name + 'ol_vx.npy')
else:
overlap_vx = np.zeros((0, ))
cs = {}
cs['overlap_vx'] = overlap_vx
cs['syn_pred'] = pred
cs['syn_proba'] = proba[1]
cs['center_coord'] = node.getCoordinate()
cs['syn_feats'] = feat
cs['cs_dist'] = float(node.data['cs_dist'])
cs['mean_cs_area'] = float(node.data['mean_cs_area'])
cs['overlap_area'] = float(node.data['overlap_area'])
cs['overlap'] = float(node.data['overlap'])
cs['abs_ol'] = float(node.data['abs_ol'])
cs['overlap_cs'] = float(node.data['overlap_cs'])
cs['cs_name'] = node.data['cs_name']
cs_dict[cs_name] = cs
dummy_skel.add_annotation(dummy_anno)
fname = cs_dir + 'cs_summary%s.k.zip' % supp
dummy_skel.to_kzip(fname)
write_obj2pkl(cs_dict, cs_dir + 'cs_dict%s.pkl' % supp)
# print "Saved CS summary at %s." % fname
[docs]def calc_cs_node(args):
"""Helper function. Calculates three nodes for given contantct site annotation
containing center of cs and the nearest skeleton nodes.
The skeleton nodes containg the property axoness, radius and skel_ID.
The center contains all above information.
Parameters
----------
args: list
Filepath to AnnotationObject containing one CS list to store number
of processed cs and queue
Returns
-------
list of SkeletonNodes
Three nodes: nodes of adjacent skeletons and center node of CS
"""
cs_path, q = args
anno = su.loadj0126NML(cs_path)[0]
ids, axoness = predict_axoness_from_nodes(anno)
syn_nodes = []
center_node = None
for node in anno.getNodes():
n_comment = node.getComment()
if 'center' in n_comment:
try:
feat = parse_synfeature_from_node(node)
except KeyError:
feat = re.findall('[\d.e+e-]+', node.data['syn_feat'])
feat = arr([float(f) for f in feat])
center_node = copy.copy(node)
center_node.data['syn_feat'] = feat
cs_ix = re.findall('cs(\d+)', anno.getComment())[0]
add_comment = ''
if 'vc' in anno.getComment():
add_comment += '_vc'
if 'sj' in anno.getComment():
add_comment += '_sj'
center_node.setComment('cs%s%s_skel_%s_%s' % (cs_ix, add_comment,
str(ids[0]), str(ids[1])))
center_node_comment = center_node.getComment()
center_node.data['vc_size'] = None
center_node.data['sj_size'] = None
for node in anno.getNodes():
n_comment = node.getComment()
if 'skelnode_area' in n_comment:
skel_node = copy.copy(node)
try:
skel_id = int(re.findall('(\d+)_skelnode',
skel_node.getComment())[0])
except IndexError:
skel_id = int(re.findall('syn(\d+)', skel_node.getComment())[0])
assert skel_id in ids, 'Wrong skeleton ID during axoness parsing.'
skel_node.data['skelID'] = skel_id
skel_node.setComment(center_node_comment+'_skelnode'+str(skel_id))
syn_nodes.append(skel_node)
if '_vc-' in n_comment:
center_node.data['vc_size'] = node.data['radius']
# if '_sj-' in n_comment:
# center_node.data['sj_size'] = node.data['radius']
center_node.setComment(center_node_comment+'_center')
syn_nodes.append(center_node)
assert len(syn_nodes) == 3, 'Number of synapse nodes different than three'
q.put(1)
return syn_nodes
[docs]def get_spine_summary(syn_nodes, cs_dir):
"""Write out dictionary containing spines (dictionaries) with contact site name
and spine number as key. The spine dictionary contains three values
'head_diameter', 'branch_dist' and head_proba
Parameters
----------
syn_nodes : list of SkeletonNodes
cs_dir: str
path to contact site data
"""
spine_summary = {}
for nodes in syn_nodes:
skel1_node, skel2_node, center_node = nodes
spiness1 = int(skel1_node.data['spiness_pred'])
spiness2 = int(skel2_node.data['spiness_pred'])
cs_name = center_node.getComment()
if spiness1 == 1:
spine_node = skel1_node
spine = get_spine_dict(spine_node)
spine_summary[cs_name] = spine
if spiness2 == 1:
spine_node = skel2_node
spine = get_spine_dict(spine_node)
spine_summary[cs_name] = spine
write_obj2pkl(spine_summary, cs_dir+'/spine_summary.pkl')
[docs]def get_spine_dict(spine_node, max_dist=100):
"""Gather spine information from spine endnote
Parameters
----------
spine_node: SkeletonNode
Node containing spine information
max_dist : float
Maximum distance from spine head to endpoint in nm
Returns
-------
dict
Spine properties 'head_diameter' and 'branch_dist' and 'head_proba'
"""
spine = {}
endpoint_dist = float(spine_node.data['end_dist'])
if endpoint_dist > max_dist:
return None
spine['head_diameter'] = float(spine_node.data['head_diameter'])
spine['branch_dist'] = float(spine_node.data['branch_dist'])
spine['head_proba'] = float(spine_node.data['spiness_proba1'])
return spine
[docs]def update_axoness_dict(cs_dir, syn_only=True):
"""Update axoness dictionary with current data pulled from 'neurons/'
folder
Parameters
----------
cs_dir: str
path to contact site data
syn_only : bool
if only synapses are to be saved.
Returns
-------
dict
updated axoness dictionary
"""
if syn_only:
dict_path = cs_dir + 'cs_dict.pkl'
else:
dict_path = cs_dir + 'cs_dict_all.pkl'
cs_dict = load_pkl2obj(dict_path)
cs_keys = cs_dict.keys()
params = []
for k in cs_keys:
if syn_only and not cs_dict[k]['syn_pred']:
continue
# new_names = convert_to_standard_cs_name(k)
skels = re.findall('[\d]+', k)
ax_key = skels[2] + '_' + skels[0] + '_' + skels[1]
new_names = [k]
param = None
for var in new_names:
try:
center_coord = cs_dict[var]['center_coord']
param = (center_coord, ax_key)
break
except KeyError:
continue
if param is None:
print "Didnt find CS of key %s" % k
else:
params.append(param)
new_vals = start_multiprocess(update_single_cs_axoness, params, debug=False)
new_ax_dict = {}
for k, val in new_vals:
new_ax_dict[k] = val
return new_ax_dict
[docs]def update_single_cs_axoness(params):
"""Get axoness of adjacent tracings of contact site
Find nearest tracing node to representative coordinate of contact site of
touching tracings.
Parameters
----------
params : tuple (numpy.array, str)
Representative coordinate and contact site key
Returns
-------
str, dict
contact site key, dictionary of skeleton axoness
"""
center_coord = params[0]
key = params[1]
cs_nb, skel1, skel2 = re.findall('(\d+)_(\d+)_(\d+)', key)[0]
skel1_path, skel2_path = get_paths_of_skelID([skel1, skel2],
traced_skel_dir='/lustre/pschuber/mapped_soma_tracings/nml_obj/')
skel1_anno = load_ordered_mapped_skeleton(skel1_path)[0]
skel2_anno = load_ordered_mapped_skeleton(skel2_path)[0]
a_nodes = [node for node in skel1_anno.getNodes()]
b_nodes = [node for node in skel2_anno.getNodes()]
a_coords = arr([node.getCoordinate() for node in a_nodes])\
* skel1_anno.scaling
b_coords = arr([node.getCoordinate() for node in b_nodes])\
* skel2_anno.scaling
a_tree = spatial.cKDTree(a_coords)
b_tree = spatial.cKDTree(b_coords)
center_coord *= np.array(skel1_anno.scaling)
_, a_near = a_tree.query(center_coord, 1)
_, b_near = b_tree.query(center_coord, 1)
axoness_entry = {skel1: a_nodes[a_near].data["axoness_pred"],
skel2: b_nodes[b_near].data["axoness_pred"]}
return key, axoness_entry
[docs]def convert_to_standard_cs_name(name):
"""Convert to all possible contact site key combinations
Parameters
----------
name : str
Contact site key
Returns
-------
list of str
possible key variations
"""
cs_nb, skel1, skel2 = re.findall('(\d+)_(\d+)_(\d+)', name)[0]
new_name1 = 'skel_%s_%s_cs%s_sj' % (skel1, skel2, cs_nb)
new_name2 = 'skel_%s_%s_cs%s_vc_sj' % (skel1, skel2, cs_nb)
new_name3 = 'skel_%s_%s_cs%s_sj' % (skel2, skel1, cs_nb)
new_name4 = 'skel_%s_%s_cs%s_vc_sj' % (skel2, skel1, cs_nb)
new_name5 = 'skel_%s_%s_cs%s' % (skel2, skel1, cs_nb)
new_name6 = 'skel_%s_%s_cs%s' % (skel1, skel2, cs_nb)
return [new_name1, new_name2, new_name3, new_name4, new_name5, new_name6]
[docs]def update_property_dict(cs_dir):
"""Update property dictionary
Parameters
----------
cs_dir : str
Returns
-------
dict
update property dictionary
"""
ax_dict = load_pkl2obj(cs_dir + 'axoness_dict.pkl')
cs_dict = load_pkl2obj(cs_dir + 'cs_dict.pkl')
ax_keys = ax_dict.keys()
params = []
for k in ax_keys:
new_names = convert_to_standard_cs_name(k)
param = None
for var in new_names:
try:
center_coord = cs_dict[var]['center_coord']
param = (center_coord, k, cs_dir)
break
except KeyError:
continue
if param is None:
print "Didnt find CS of key %s" % k
else:
params.append(param)
new_vals = start_multiprocess(update_single_cs_properties, params,
debug=False)
new_property_dict = {}
for k, val in new_vals:
new_property_dict[k] = val
return new_property_dict
[docs]def update_single_cs_properties(params):
"""Helper function of update_property_dict
"""
center_coord = params[0]
key = params[1]
cs_dir = params[2]
cs_nb, skel1, skel2 = re.findall('(\d+)_(\d+)_(\d+)', key)[0]
skel1_path, skel2_path = get_paths_of_skelID([skel1, skel2],
traced_skel_dir=cs_dir + '/../neurons/')
skel1_anno = load_ordered_mapped_skeleton(skel1_path)[0]
skel2_anno = load_ordered_mapped_skeleton(skel2_path)[0]
a_nodes = [node for node in skel1_anno.getNodes()]
b_nodes = [node for node in skel2_anno.getNodes()]
a_coords = arr([node.getCoordinate() for node in a_nodes])\
* skel1_anno.scaling
b_coords = arr([node.getCoordinate() for node in b_nodes])\
* skel2_anno.scaling
a_tree = spatial.cKDTree(a_coords)
b_tree = spatial.cKDTree(b_coords)
_, a_near = a_tree.query(center_coord, 1)
_, b_near = b_tree.query(center_coord, 1)
property_entry = {skel1: {'spiness': int(a_nodes[a_near].data["spiness_pred"]),
'axoness': int(a_nodes[a_near].data["axoness_pred"])},
skel2: {'spiness': int(b_nodes[b_near].data["spiness_pred"]),
'axoness': int(b_nodes[b_near].data["axoness_pred"])}}
return key, property_entry
[docs]def write_property_dict(cs_dir):
"""Write property dictionary to working directory
Parameters
----------
cs_dir : str
"""
new_property_dict = update_property_dict(cs_dir)
write_obj2pkl(new_property_dict, cs_dir + 'property_dict.pkl')
[docs]def write_axoness_dicts(cs_dir):
"""Write axoness dictionaries to working directory
Parameters
----------
cs_dir : str
"""
new_ax_all = update_axoness_dict(cs_dir, syn_only=False)
write_obj2pkl(new_ax_all, cs_dir + '/axoness_dict_all.pkl')
new_ax = update_axoness_dict(cs_dir, syn_only=True)
write_obj2pkl(new_ax, cs_dir + '/axoness_dict.pkl')
[docs]def get_number_cs_details(cs_path):
"""Prints number of synapse candidates and number of contact sites
Parameters
----------
cs_path : str
"""
sj_samples = [path for path in
get_filepaths_from_dir(cs_path+'cs_sj/', ending='nml')]
sj_vc_samples = [path for path in
get_filepaths_from_dir(cs_path+'cs_vc_sj/', ending='nml')]
vc_samples = [path for path in
get_filepaths_from_dir(cs_path+'cs_vc/', ending='nml')]
cs_samples = [path for path in
get_filepaths_from_dir(cs_path+'cs/', ending='nml')]
cs_only = len(vc_samples) + len(cs_samples)
syn_only = len(sj_samples)+len(sj_vc_samples)
print "Found %d syn-candidates and %d contact sites." % (syn_only,
cs_only+syn_only)
[docs]def conn_dict_wrapper(wd, all=False):
"""Wrapper function to write connectivity dictionary to working directory
"""
if all:
suffix = "_all"
else:
suffix = ""
synapse_matrix(wd, suffix=suffix)
syn_type_majority_vote(wd, suffix=suffix)
[docs]def synapse_matrix(wd, type_threshold=0.225, suffix="",
exclude_dendrodendro=True):
"""
Parameters
----------
wd : str
type_threshold : float
suffix : str
exclude_dendrodendro : bool
"""
save_folder = wd + '/contactsites/'
if not os.path.exists(save_folder):
os.makedirs(save_folder)
with open(save_folder + "/cs_dict%s.pkl" % suffix, "r") as f:
cs_dict = pickle.load(f)
with open(save_folder + "/axoness_dict%s.pkl" % suffix, "r") as f:
ax_dict = pickle.load(f)
with open(wd + '/neurons/celltype_pred_dict.pkl', "r") as f:
cell_type_dict = pickle.load(f)
kd_asym = KnossosDataset()
kd_asym.initialize_from_knossos_path(wd + '/knossosdatasets/asymmetric/')
kd_sym = KnossosDataset()
kd_sym.initialize_from_knossos_path(wd + '/knossosdatasets/symmetric/')
ids = []
cs_keys = cs_dict.keys()
for cs_key in cs_keys:
skels = np.array(re.findall('[\d]+', cs_key)[:2], dtype=np.int)
ids.append(skels)
ids = np.unique(ids)
id_mapper = {}
for ii in range(len(ids)):
id_mapper[ids[ii]] = ii
axo_axo_dict = {}
size_dict = {}
coord_dict = {}
syn_type_dict = {}
partner_cell_type_dict = {}
partner_ax_dict = {}
cs_key_dict = {}
phil_dict = {}
for first_key in ids:
for second_key in ids:
size_dict[str(first_key) + "_" + str(second_key)] = []
coord_dict[str(first_key) + "_" + str(second_key)] = []
syn_type_dict[str(first_key) + "_" + str(second_key)] = []
partner_cell_type_dict[str(first_key) + "_" + str(second_key)] = []
partner_ax_dict[str(first_key) + "_" + str(second_key)] = []
cs_key_dict[str(first_key) + "_" + str(second_key)] = []
phil_dict[str(first_key) + "_" + str(second_key)] = {}
phil_dict[str(first_key) + "_" + str(second_key)]["total_size_area"] = 0
phil_dict[str(first_key) + "_" + str(second_key)]["sizes_area"] = []
phil_dict[str(first_key) + "_" + str(second_key)]["syn_types_prob"] = []
phil_dict[str(first_key) + "_" + str(second_key)]["syn_types_pred"] = []
phil_dict[str(first_key) + "_" + str(second_key)]["coords"] = []
phil_dict[str(first_key) + "_" + str(second_key)]['partner_axoness'] = []
phil_dict[str(first_key) + "_" + str(second_key)]['partner_cell_type'] = []
phil_dict[str(first_key) + "_" + str(second_key)]['cs_area'] = []
phil_dict[str(first_key) + "_" + str(second_key)]['total_cs_area'] = 0
phil_dict[str(first_key) + "_" + str(second_key)]['syn_pred'] = []
fails = []
syn_matrix = np.zeros([len(ids), len(ids)], dtype=np.int)
for ii, cs_key in enumerate(cs_keys):
skels = re.findall('[\d]+', cs_key)
# if cs_dict[cs_key]['syn_pred']:
ax_key = skels[2] + '_' + skels[0] + '_' + skels[1]
ax = ax_dict[ax_key]
this_keys = ax.keys()
if cs_dict[cs_key]['syn_pred']:
overlap_vx = cs_dict[cs_key]['overlap_vx'] / np.array([9,9,20])
sym_values = kd_sym.from_raw_cubes_to_list(overlap_vx)
asym_values = kd_asym.from_raw_cubes_to_list(overlap_vx)
this_type = float(np.sum(sym_values))/(np.sum(sym_values)+np.sum(asym_values))
overlap = cs_dict[cs_key]['overlap_area']/2
else:
this_type = 0
overlap = 0
for ii in range(2):
if int(ax[this_keys[ii]]) == 1 or not exclude_dendrodendro:
if int(ax[this_keys[(ii+1)%2]]) == 1:
if this_keys[ii] in axo_axo_dict:
axo_axo_dict[this_keys[ii]].append(this_keys[(ii+1)%2])
else:
axo_axo_dict[this_keys[ii]] = [this_keys[(ii+1)%2]]
syn_matrix[id_mapper[int(this_keys[ii])], id_mapper[int(this_keys[(ii+1)%2])]] += 1
dict_key = this_keys[ii] + "_" + this_keys[(ii+1)%2]
cs_key_dict[dict_key].append(cs_key)
coord_dict[dict_key].append(cs_dict[cs_key]['center_coord'])
syn_type_dict[dict_key].append(this_type)
size_dict[dict_key].append(overlap)
partner_ax_dict[dict_key].append(int(ax[this_keys[(ii+1)%2]]))
partner_cell_type_dict[dict_key].append(cell_type_dict[int(this_keys[(ii+1)%2])])
phil_dict[dict_key]["total_size_area"] += overlap
phil_dict[dict_key]["sizes_area"].append(overlap)
phil_dict[dict_key]["cs_area"].append(cs_dict[cs_key]['mean_cs_area']/2.e6)
phil_dict[dict_key]["total_cs_area"] += cs_dict[cs_key]['mean_cs_area']/2.e6
phil_dict[dict_key]["syn_types_prob"].append(this_type)
phil_dict[dict_key]["syn_types_pred"].append(int(this_type > type_threshold))
phil_dict[dict_key]["coords"].append(cs_dict[cs_key]['center_coord'])
phil_dict[dict_key]['partner_axoness'].append(int(ax[this_keys[(ii+1)%2]]))
phil_dict[dict_key]['partner_cell_type'].append(cell_type_dict[int(this_keys[(ii+1)%2])])
phil_dict[dict_key]['syn_pred'].append(cs_dict[cs_key]['syn_pred'])
if not exclude_dendrodendro:
suffix = "_no_exclusion" + suffix
np.save(save_folder+"/syn_matrix%s.npy" % suffix, syn_matrix)
with open(save_folder + "/id_mapper%s.pkl" % suffix, "w") as f:
pickle.dump(id_mapper, f)
with open(save_folder + "/size_dict%s.pkl" % suffix, "w") as f:
pickle.dump(size_dict, f)
with open(save_folder + "/syn_type_dict%s.pkl" % suffix, "w") as f:
pickle.dump(syn_type_dict, f)
with open(save_folder + "/coord_dict%s.pkl" % suffix, "w") as f:
pickle.dump(coord_dict, f)
with open(save_folder + "/cs_key_dict%s.pkl" % suffix, "w") as f:
pickle.dump(cs_key_dict, f)
with open(save_folder + "/partner_ax_dict%s.pkl" % suffix, "w") as f:
pickle.dump(partner_ax_dict, f)
with open(save_folder + "/partner_cell_type_dict%s.pkl" % suffix, "w") as f:
pickle.dump(partner_cell_type_dict, f)
with open(save_folder + "/connectivity_dict%s.pkl" % suffix, "w") as f:
pickle.dump(phil_dict, f)
[docs]def syn_type_majority_vote(wd, suffix=""):
"""Majority vote of synapse type
Parameters
----------
Returns
-------
int, int
"""
save_folder = wd + '/contactsites/'
with open(save_folder + "/connectivity_dict%s.pkl" % suffix, "r") as f:
phil_dict = pickle.load(f)
with open(save_folder + "/cs_dict%s.pkl" % suffix, "r") as f:
cs_dict = pickle.load(f)
maj_type_dict = {}
ids = []
cs_keys = cs_dict.keys()
for cs_key in cs_keys:
skels = np.array(re.findall('[\d]+', cs_key)[:2], dtype=np.int)
ids.append(skels)
ids = np.unique(ids)
id_mapper = {}
for ii in range(len(ids)):
id_mapper[ids[ii]] = ii
types = []
sizes = []
for first_key in ids:
types.append([])
sizes.append([])
for second_key in ids:
key = str(first_key) + "_" + str(second_key)
types[-1].append(phil_dict[key]["syn_types_pred"])
sizes[-1].append(phil_dict[key]["sizes_area"])
phil_dict[key]["syn_types_pred_maj"] = None
maj_type_dict[key] = []
maj_types = []
for ii in range(len(ids)):
this_types = []
this_sizes = []
for jj in range(len(ids)):
this_types += types[ii][jj]
this_sizes += sizes[ii][jj]
if len(this_types) > 0:
# maj_types.append(np.argmax(np.bincount(this_types)))
sum0 = np.sum(np.array(this_sizes)[np.array(this_types)==0])
sum1 = np.sum(np.array(this_sizes)[np.array(this_types)==1])
maj_types.append(int(sum0 < sum1))
else:
maj_types.append(-1)
change_count = 0
syn_count = 0
for ii, first_key in enumerate(ids):
for second_key in ids:
key = str(first_key) + "_" + str(second_key)
if maj_types[ii] != -1:
phil_dict[key]["syn_types_pred_maj"] = [maj_types[ii]]*len(phil_dict[key]["syn_types_pred"])
maj_type_dict[key] = [maj_types[ii]]*len(phil_dict[key]["syn_types_pred"])
change_count += np.abs(np.sum(phil_dict[key]["syn_types_pred"])-np.sum(maj_type_dict[key]))
syn_count += len(maj_type_dict[key])
with open(save_folder + "/maj_syn_types%s.pkl" % suffix, "w") as f:
pickle.dump(maj_type_dict, f)
with open(save_folder + "/connectivity_dict%s.pkl" % suffix, "w") as f:
pickle.dump(phil_dict, f)
[docs]def create_synapse_skeleton(wd):
cs_dir = wd + '/contactsites/'
cs_nodes, cs_feats = collect_contact_sites(cs_dir, only_sj=False)
clf_path = wd + '/models/rf_synapses/rfc_syn.pkl'
# print "\nUsing %s for synapse prediction." % clf_path
rfc_syn = joblib.load(clf_path)
syn_probas = rfc_syn.predict_proba(cs_feats)
synapse_skels = Skeleton()
for p in get_filepaths_from_dir(wd+'/tracings/'):
tracing = load_ordered_mapped_skeleton(p)[0]
synapse_skels.add_annotation(tracing)
synapse_skels.scaling = tracing.scaling
new_node_lst = [n for n in synapse_skels.getNodes()]
tracing_node_tree = spatial.cKDTree(np.array([n.getCoordinate() for n
in new_node_lst]))
for ii, syn_nodes in enumerate(cs_nodes):
skel_nodes = []
center_node = None
assert len(syn_nodes) == 3, "Wrong number of synapse nodes. It is %d" \
"instead of 3." % len(syn_nodes)
for n in syn_nodes:
if 'center' in n.getComment():
center_node = n
else:
skel_nodes.append(n)
new_syn_anno = SkeletonAnnotation()
new_syn_anno.scaling = synapse_skels.scaling
for k, v in center_node.data.iteritems():
if k in ["inVp", "node", "id", "inMag", "radius", "time", "x",
"y", "z", "edge", "comment", "content", "target"]:
continue
else:
new_syn_anno.data[k] = v
syn_pred = np.argmax(syn_probas[ii])
new_syn_anno.data['synapticCleft'] = bool(syn_pred)
new_syn_anno.data['syn_proba'] = bool(syn_probas[ii][1])
new_syn_anno.addNode(SkeletonNode().from_scratch(new_syn_anno,
center_node.getCoordinate()[0],
center_node.getCoordinate()[1],
center_node.getCoordinate()[2],
radius=center_node.data["radius"]))
new_syn_anno.data["axo-axonic"] = False
for n in skel_nodes:
n_in_anno = new_node_lst[tracing_node_tree.query_ball_point(
n.getCoordinate(), r=1)[0]]
if int(n.data['axoness_pred']):
#
if 'preSynapse' in new_syn_anno.data.keys():
new_syn_anno.data['postSynapse'] = n_in_anno.getID()
else:
new_syn_anno.data['preSynapse'] = n_in_anno.getID()
new_syn_anno.data["axo-axonic"] = True
else:
new_syn_anno.data['postSynapse'] = n_in_anno.getID()
synapse_skels.add_annotation(new_syn_anno)
synapse_skels.to_kzip(wd + "/contactsites/synapses_skel.k.zip")