Source code for syconn.processing.mapper

# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max-Planck-Institute for Medical Research, Heidelberg, Germany
# Authors: Sven Dorkenwald, Philipp Schubert, Jörgen Kornfeld

import copy
import gc
import socket
import time
from multiprocessing import Pool, Manager
from scipy import sparse
from shutil import copyfile
from sys import stdout
import networkx as nx
from axoness import majority_vote
from axoness import predict_axoness_from_nodes
from features import calc_prop_feat_dict
from learning_rfc import write_feat2csv
from spiness import assign_neck
from ..utils.datahandler import *
from ..utils.segmentationdataset import UltrastructuralDataset
from synapticity import parse_synfeature_from_node
from ..multi_proc.multi_proc_main import start_multiprocess
from knossos_utils import chunky
from knossos_utils.knossosdataset import KnossosDataset
from sklearn.externals import joblib
from knossos_utils.skeleton import Skeleton
from knossos_utils.skeleton import SkeletonNode
from knossos_utils.skeleton import from_skeleton_to_mergelist

try:
    from syconn.ray_casting.ray_casting_radius import ray_casting_radius
    ray_cast_avail = True
except ImportError:
    print("ray_casting_radius-module not imported")
    ray_cast_avail = False


[docs]class SkeletonMapper(object): """Class to handle mapping of cell objects (mitochondria, vesicle clouds, synaptic clefts) to tracings. Mapping parameters are saved as attributes. Attributes ---------- soma : SkeletonAnnotation Soma tracing old_anno : SkeletonAnnotation original tracing where estimated cell radius is saved at each node anno : SkeletonAnnotation interpolated tracing skeleton for hull calculation mitos/vc/sj : segmentationDataset Dictionaries in which mapped cell objects are saved ix : int mapped skeleton id write_obj_voxel : bool write object voxel to kzip as binary file """ def __init__(self, source, dh, ix=None, soma=None, context_range=6000): """ Parameters ---------- source: SkeletonAnnotation/str initial tracing object or path to .k.zip file dh: DataHandler DataHandler object ix : int index of tracings soma : SkeletonAnnotation if additional soma tracing is available context_range : int range for feature extraction of spiness and axoness """ self.context_range = context_range self.scaling = arr(dh.scaling, dtype=np.int) self._mem_path = dh.mem_path self._nb_cpus = dh.nb_cpus self._cset_path = dh.cs_path self._myelin_ds_path = dh.myelin_ds_path init_anno = SkeletonAnnotation() init_anno.scaling = [9, 9, 20] init_anno.appendComment('soma') self.soma = init_anno self.mitos = None self.vc = None self.sj = None if type(source) is str: self.ix = re.findall('[^/]+$', source)[0][:-4] self._path = source self.old_anno = load_ordered_mapped_skeleton(source)[0] self.anno, self.soma = load_ordered_mapped_skeleton(source)[0, 4] obj_dicts = load_objpkl_from_kzip(source) self.mitos = obj_dicts[0] self.vc = obj_dicts[1] self.sj = obj_dicts[2] elif isinstance(source, SkeletonAnnotation): self._path = None self.ix = ix self.old_anno = source self.anno = copy.deepcopy(source) if soma is not None: self.soma = soma else: raise RuntimeError('Datatype not understood in __init__' 'of SkeletonMapper.') # init mapping parameters self.detect_outlier = True self.neighbor_radius = None self.nb_neighbors = None self.nb_rays = None self.nb_hull_vox = None self.nb_voting_neighbors = None self.annotation_method = 'hull' self.kd_radius = 1200 self.thresh = 2.2 self.filter_size = [0, 0, 0] self.write_obj_voxel = False self.obj_min_votes = {'mitos': 235, 'vc': 191, 'sj': 346} self.mapping_info = {'sj': {}, 'mitos': {}, 'vc': {}} self._cset = None # stores hull and radius estimation of each ray and node self._hull_coords = None self._hull_normals = None self._skel_radius = None if hasattr(self.old_anno, 'hull_coords'): self._hull_coords = self.old_anno.hull_coords self._hull_normals = self.old_anno.hull_normals # init skeleton nodes self._property_features = None self.property_feat_names = None self.anno.interpolate_nodes() if len(self.soma.getNodes()) != 0: self.merge_soma_tracing() self._create_nodes() @property def cset(self): if self._cset is None: self._cset = chunky.load_dataset(self._cset_path) return self._cset @property def hull_coords(self): """Scaled hull coordinates of skeleton membrane Returns ------- np.array Coordinate each hull point """ if self._hull_coords is None: self.hull_sampling(thresh=self.thresh, detect_outlier=True, nb_rays=20, nb_neighbors=20, neighbor_radius=220, max_dist_mult=1.4) return self._hull_coords @property def hull_normals(self): """Normal for each hull point pointing outwards Returns ------- np.array Normal vector of each hull point pointing outwards """ if self._hull_normals is None: node_coords = arr(self.node_com) skel_tree = spatial.cKDTree(node_coords) hull_coords = self.hull_coords nearest_skel_nodes = arr(skel_tree.query(hull_coords, k=1)[1]) nearest_coords = node_coords[nearest_skel_nodes] dir_vecs = hull_coords - nearest_coords hull_normals = dir_vecs * (1 / np.linalg.norm(dir_vecs, axis=1))[:, None] if len(self.soma.getNodes()) != 0: soma_nodes = self.soma.getNodes() if len(soma_nodes) != 0: soma_coords_pure = arr([node.getCoordinate_scaled() for node in soma_nodes]) soma_node_ixs = arr(skel_tree.query(soma_coords_pure, k=1)[1]) com_soma = np.mean(soma_coords_pure, axis=0) dist, nn_ixs = skel_tree.query(hull_coords, k=1) for ii, ix in enumerate(nn_ixs): if ix in soma_node_ixs: hull_normals[ii] = hull_coords[ii] - com_soma self._hull_normals = hull_normals return self._hull_normals @property def skel_radius(self): """Radius of membrane at each skeleton node Returns ------- np.array cell radius at self.nodes """ if self._skel_radius is None: self.hull_sampling(thresh=self.thresh, detect_outlier=True, nb_rays=20, nb_neighbors=20, neighbor_radius=220, max_dist_mult=1.4) return self._skel_radius def _create_nodes(self): """Creates sorted node list and corresponding ID- and coordinate-list. Enables fast access to node information in same ordering. Setter for node_com, node_ids and nodes attribute. """ coords = [] ids = [] self.nodes = [] graph = su.annotation_to_nx_graph(self.anno) for i, node in enumerate(nx.dfs_preorder_nodes(graph)): coords.append(node.getCoordinate()*self.scaling) ids.append(node.ID) # contains mapped objects node.objects = {'vc': [], 'mitos': [], 'sj': []} self.nodes.append(node) self.node_com = arr(coords, dtype=np.int) self.node_ids = ids self.anno.nodes = set(self.nodes)
[docs] def merge_soma_tracing(self): # print "Merging soma (%d nodes) with original annotation." % \ # (len(self.soma.getNodes())) self.soma.interpolate_nodes(150) self.anno = su.merge_annotations(self.anno, self.soma)
[docs] def annotate_objects(self, dh, radius=1200, method='hull', thresh=2.2, filter_size=(0, 0, 0), nb_neighbors=20, nb_hull_vox=500, neighbor_radius=220, detect_outlier=True, nb_rays=20, nb_voting_neighbors=100, max_dist_mult=1.4): """Creates self.object with annotated objects as UltrastructuralDataset, where object is in {mitos, vc, sj} Parameters ---------- dh : DataHandler object containing SegmentationDataObjects mitos, vc, sj radius : int Radius in nm. Single integer if integer radius is for all objects the same. If list of three integer stick to ordering [mitos, vc, sj]. method : str Either 'kd' for fix radius or 'hull'/'supervoxel' if membrane is available. thresh : float Denotes the factor which is multiplied with the maximum membrane probability. The resulting value is used as threshold after which the membrane is assumed to be existant. filter_size : int List of integer for each object [mitos, vc, sj] nb_neighbors : int minimum number of neighbors needed during outlier detection for a single hull point to survive. nb_hull_vox : int Number of object hull voxels which are used to estimate spatial proximity to skeleton (inside or outside). neighbor_radius : int Radius (nm) of ball in which to look for supporting hull voxels. Used during outlier detection. detect_outlier : bool use outlier-detection if True. nb_rays : int Number of rays send at each skeleton node (multiplied by a factor of 5). Defines the angle between two rays (=360 / nb_rays) in the orthogonal plane. nb_voting_neighbors : int Number votes of skeleton hull voxels (membrane representation) for object-mapping. Used for vc and mitos during geometrical position estimation of object nodes. max_dist_mult : float Multiplier for radius to estimate maximal distance of hull points to source node. """ start = time.time() if radius == 0: return if np.isscalar(radius): radius = [radius] * 3 self.kd_radius = radius self.kd_radius = radius self.filter_size = arr(filter_size) self.annotation_method = method self.thresh = thresh self.detect_outlier = detect_outlier self.neighbor_radius = neighbor_radius self.nb_neighbors = nb_neighbors self.nb_rays = nb_rays self.nb_hull_vox = nb_hull_vox self.nb_voting_neighbors = nb_voting_neighbors if method == 'hull' and (self._hull_coords is None): self.hull_sampling(thresh, nb_rays, nb_neighbors, neighbor_radius, detect_outlier, max_dist_mult) if dh.mitos is not None: # initialize segmentationDatasets for mapped objects self.mitos = UltrastructuralDataset(dh.mitos.type, dh.mitos._rel_path_home, dh.mitos._path_to_chunk_dataset_head) # do the mapping node_ids = self.annotate_object(dh.mitos, radius[0], method, "mitos") self.mitos._node_ids = node_id2key(dh.mitos, node_ids, filter_size[0]) nb_obj_found = len(set([element for sublist in self.mitos._node_ids for element in sublist])) # print "[%s] Found %d %s using size filter %d" % \ # (self.ix, nb_obj_found, 'mitos', filter_size[0]) # store annotated segmentation objects for i in range(len(self.nodes)): mito_keys = self.mitos._node_ids[i] for k in mito_keys: self.nodes[i].objects['mitos'] = self.nodes[i].objects['mitos']\ + [dh.mitos.object_dict[k]] self.mitos.object_dict[k] = dh.mitos.object_dict[k] else: # print "Skipped mito-mapping." pass # same for vc if dh.vc is not None: self.vc = UltrastructuralDataset(dh.vc.type, dh.vc._rel_path_home, dh.vc._path_to_chunk_dataset_head) self.vc._node_ids = node_id2key(dh.vc, self.annotate_object( dh.vc, radius[1], method, "vc"), filter_size[1]) nb_obj_found = len(set([element for sublist in self.vc._node_ids for element in sublist])) # print "[%s] Found %d %s using size filter %d" % \ # (self.ix, nb_obj_found, 'vc', filter_size[1]) for i in range(len(self.nodes)): vc_keys = self.vc._node_ids[i] for k in vc_keys: self.nodes[i].objects['vc'] = self.nodes[i].objects['vc'] + \ [dh.vc.object_dict[k]] self.vc.object_dict[k] = dh.vc.object_dict[k] else: # print "Skipped vc-mapping." pass # and sj if dh.sj is not None: self.sj = UltrastructuralDataset(dh.sj.type, dh.sj._rel_path_home, dh.sj._path_to_chunk_dataset_head) self.sj._node_ids = node_id2key(dh.sj, self.annotate_object( dh.sj, radius[2], method, "sj"), filter_size[2]) nb_obj_found = len(set([element for sublist in self.sj._node_ids for element in sublist])) # print "[%s] Found %d %s using size filter %d" % \ # (self.ix, nb_obj_found, 'sj', filter_size[2]) for i in range(len(self.nodes)): sj_keys = self.sj._node_ids[i] for k in sj_keys: self.nodes[i].objects['sj'] = self.nodes[i].objects['sj'] + \ [dh.sj.object_dict[k]] self.sj.object_dict[k] = dh.sj.object_dict[k] else: # print "Skipped sj-mapping." pass if self._myelin_ds_path is not None: self.calc_myelinisation()
# print "--- Skeleton #%s fully annotated after %0.2f seconds with" \ # " '%s'-criterion" % (self.ix, time.time() - start, # self.annotation_method)
[docs] def annotate_object(self, objects, radius, method, objtype): """Redirects mapping task to desired method-function Parameters ---------- objects: UltrastructuralDataset radius: int Radius of kd-tree in units of nm. method: str either 'hull', 'kd' or 'supervoxel' objtype : string characterising object type Returns ------- list mapped object ID's """ if method == 'hull': node_ids = self._annotate_with_hull(objects, radius, objtype) elif method == 'supervoxel': node_ids = self._annotate_with_supervoxels(objects, radius, objtype) elif method == 'gt_sampling': node_ids = self._annotate_with_kdtree_gt_sampling(objects, radius) else: node_ids = self._annotate_with_kdtree(objects, radius) return list(node_ids)
def _annotate_with_kdtree_gt_sampling(self, data, radius): """Annotates objects to node if its representative coordinate (data) is within radius and samples dependent on the distance of each object such that the objecet distance distribution to its nearest node is nearly uniform (assume isotrope distribution at the beginning, i.e. ~ r**2). Parameters ---------- data : UltrastructuralDataset Dictionary of cell objects radius : int Cell radius at tracing nodes (in nm) Returns ------- list annotated objects per node, i.e. list of lists """ # print "Applying kd-tree with radius %s to %d nodes and %d objects" % \ # (radius, len(self.node_com), len(data.rep_coords)) coords = arr(data.rep_coords) * self.scaling tree = spatial.cKDTree(coords) # Get objects within constant radius for all nodes assert radius > 0, "Choose positive radius!" annotation_ids = tree.query_ball_point(self.node_com, radius) dists = [] for k, sublist in enumerate(annotation_ids): node_coord = self.node_com[k] dist_sub = np.linalg.norm(node_coord-coords[sublist], axis=1) dists.append(dist_sub) nb_objects = len(set([element for sublist in annotation_ids for element in sublist])) annotation_ids = arr([element for sublist in annotation_ids for element in sublist]) dists = arr([element for sublist in dists for element in sublist]) set_of_anno_ids = list(set(annotation_ids)) # print "Found %d objects before sampling." % nb_objects if nb_objects <= 400: return [[]]*(len(self.nodes)-1)+[set_of_anno_ids] todo_list = [[list(set_of_anno_ids[i::self._nb_cpus]), annotation_ids, dists] for i in xrange(self._nb_cpus)] pool = Pool(processes=self._nb_cpus) res = pool.map(helper_samllest_dist, todo_list) pool.close() pool.join() final_ids = arr([ix for sub_list in res for ix in sub_list[0]]) final_dists = arr([dist for sub_list in res for dist in sub_list[1]]) max_dist = np.max(final_dists) a = -0.95 * max_dist**2 w_func = lambda x: a*x**(-2)+1 weights = w_func(final_dists) normalization = np.sum(weights) weights /= normalization cum_weights = np.cumsum(weights) sample_ixs = [] cnt = 0 while len(sample_ixs) < 400: cnt += 1 rand_nb = np.random.rand(1) # find first occurance of entry with higher value than random number sample = np.argmax(cum_weights > rand_nb) if not sample in sample_ixs: sample_ixs.append(sample) if cnt > 50000: break return [[]]*(len(self.nodes)-1)+[list(final_ids[sample_ixs])] def _annotate_with_kdtree(self, data, radius): """Annotates objects to node if its representative coordinate (data) is within radius. Parameters ---------- data : UltrastructuralDataset Dictioanry of cell objects radius: int radii list (in nm) Returns ------- list of list of UltrastructuralDatasetObjects List with annotated objects per node, i.e. list of lists """ # print "Applying kd-tree with radius %s to %d nodes and %d objects" % \ # (radius, len(self.node_com), len(data.rep_coords)) coords = arr(data.rep_coords) * self.scaling tree = spatial.cKDTree(coords) annotation_ids = [] # Get objects within constant radius for all nodes assert radius > 0, "Choose positive radius!" for coord in self.node_com: annotation_ids.append(tree.query_ball_point(coord, radius)) nb_objects = len(set([element for sublist in annotation_ids for element in sublist])) return annotation_ids def _annotate_with_supervoxels(self, data, radius, objtype): """Annotates objects to skeleton if sufficient randomly selected object hull voxels are within supervoxels of this skeleton. radius. Parameters ---------- data : UltrastructuralDataset Dictioanry of cell objects radius: int radii list (in nm) objtype : str Cell object type (sj, vc, mito) Returns ------- list of list of UltrastructuralDatasetObjects List with annotated objects per node, i.e. list of lists """ nb_hull_vox = self.nb_hull_vox red_ids = self._annotate_with_kdtree(data, radius) red_ids = list(set([ix for sublist in red_ids for ix in sublist])) keys = arr(data.ids)[red_ids] curr_objects = [data.object_dict[key] for key in keys] pool = Pool(processes=self._nb_cpus) obj_voxel_coords = pool.map(helper_get_voxels, curr_objects) pool.close() pool.join() cset = self.cset obj_ids = [] rand_voxels = [] for i, key in enumerate(keys): curr_voxels = obj_voxel_coords[i] curr_obj_id = curr_objects[i].obj_id rand_ixs = np.random.randint(len(curr_voxels), size=nb_hull_vox) rand_voxels += curr_voxels[rand_ixs].tolist() obj_ids += [curr_obj_id] * nb_hull_vox mergelist_path = '/home/pschuber/data/gt/nml_obj/'+str(self.ix) mapped_obj_ids = arr(from_skeleton_to_mergelist( cset, self.anno, 'watershed_150_20_10_3_unique', 'labels', rand_voxels, obj_ids, nb_processes=self._nb_cpus, mergelist_path=mergelist_path)) annotation_ids_new = [] min_votes = self.obj_min_votes[objtype] for i in range(len(obj_voxel_coords)): ix = curr_objects[i].obj_id inside_votes = np.sum(mapped_obj_ids[mapped_obj_ids[:, 0]==ix][:,1]) if inside_votes >= min_votes: annotation_ids_new.append(red_ids[i]) self.mapping_info[objtype][ix] = inside_votes return [[]]*(len(self.nodes)-1) + [annotation_ids_new] def _annotate_with_hull(self, data, radius, objtype): """ Calculates a membrane representation via ray-castings. Each ray ends as a point after reaching a certain threshold. The resulting point cloud is used to determine in- and outlier coordinates of object hull voxels. If sufficient voxels are inside the cloud, the corresponding object is mapped to the skeleton. Parameters ---------- data : UltrastructuralDataset Dictioanry of cell objects radius: int radii list (in nm) objtype : str Cell object type (sj, vc, mito) Returns ------- list of list of UltrastructuralDatasetObjects List with annotated objects per node, i.e. list of lists """ sjtrue = (objtype == 'sj') max_sj_dist = 125. nb_voting_neighbors = self.nb_voting_neighbors nb_hull_vox = self.nb_hull_vox # print "Annotating with hull criterion. Using %d voting neighbors and" \ # " %d hull voxel." % (nb_voting_neighbors, nb_hull_vox) red_ids = self._annotate_with_kdtree(data, radius=radius) red_ids = list(set([id for sublist in red_ids for id in sublist])) points = self.hull_coords if len(points) == 0: return [[]] * len(self.nodes) tree = spatial.cKDTree(points) def check_hull_normals(obj_coord, hull_coords, dir_vecs): if not sjtrue: obj_coord = obj_coord[None, :] left_side = np.inner(obj_coord, dir_vecs) right_side = np.sum(dir_vecs * hull_coords, axis=1) sign = np.sign(left_side - right_side) return np.sum(sign) < 0 else: n_hullnodes_dists, n_hullnodes_ids = tree.query(obj_coord, k=20) mean_dists = np.mean(n_hullnodes_dists) return mean_dists < max_sj_dist # here annotation_ids_new contains only one node. keys = arr(data.ids)[red_ids] curr_objects = [data.object_dict[key] for key in keys] nb_cpus = max(cpu_count() / 2 - 2, 1) pool = Pool(processes=nb_cpus) curr_object_voxels = pool.map(helper_get_voxels, curr_objects) pool.close() pool.join() annotation_ids_new = [] min_votes = self.obj_min_votes[objtype] # print "Mapping objects '%s' using %d min. votes while asking %s obj. " \ # "hull voxel and using %d skeleton hull voxel to decide if in or" \ # " out." % (objtype, min_votes, nb_hull_vox, nb_voting_neighbors) for i in range(len(curr_object_voxels)): curr_obj_id = curr_objects[i].obj_id curr_voxels = curr_object_voxels[i] rand_ixs = np.random.randint(len(curr_voxels), size=nb_hull_vox) rand_voxels = curr_voxels[rand_ixs] * self.scaling _, skel_hull_ixs = tree.query(rand_voxels, k=nb_voting_neighbors) is_in_hull = 0 for ii, voxel in enumerate(rand_voxels): vx_near_cellixs = skel_hull_ixs[ii] is_in_hull += check_hull_normals(voxel, points[vx_near_cellixs], self.hull_normals[vx_near_cellixs]) if is_in_hull >= min_votes: annotation_ids_new.append(red_ids[i]) self.mapping_info[objtype][curr_obj_id] = is_in_hull return [[]]*(len(self.nodes)-1)+[annotation_ids_new]
[docs] def hull_sampling(self, thresh=2.2, nb_rays=20, nb_neighbors=20, neighbor_radius=220, detect_outlier=True, max_dist_mult=1.4): """ Calculates hull of tracing Parameters ---------- thresh : float factor of maximum occurring prediction value after which membrane is triggered active. nb_rays : int Number of rays send at each skeleton node (multiplied by a factor of 5). Defines the angle between two rays (=360 / nb_rays) in the orthogonal plane. nb_neighbors : int minimum number of neighbors needed during outlier detection for a single hull point to survive. neighbor_radius : int Radius of ball in which to look for supporting hull voxels. Used during outlier detection. detect_outlier : bool use outlier-detection if True. max_dist_mult : float Multiplier for radius to generate maximal distance of hull points to source node. Returns ------- numpy.array Average radius per node in (9,9,20) corrected units estimated by rays propagated through Membrane prediction until threshold reached. """ # print "Creating hull using scaling %s and threshold %0.2f with" \ # " outlier-detetion=%s" % (self.scaling, thresh*255.0, # str(detect_outlier)) mem_path = self._mem_path assert mem_path is not None, "Path to barrier must be given!" kd = KnossosDataset() kd.initialize_from_knossos_path(mem_path) used_node_ix = [] coms = [] mem_pos = np.array([0, 0, 0], dtype=np.int) mem_shape = kd.boundary # compute orthogonal plane to linear interpolated skeleton at each com orth_plane, skel_interp = get_orth_plane(self.node_com) # test and rewrite node positions of skeleton_data for i, com in enumerate(self.node_com): com = (np.array(com) - mem_pos) / self.scaling smaller_zero = np.any(com < 0) out_of_mem = np.any([com[k] > mem_shape[k] for k in range(3)]) if not smaller_zero or out_of_mem: coms.append(com) used_node_ix.append(i) used_node_ix = arr(used_node_ix) coms = arr(coms) nb_nodes2proc = len(coms) # print "Computing radii and point cloud for %d of %d nodes." % \ # (nb_nodes2proc, len(self.node_com)) # print "Total bounding box from %s to %s" % (str(np.min(coms, axis=0)), # str(np.max(coms, axis=0))) assert (len(orth_plane) == len(skel_interp)) and \ (len(skel_interp) == len(coms)) # Find necessary bounding boxes containing nodes and index to get # corresponding orth. plane and interp. boxes = [] box = [coms[0]] node_attr = [] ix = used_node_ix[0] # check if current node is end node current_node = self.nodes[ix] nb_edges = len(self.anno.getNodeReverseEdges(current_node)) + \ len(self.anno.getNodeEdges(current_node)) # store properties of nodes node_attr.append((skel_interp[ix], orth_plane[ix], ix, nb_edges<2)) for i in range(1, len(coms)): node_box_min = np.min(box+[coms[i]], axis=0) node_box_max = np.max(box+[coms[i]], axis=0) vol = np.prod(node_box_max - node_box_min) if vol > 0.5e7: boxes.append((arr(box), node_attr)) box = [] node_attr = [] box.append(coms[i]) ix = used_node_ix[i] current_node = self.nodes[ix] nb_edges = len(self.anno.getNodeReverseEdges(current_node)) + \ len(self.anno.getNodeEdges(current_node)) node_attr.append((skel_interp[ix], orth_plane[ix], ix, nb_edges<2)) boxes.append((arr(box), node_attr)) # print "Found %d different boxes." % len(boxes) # print "Using %d cpus." % self._nb_cpus pool = Pool(processes=self._nb_cpus) m = Manager() q = m.Queue() result = pool.map_async(get_radii_hull, [(box, q, self.scaling, mem_path, nb_rays, thresh, max_dist_mult) for box in boxes]) outputs = result.get() pool.close() pool.join() # print "\nFinished radius estimation and hull representation." ixs = [] radii = [] hull_list = [] vals = [] for cnt, el in enumerate(outputs): radii += list(el[0]) ixs += list(el[1]) hull_list += list(el[2]) vals += list(el[3]) # sort to match self.node_ids ordering ixs = arr(ixs) ixs_sorted = np.argsort(ixs) radii_sorted = arr(radii)[ixs_sorted] # check result if len(ixs) != len(self.node_com): raise RuntimeError("Tracing nodes during hull mapping missing!") elif not (ixs[ixs_sorted] == np.arange(len(self.node_com))).all(): raise RuntimeError("Original tracing node indices differ from " "returned indices in membrane radius result.") coord_list = [] for i, node in enumerate(self.nodes): node.setDataElem("radius", radii_sorted[i]) coord_list.append(node.getCoordinate()) node.ID = np.int(node.ID) big_skel_tree = spatial.cKDTree(coord_list) for node in self.old_anno.getNodes(): ix_node = big_skel_tree.query(node.getCoordinate(), 1)[1] node.setDataElem("radius", np.max((radii_sorted[ix_node], 1.))) self.anno.nodes = set(self.nodes) try: hull_coords = arr([pt for sub in hull_list for pt in sub])*self.scaling except ValueError: hull_coords = np.zeros((0, 3)) hull_coords = np.nan_to_num(hull_coords).astype(np.float32) if detect_outlier: hull_coords_ix = outlier_detection(hull_coords, nb_neighbors, neighbor_radius) hull_coords = hull_coords[hull_coords_ix] self._hull_coords = hull_coords self._skel_radius = radii_sorted
[docs] def calc_myelinisation(self): """Calculates myelinisation at each node and writes it to node.data["myelin_pred"] """ assert self._myelin_ds_path is not None, "Myelin dataset not found." test_box = (10, 10, 5) true_thresh = 100. j0126_myelin_inside_ds = KnossosDataset() j0126_myelin_inside_ds.initialize_from_knossos_path(self._myelin_ds_path) for n in self.old_anno.getNodes(): myelin_b = '0' test_vol = j0126_myelin_inside_ds.from_raw_cubes_to_matrix(test_box, n.getCoordinate(), show_progress=False) if np.mean(test_vol) > true_thresh: myelin_b = '1' n.data["myelin_pred"] = 1 else: n.data["myelin_pred"] = 0 node_comment = n.getComment() ax_ix = node_comment.find('myelin') if ax_ix == -1: n.appendComment('myelin'+myelin_b) else: help_list = list(node_comment) help_list[ax_ix+5] = myelin_b n.setComment("".join(help_list)) majority_vote(self.old_anno, property='myelin', max_dist=2000)
@property def property_features(self): """Getter of property features, calculates axoness/spiness features if necessary Returns ------- np.array, bool property features, if spiness feature are given """ if self._property_features is None: self._property_features, self.property_feat_names, \ self.spiness_given = calc_prop_feat_dict(self, self.context_range) return self._property_features
[docs] def predict_property(self, rf, prop, max_neck2endpoint_dist=3000, max_head2endpoint_dist=600): """Predict property (axoness, spiness) of tracings Parameters ---------- rf: RandomForestClassifier prop: str property name max_neck2endpoint_dist: int max_head2endpoint_dist: int """ property_feature = self.property_features[prop][:, 1:] if prop == 'axoness' and not self.spiness_given: raise RuntimeError("Spiness feature were not given " "during axoness prediction!") # print "Predicting %s using %d features." % \ # (prop, property_feature.shape[1]) proba = rf.predict_proba(property_feature) pred = rf.predict(property_feature) node_ids = self.property_features[prop][:, 0] for k, node_id in enumerate(node_ids): node = self.old_anno.getNodeByID(node_id) if prop == 'spiness' and 'axoness_pred' in node.data.keys(): if int(node.data['axoness_pred']) != 0: continue node_comment = node.getComment() ax_ix = node_comment.find(prop) node_pred = int(pred[k]) if ax_ix == -1: node.appendComment(prop+'%d' % node_pred) else: help_list = list(node_comment) help_list[ax_ix+7] = str(node_pred) node.setComment("".join(help_list)) for ii in range(proba.shape[1]): node.setDataElem(prop+'_proba%d' % ii, proba[k, ii]) node.setDataElem(prop+'_pred', node_pred) node.setDataElem('branch_dist', property_feature[k, -1]) node.setDataElem('end_dist', property_feature[k, -2]) # if prop == 'axoness': # majority_vote(self.old_anno, 'axoness', 25000) # pass if prop == 'spiness': assign_neck(self.old_anno, max_head2endpoint_dist=max_head2endpoint_dist, max_neck2endpoint_dist=max_neck2endpoint_dist)
[docs] def write2pkl(self, path): """Writes MappedSkeleton object to .pkl file. Path is extracted from dh._datapath and MappedSkeleton ID. Parameters ---------- path: str Path to kzip destination """ if os.path.isfile(path): copyfile(path, path[:-4]+'_old.pkl') # print ".pkl file already existed, moved old one to %s." %\ # (path[:-4]+'_old.pkl') with open(path, 'wb') as output: pickle.dump(self, output, -1)
# print "Skeleton %s saved successfully at %s." % (self.ix, path)
[docs] def write2kzip(self, path): """Writes interpolated skeleton (and annotated objects) to nml at path. If self.write_obj_voxel flag is True a .txt file containing all object voxel with id is written in k.zip Parameters ---------- path: str Path to kzip destination """ object_skel = Skeleton() obj_dict = {0: 'mitos', 1: 'vc', 2: 'sj'} re_process_skels = [] # store path to written files for kzip compression files = [] # print 'Writing kzip to %s. Writing object voxels=%s' \ # % (path, str(self.write_obj_voxel)) if '.k.zip' in path: path = path[:-5] + 'nml' elif '.zip' in path: path = path[:-3] + 'nml' for k, objects in enumerate([self.mitos, self.vc, self.sj]): object_annotation = SkeletonAnnotation() object_annotation.scaling = self.scaling object_annotation.appendComment(obj_dict[k]) if objects is None: continue object_voxel = [] object_voxel_id = [] if not np.all(arr(objects.sizes) >= self.filter_size[k]): print "Size filter does not work properly!" re_process_skels.append(id) for key in list(arr(objects.object_dict.keys())): obj = objects.object_dict[key] curr_obj_id = np.int(obj.obj_id) map_info = self.mapping_info[obj_dict[k]][curr_obj_id] node = SkeletonNode().from_scratch( object_annotation, obj.rep_coord[0], obj.rep_coord[1], obj.rep_coord[2], radius=(obj.size/4./np.pi*3)**(1/3.)) node.setPureComment(obj_dict[k]+'-'+str(curr_obj_id)+'_mi'+ str(map_info)) object_annotation.addNode(node) if self.write_obj_voxel: try: coords_to_add = list(obj.hull_voxels*self.scaling) except IOError, e: # print "Could not find hull vx of object %s" % str(key) # print e warnings.warn("Could not find hull voxel. " "Aborting %s." % path, DeprecationWarning) continue object_voxel += coords_to_add object_voxel_id += [np.int(obj.obj_id)] * len(coords_to_add) if self.write_obj_voxel: obj_hull_path = path[:-4] + '_' + obj_dict[k] + '.txt' obj_hull2text(arr(object_voxel_id), arr(object_voxel), obj_hull_path) files.append(obj_hull_path) files.append(obj_hull_path[:-4]+'_id.txt') obj_pkl_path = path[:-4] + '_' + obj_dict[k] + '.pkl' write_obj2pkl(objects, obj_pkl_path) files.append(obj_pkl_path) object_skel.add_annotation(object_annotation) self.old_anno.setComment("skeleton") object_skel.add_annotation(self.old_anno) if self.soma is not None: object_skel.add_annotation(self.soma) object_skel.toNml(path) files.append(path) if self._hull_coords is not None: hull_path = path[:-3] + 'xyz' hull2text(self.hull_coords, self.hull_normals, hull_path) files.append(hull_path) kzip_path = path[:-3] + "k.zip" try: for prop, prop_feat in self.property_features.iteritems(): feat_path = path[:-4] + '_%s_feat.csv' % prop write_feat2csv(feat_path, prop_feat, self.property_feat_names[prop]) files.append(feat_path) except IOError: pass for path_to_file in files: write_data2kzip(kzip_path, path_to_file)
# print "Mapped skeleton %s saved successfully at %s." % (self.ix, # kzip_path)
[docs] def get_plot_obj(self): """Extracts coordinates from annotated SegmentationObjects Returns ------- np.array object-voxels for each object """ assert self.annotation_method != None, "Objects not initialized!" voxel_list = [] for objects in [self.mitos, self.vc, self.sj]: voxel_list1 = [] for key in objects.object_dict.keys(): voxels = objects.object_dict[key].voxels if np.ndim(voxels) == 1: voxels = voxels[None, :] voxel_list1.append(voxels) voxel_list.append(voxel_list1) mito = arr([element for sublist in voxel_list[0] for element in sublist], dtype=np.uint32) * (self.scaling / 10) vc = arr([element for sublist in voxel_list[1] for element in sublist], dtype=np.uint32) * (self.scaling / 10) sj = arr([element for sublist in voxel_list[2] for element in sublist], dtype=np.uint32) * (self.scaling / 10) return mito, vc, sj
[docs]def node_id2key(segdataobject, node_ids, filter_size): """ Maps list indices in node_ids to keys of SegmentationObjects. Filters objects bigger than filter_size. Parameters ---------- :param segdataobject: UltrastructuralDataset of object type currently processed :param node_ids: List of list containing annotated object ids for each node :param filter_size: int minimum number of voxels of object Returns ------- list objects keys """ for node in node_ids: for obj in node: if segdataobject.sizes[obj] < filter_size: node[node.index(obj)] = -1 else: key = segdataobject.ids[obj] node[node.index(obj)] = key node_ids = [filter(lambda a: a != -1, node) for node in node_ids] return node_ids
[docs]def outlier_detection(point_list, min_num_neigh, radius): """Finds hull outlier using point density criterion Parameters ---------- point_list: list List of coordinates min_num_neigh: int Minimum number of neighbors, s.t. hull-point survives. radius: int Radius in nm to look for neighbors Returns ------- numpy.array Cleaned point cloud """ if len(point_list) == 0: return np.ones((len(point_list), )).astype(np.bool) # print "Starting outlier detection." if np.array(point_list).ndim != 2: points = np.array([point for sublist in point_list for point in sublist]) else: points = np.array(point_list) tree = spatial.cKDTree(points) nb_points = float(len(points)) # print "Old #points:\t%d" % nb_points new_points = np.ones((len(points), )).astype(np.bool) for ii, coord in enumerate(points): neighbors = tree.query_ball_point(coord, radius) num_neighbors = len(neighbors) new_points[ii] = num_neighbors>=min_num_neigh # print "Found %d outlier." % np.sum(~new_points) return np.array(new_points)
[docs]def get_radii_hull(args): if not ray_cast_avail: raise RuntimeError("ray_casting_radius-module needed for this") """Wrapper-function for point cloud extraction from membrane prediction. Gets a bounding box with nodes, loads the membrane prediction for these and then calculates the radius and hull at each skeleton node. """ # node attribute contains skel_interpolation, orthogonal plane and # bool if node is end node box, node_attr = args[0] q = args[1] scaling = args[2] mem_path = args[3] nb_rays = args[4] thresh_factor = args[5] max_dist_mult = args[6] kd = KnossosDataset() kd.initialize_from_knossos_path(mem_path) mem_shape = kd.boundary ray_buffer = arr([2000, 2000, 2000])/scaling prop_offset = np.max([np.min(box, axis=0) - ray_buffer, [0,0,0]], axis=0).astype(np.int) prop_size = np.min([np.max(box, axis=0) + ray_buffer, mem_shape], axis=0) - prop_offset assert np.prod(prop_size) < 10e9, "Bounding box too big!" mem = kd.from_raw_cubes_to_matrix(prop_size.astype(np.int32), prop_offset.astype(np.int32), show_progress=False) # thresholding membrane mem[mem <= 0.4*mem.max()] = 0 mem = mem.astype(np.uint8) threshold = mem.max() * thresh_factor # iterate over every node avg_radius_list = [] all_points = [] ids = [] val_list = [] todo_list = zip(list(box), [nb_rays] * len(box), list(node_attr)) for el in todo_list: radius, ix, membrane_points, vals = ray_casting_radius( el[0], el[1], el[2][0], el[2][1], el[2][2], scaling, threshold, prop_offset, mem, el[2][3], max_dist_mult) all_points.append(arr(membrane_points, dtype=np.float32)) avg_radius_list.append(radius) ids.append(ix) val_list.append(vals) q.put(ids) del mem return avg_radius_list, ids, all_points, val_list
[docs]def read_pair_cs(pair_path): """Helper function to collect pairwise contact site information. Extracts axoness prediction. Parameters ---------- pair_path : str path to pairwise contact site kzip Returns ------- SkeletonAnnotation annotation object without contact site hull voxel """ pairwise_anno = su.loadj0126NML(pair_path)[0] predict_axoness_from_nodes(pairwise_anno) new_anno = SkeletonAnnotation() new_anno.setComment(pairwise_anno.getComment()) for node in list(pairwise_anno.getNodes()): n_comment = node.getComment() if '_hull' in n_comment: continue new_anno.addNode(node) return new_anno
[docs]def prepare_syns_btw_annos(pairwise_paths, dest_path, max_hull_dist=60, concom_dist=300): """ Checks pairwise for contact sites between annotation objects found at paths in nml_list. Adds sj, vc and nearest skeleton nodes to found contact sites. Writes 'contact_sites.nml' to nml-path containing contact sites of all nml's. Parameters ---------- pairwise_paths : list of str List of pairwise paths to nml's dest_path : str Path to directory where to store result of synapse mapping max_hull_dist : float maximum distance between skeletons in nm concom_dist : float Maximum distance of connected components (nm) """ sname = socket.gethostname() if sname[:6] in ['soma01', 'soma02', 'soma03', 'soma04', 'soma05']: nb_cpus = np.min((2, cpu_count()-1)) else: nb_cpus = np.max([np.min((16, cpu_count()-1)), 1]) params = [(a, b, max_hull_dist, concom_dist, dest_path) for a, b in pairwise_paths] _ = start_multiprocess(syn_btw_anno_pair, params, nb_cpus=nb_cpus)
[docs]def similarity_check(skel_a, skel_b): """If absolute number of identical nodes is bigger then certain threshold return similar. Parameters ---------- skel_a : SkeletonAnnotation Skeleton a skel_b: SkeletonAnnotation Skeleton b Returns ------- bool skel_a and skel_b are similar """ a_coords = arr([node.getCoordinate() for node in skel_a.getNodes()]) * \ skel_a.scaling a_coords_sample = a_coords[np.random.randint(0, len(a_coords), 100)] b_coords = arr([node.getCoordinate() for node in skel_b.getNodes()]) * \ skel_b.scaling b_tree = spatial.cKDTree(b_coords) a_near = b_tree.query_ball_point(a_coords_sample, 1) nb_equal = len([id for sublist in a_near for id in sublist]) similar = nb_equal > 10 return similar
[docs]def similarity_check_star(params): """Helper function""" skel1 = load_ordered_mapped_skeleton(params[0])[0] skel2 = load_ordered_mapped_skeleton(params[1])[0] similar = similarity_check(skel1, skel2) return similar, params
[docs]def syn_btw_anno_pair(params): """ Get synapse information between two mapped annotation objects. Details are written to pairwise nml (all contact sites between pairs contained) and to nml for each contact site. Parameters ---------- params : list [path_a, path_b, max_hull_dist, concom_dist] path_a : str path to mapped annotation object path_b : str path to mapped annotation object max_hull_dist : float maximum distance between skeletons (nm) concom_dist : float maximum distance of connected components (nm) """ path_a, path_b, max_hull_dist, concom_dist, dest_path = params vx_overlap_dist = 80 max_vc_dist = 80 max_sj_dist = 40 min_cs_area = 0.05 * 1e6 # try: a = load_anno_list([path_a], load_mitos=False)[0] sj_dict = load_objpkl_from_kzip(path_a)[2].object_dict b = load_anno_list([path_b], load_mitos=False)[0] id2skel = lambda x: str(a[0].filename) if np.int(x) == 0 else\ str(b[0].filename) sj_dict.update(load_objpkl_from_kzip(path_b)[2].object_dict) scaling = a[0].scaling match = re.search(r'iter_0_(\d+)', a[0].filename) if match: a[0].filename = match.group(1) match = re.search(r'iter_0_(\d+)', b[0].filename) if match: b[0].filename = match.group(1) annotation_name = 'skel_' + a[0].filename + '_' + b[0].filename # DO similarity check and skip combination if true if a[0].filename == b[0].filename: # print "\n Skipping nearly identical skeletons: %s and %s, " \ # "because of identical ID.\n " % (a[0].filename, b[0].filename) return None if similarity_check(a[0], b[0]): # print "\n Skipping nearly identical skeletons: %s and %s, " \ # "because of similarity check.\n" % (a[0].filename, b[0].filename) return None csites, csite_ids = cs_btw_annos(a[0], b[0], max_hull_dist, concom_dist) if len(csites) == 0: return None # save information about pairwise csites in one nml pairwise_anno = SkeletonAnnotation() pairwise_anno.appendComment(annotation_name) pairwise_anno.scaling = scaling # get sj_objects with hull voxels if available sj_nodes = list(a[3].getNodes()) + list(b[3].getNodes()) sj_ids = [] for node in sj_nodes: global_sj_id = np.int(re.findall('sj-(\d+)', node.getComment())[0]) sj_ids.append(global_sj_id) sj_id_to_ix = {} for i, entry in enumerate(sj_ids): sj_id_to_ix[entry] = i if len(a[0].sj_hull_coords) != 0 or len(b[0].sj_hull_coords) != 0: sj_hull_voxel = np.concatenate((a[0].sj_hull_coords, b[0].sj_hull_coords), axis=0) sj_ids = np.concatenate((a[0].sj_hull_ids, b[0].sj_hull_ids), axis=0) sj_tree = spatial.cKDTree(sj_hull_voxel) else: sj_tree = None # get vc_objects with hull voxels if available vc_nodes = list(a[2].getNodes()) + list(b[2].getNodes()) vc_ids = [] for node in vc_nodes: global_vc_id = np.int(re.findall('vc-(\d+)', node.getComment())[0]) vc_ids.append(global_vc_id) vc_id_to_ix = {} for i, entry in enumerate(vc_ids): vc_id_to_ix[entry] = i if len(a[0].vc_hull_coords) != 0 or len(b[0].vc_hull_coords) != 0: vc_hull_voxel = np.concatenate((a[0].vc_hull_coords, b[0].vc_hull_coords), axis=0) vc_ids = np.concatenate((a[0].vc_hull_ids, b[0].vc_hull_ids), axis=0) vc_tree = spatial.cKDTree(vc_hull_voxel) else: vc_tree = None # iterate over all contact sites between skeletons, calc skeleton # kd-tree in advance a_skel_node_list = [node for node in a[0].getNodes()] a_skel_node_coords = arr([node.getCoordinate() for node in a_skel_node_list]) * arr(scaling) a_skel_node_tree = spatial.cKDTree(a_skel_node_coords) b_skel_node_list = [node for node in b[0].getNodes()] b_skel_node_coords = arr([node.getCoordinate() for node in b_skel_node_list]) * arr(scaling) b_skel_node_tree = spatial.cKDTree(b_skel_node_coords) for i, csite in enumerate(csites): vc_bool = False sj_bool = False # save information about one contact site in extra nml dependent on # occuring vc and sj (four different categories) contact_site_name = annotation_name+'_cs%d' % (i+1) contact_site_anno = SkeletonAnnotation() contact_site_anno.scaling = scaling curr_csite_ids = arr(csite_ids[i]) csite_name = 'cs'+str(i+1)+'_' csite_tree = spatial.cKDTree(csite) # get hull area csb_area = 0 csa_area = 0 csb_points = arr(csite)[curr_csite_ids] csa_points = arr(csite)[~curr_csite_ids] try: if np.sum(curr_csite_ids) > 3: csb_area = convex_hull_area(csb_points) except Exception, e: # print e # print "Could not calculate a_area!!!!" pass try: if np.sum(~curr_csite_ids) > 3: csa_area = convex_hull_area(csa_points) except Exception, e: # print e # print "Could not calculate b_area!!!!" pass for j, coord in enumerate(csite): coord_id = curr_csite_ids[j] node = SkeletonNode().from_scratch( contact_site_anno, coord[0]/scaling[0], coord[1]/scaling[1], coord[2]/scaling[2]) node.setPureComment(csite_name + id2skel(coord_id) + '_hull') pairwise_anno.addNode(node) mean_cs_area = np.mean((csb_area, csa_area)) if mean_cs_area < min_cs_area: # print "Skipping cs because of area:", mean_cs_area continue # get hull distance csa_tree = spatial.cKDTree(csa_points) dist, ixs = csa_tree.query(csb_points, 1) cs_dist = np.min(dist) # check vc and sj if sj_tree is not None: near_sj_ixs = sj_tree.query_ball_point(csite, max_sj_dist) near_sj_ids = list(set([sj_ids[id] for sublist in near_sj_ixs.tolist() for id in sublist])) else: near_sj_ids = [] overlap = 0 abs_ol = 0 overlap_cs = 0 overlap_area = 0 overlap_coords = np.array([]) for sj_id in near_sj_ids: sj_ix = sj_id_to_ix[sj_id] node = copy.copy(sj_nodes[sj_ix]) curr_sj_voxel = np.array(sj_dict[sj_id].voxels) * scaling overlap_new, overlap_cs_new, overlap_area_new,\ center_coord_new, overlap_coords_new = calc_overlap( csite, curr_sj_voxel, vx_overlap_dist) abs_ol_new = overlap_new*len(curr_sj_voxel) old_comment = node.getComment() node.setPureComment(csite_name + 'relol%0.3f_absol%d' % (overlap_new, abs_ol_new) + old_comment) contact_site_anno.addNode(node) pairwise_anno.addNode(node) if overlap_new > overlap: overlap = overlap_new abs_ol = abs_ol_new overlap_cs = overlap_cs_new overlap_area = overlap_area_new overlap_coords = overlap_coords_new if vc_tree is not None: near_vc_ixs = vc_tree.query_ball_point(csite, max_vc_dist) near_vc_ids = list(set([vc_ids[ix] for sublist in near_vc_ixs.tolist() for ix in sublist])) else: near_vc_ids = [] for vc_id in near_vc_ids: vc_ix = vc_id_to_ix[vc_id] node = copy.copy(vc_nodes[vc_ix]) dist, nearest_ix = csite_tree.query(node.getCoordinate(), 1) nearest_id = curr_csite_ids[nearest_ix] old_comment = node.getComment() node.setPureComment(csite_name + id2skel(nearest_id) + '_' + old_comment) contact_site_anno.addNode(node) pairwise_anno.addNode(node) # get center node (representative cs coordinate) cs_center = np.sum(csite, axis=0) / float(len(csite)) cs_center_ix = csite_tree.query(cs_center)[1] cs_center = csite[cs_center_ix] node = SkeletonNode().from_scratch(contact_site_anno, cs_center[0]/scaling[0], cs_center[1]/scaling[1], cs_center[2]/scaling[2]) comment = csite_name+'area%0.2f_dist%0.4f_center' % (mean_cs_area, cs_dist) node.data['adj_skel1'] = a[0].filename node.data['adj_skel2'] = b[0].filename if len(near_vc_ids) > 0: vc_bool = True comment += '_vc' pairwise_anno.setComment(annotation_name+'_syn_candidate') contact_site_name += '_vc' if len(near_sj_ids) > 0: sj_bool = True comment += '_sj_relol%0.3f_absol%d_csrelol%0.3f_areaol%0.3f' % \ (overlap, abs_ol, overlap_cs, overlap_area) contact_site_name += '_sj' np.save(dest_path + '/overlap_vx/' + contact_site_name + 'ol_vx.npy', overlap_coords) node.data['syn_feat'] = np.array([cs_dist, mean_cs_area, overlap_area, overlap, abs_ol, overlap_cs]) node.data['cs_dist'] = cs_dist node.data['mean_cs_area'] = mean_cs_area node.data['overlap_area'] = overlap_area node.data['overlap'] = overlap node.data['abs_ol'] = abs_ol node.data['overlap_cs'] = overlap_cs node.data['cs_name'] = contact_site_name node.setPureComment(comment) contact_site_anno.addNode(node) pairwise_anno.addNode(node) # get closest skeleton nodes dist, a_nearest_sn_ixs = a_skel_node_tree.query(cs_center, 2) a_source_node = a_skel_node_list[a_nearest_sn_ixs[0]] a_nn = max_nodes_in_path(a[0], a_source_node, 100) # get nearest node to source node of skeleton b and average radius a_source_node_nn = a_skel_node_list[a_nearest_sn_ixs[1]] mean_radius = np.mean([a_source_node.data['radius'], a_source_node_nn.data['radius']]) for j, node in enumerate(a_nn): if j == 0: comment = csite_name+a[0].filename+'_skelnode'+\ '_area %0.2f' % (csa_area) node.data['head_diameter'] = mean_radius * 2 node.data['skel_id'] = int(a[0].filename) else: comment = csite_name+a[0].filename+'_skelnode%d' % j curr_node = copy.copy(node) curr_node.appendComment(comment) contact_site_anno.addNode(curr_node) pairwise_anno.addNode(curr_node) for j, node in enumerate(a_nn): try: target_node = list(a[0].getNodeEdges(node))[0] contact_site_anno.addEdge(node, target_node) pairwise_anno.addEdge(node, target_node) except (KeyError, IndexError): pass dist, b_nearest_sn_ixs = b_skel_node_tree.query(cs_center, 2) b_source_node = b_skel_node_list[b_nearest_sn_ixs[0]] b_nn = max_nodes_in_path(b[0], b_source_node, 100) # get nearest node to source node of skeleton b and average radius b_source_node_nn = b_skel_node_list[b_nearest_sn_ixs[1]] mean_radius = np.mean([b_source_node.data['radius'], b_source_node_nn.data['radius']]) for j, node in enumerate(b_nn): if j == 0: comment = csite_name+b[0].filename+'_skelnode'+'_area %0.2f'\ % (csb_area) node.data['head_diameter'] = mean_radius * 2 node.data['skel_id'] = int(b[0].filename) else: comment = csite_name+b[0].filename+'_skelnode%d' % j curr_node = copy.copy(node) curr_node.appendComment(comment) contact_site_anno.addNode(curr_node) pairwise_anno.addNode(curr_node) for j, node in enumerate(b_nn): try: target_node = list(b[0].getNodeEdges(node))[0] contact_site_anno.addEdge(node, target_node) pairwise_anno.addEdge(node, target_node) except (KeyError, IndexError): pass contact_site_anno.setComment(contact_site_name) dummy_skel = Skeleton() dummy_skel.add_annotation(contact_site_anno) cs_destpath = dest_path if vc_bool and sj_bool: cs_destpath += 'cs_vc_sj/' elif vc_bool and not sj_bool: cs_destpath += 'cs_vc/' elif not vc_bool and sj_bool: cs_destpath += 'cs_sj/' elif not vc_bool and not sj_bool: cs_destpath += 'cs/' dummy_skel.toNml(cs_destpath+contact_site_name+'.nml') if len(pairwise_anno.getNodes()) == 0: # print "Did not found any node in annotation object." return None pairwise_anno.appendComment('%dcs' % len(csites)) dummy_skel = Skeleton() dummy_skel.add_annotation(pairwise_anno) dummy_skel.toNml(dest_path+'pairwise/'+annotation_name+'.nml') del dummy_skel gc.collect() return 0
[docs]def max_nodes_in_path(anno, source_node, max_number): """Find specified number of nodes along skeleton from source node (BFS). Parameters ---------- anno: SkeletonAnnotation tracing on which to search source_node: SkeletonNode Starting node max_number: int Maximum number of nodes Returns ------- list of SkeletonNodes Tracing nodes up to certain distance from source node """ skel_graph = su.annotation_to_nx_graph(anno) reachable_nodes = [source_node] for edge in nx.bfs_edges(skel_graph, source_node): next_node = edge[1] reachable_nodes.append(next_node) if len(reachable_nodes) >= max_number: break return reachable_nodes
[docs]def feature_valid_syns(cs_dir, only_sj=True, only_syn=True, all_contacts=False): """Returns the features of valid synapses predicted by synapse rfc Parameters ---------- cs_dir : str Path to computed contact sites. only_sj : bool Return feature of all contact sites with mapped sj. only_syn : bool Returns feature only if synapse was predicted all_contacts : bool Use all contact sites for feature extraction Returns ------- np.array (n x f), np.array (n x 1), np.array (n x 1) features, array of contact site IDS, boolean array of synapse prediction """ clf_path = cs_dir + '/../models/rf_synapses/rfc_syn.pkl' cs_fpaths = [] if only_sj: search_folder = ['cs_sj/', 'cs_vc_sj/'] elif all_contacts: search_folder = ['cs_sj/', 'cs_vc_sj/', 'cs/', 'cs_vc/'] else: search_folder = ['cs/', 'cs_vc/'] sample_list_len = [] for k, ending in enumerate(search_folder): curr_dir = cs_dir+ending curr_fpaths = get_filepaths_from_dir(curr_dir, ending='nml') cs_fpaths += curr_fpaths sample_list_len.append(len(curr_fpaths)) if len(cs_fpaths) == 0: return np.zeros(0, ), np.zeros(0, ), np.zeros(0, ).astype(np.bool) nb_cpus = cpu_count() pool = Pool(processes=nb_cpus) m = Manager() q = m.Queue() params = [(sample, q) for sample in cs_fpaths] result = pool.map_async(readout_cs_info, params) res = result.get() pool.close() pool.join() res = arr(res) non_instances = arr([isinstance(el, np.ndarray) for el in res[:,0]]) cs_infos = res[non_instances] features = arr([el.astype(np.float) for el in cs_infos[:,0]], dtype=np.float) if not only_sj or not only_syn or all_contacts: syn_pred = np.ones((len(features), )) else: rfc_syn = joblib.load(clf_path) syn_pred = rfc_syn.predict(features) axoness_info = cs_infos[:, 1] return features, axoness_info, syn_pred.astype(np.bool)
[docs]def readout_cs_info(args): """Helper function of feature_valid_syns Parameters ---------- args: tuple path to file and queue Returns ------- np.array, str synapse features, contact site ID """ cspath, q = args feat = None if q is not None: q.put(1) cs = read_pair_cs(cspath) for node in cs.getNodes(): if 'center' in node.getComment(): feat = parse_synfeature_from_node(node) break return feat, cs.getComment()
[docs]def calc_syn_dict(features, axoness_info, get_all=False): """ Creates dictionary of synapses. Keys are ids of pre cells and values are dictionaries of corresponding synapses with post cell ids. Parameters ---------- features: np.array synapse feature axoness_info: np.array string containing axoness information of cells get_all : bool collect all contact sites Returns ------- np.array, np.array, dict, np.array, np.array, dict synapse features, axoness information, connectivity,\ post synaptic cell ids, synapse predictions, axoness """ total_size = float(len(axoness_info)) if total_size == 0: print "No synapse dict to create." return np.zeros(0, ), np.zeros(0,), {}, np.zeros(0, ), np.zeros(0, ), {} ax_ax_cnt = 0 den_den_cnt = 0 all_post_ids = [] pre_dict = {} val_syn_ixs = [] valid_syn_array = np.ones_like(features) axoness_dict = {} for k, ax_info in enumerate(axoness_info): stdout.write("\r%0.2f" % (k / total_size)) stdout.flush() cell1, cell2 = re.findall('(\d+)axoness(\-?\d+)', ax_info) cs_nb = re.findall('cs(\d+)', ax_info)[0] cell_ids = arr([cell1[0], cell2[0]], dtype=np.int) cell_axoness = arr([cell1[1], cell2[1]], dtype=np.int) axoness_entry = {str(cell1[0]): cell1[1], str(cell2[0]): cell2[1]} axoness_dict[cs_nb + '_' + cell1[0] + '_' + cell2[0]] = axoness_entry if cell_axoness[0] == cell_axoness[1]: if cell_axoness[0] == 1: ax_ax_cnt += 1 else: den_den_cnt += 1 valid_syn_array[k] = 0 if not get_all: continue val_syn_ixs.append(k) pre_ix = np.argmax(cell_axoness) pre_id = cell_ids[pre_ix] if pre_ix == 0: post_ix = 1 else: post_ix = 0 post_id = cell_ids[post_ix] all_post_ids += [post_id] syn_dict = {} syn_dict['post_id'] = post_id syn_dict['post_axoness'] = cell_axoness[post_ix] syn_dict['cs_area'] = features[k, 1] syn_dict['sj_size_abs'] = features[k, 2] syn_dict['sj_size_rel'] = features[k, 3] if pre_id in pre_dict.keys(): syns = pre_dict[pre_id] if post_id in syns.keys(): syns[post_id]['cs_area'] += features[k, 1] syns[post_id]['sj_size_abs'] += features[k, 2] else: syns[post_id] = syn_dict else: syns = {} syns[post_id] = syn_dict pre_dict[pre_id] = syns return features[val_syn_ixs], axoness_info[val_syn_ixs], pre_dict,\ all_post_ids, valid_syn_array, axoness_dict
[docs]def cs_btw_annos(anno_a, anno_b, max_hull_dist, concom_dist): """ Computes contact sites between two annotation objects and returns hull points of both skeletons near contact site. Parameters ---------- anno_a : SkeletonAnnotation Annotation object A anno_b : SkeletonAnnotation Annotation object B max_hull_dist : int Maximum distance between skeletons in nm concom_dist : int maximum distance of connected components (nm) Returns ------- list List of hull coordinates for each contact site """ hull_a = anno_a.hull_coords hull_b = anno_b.hull_coords if len(hull_a) == 0 or len(hull_b) == 0: # print "One skeleton hull is empty!! Skipping pair." return [], [] tree_a = spatial.cKDTree(hull_a) tree_b = spatial.cKDTree(hull_b) contact_ids = tree_a.query_ball_tree(tree_b, max_hull_dist) num_neighbours = arr([len(sublist) for sublist in contact_ids]) contact_coords_a = hull_a[num_neighbours>0] contact_ids_b = set([id for sublist in contact_ids for id in sublist]) contact_coords_b = hull_b[list(contact_ids_b)] if contact_coords_a.ndim == 1: contact_coords_a = contact_coords_a[None, :] if contact_coords_b.ndim == 1: contact_coords_b = contact_coords_a[None, :] contact_coords = np.concatenate((contact_coords_a, contact_coords_b), axis=0) if contact_coords.shape[0] >= 0.95*(len(hull_a)+len(hull_b)): print "Found too many contact_coords (proportion of total hull voxel:" \ "%0.3f) assuming similar skeleton comparison between skeleton" \ "%s and %s. " \ % (contact_coords.shape[0] / float(len(hull_a)+len(hull_b)), anno_a.filename, anno_b.filename) return [], [] if contact_coords.shape[0] == 0: return [], [] pdists = spatial.distance.pdist(contact_coords) pdists[pdists > concom_dist] = 0 pdists = sparse.csr_matrix(spatial.distance.squareform(pdists)) nb_cc, labels = sparse.csgraph.connected_components(pdists) cs_list = [] for label in set(labels): curr_label_ixs = labels == label cs_list.append(contact_coords[curr_label_ixs]) # extract annotation ids tree_a_b = spatial.cKDTree(np.concatenate((hull_a, hull_b), axis=0)) contact_site_coord_ids = [] min_id_b = len(hull_a) for cs in cs_list: # map the contact site to each coordinate ids_temp = tree_a_b.query(cs, 1)[1] in_b = arr(ids_temp>=min_id_b, dtype=np.bool) contact_site_coord_ids.append(in_b) return cs_list, contact_site_coord_ids
[docs]def translate_dense_tracings(): fpaths = get_filepaths_from_dir('/lustre/pschuber/dense_vol_tracings/source/') for p in fpaths: s = load_ordered_mapped_skeleton(p)[0] for n in s.getNodes(): n.setCoordinate(n.getCoordinate()-np.array([3540, 4843, 2418])) file_name = os.path.basename(p) dummy_skel = Skeleton() dummy_skel.add_annotation(s) dummy_skel.to_kzip("/lustre/pschuber/SyConnDenseCube/tracings/" + file_name)