Source code for syconn.conmatrix

# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max-Planck-Institute for Medical Research, Heidelberg, Germany
# Authors: Sven Dorkenwald, Philipp Schubert, Joergen Kornfeld

import matplotlib
matplotlib.use('Agg')
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import re
import os
from matplotlib import gridspec
from matplotlib import pyplot as pp
from mpl_toolkits.axes_grid1 import make_axes_locatable
from numpy import array as arr

from syconn.contactsites import conn_dict_wrapper
from syconn.processing.cell_types import load_celltype_probas, \
    get_id_dict_from_skel_ids
from syconn.processing.learning_rfc import cell_classification
from syconn.utils.datahandler import load_pkl2obj


[docs]def type_sorted_wiring(wd, confidence_lvl=0.3, binary=False, max_syn_size=0.4, syn_only=True, big_entries=True): """Calculate wiring of consensus skeletons sorted by predicted cell type classification and axoness prediction Parameters ---------- wd : str confidence_lvl : float minimum probability of cell type prediction to keep cell binary : bool if True existence of synapse is weighted by 1, else 0 max_syn_size : float maximum cumulated synapse size shown in plot syn_only : bool take only contact sites with synapse classification result of 1 into account big_entries : bool artificially increase pixel size from 1 to 3 for better visualization """ if not os.path.exists(wd + "/figures/"): os.makedirs(wd + "/figures/") supp = "" skeleton_ids, cell_type_probas = load_celltype_probas(wd) cell_type_pred_dict = load_pkl2obj(wd + '/neurons/celltype_pred_dict.pkl') bool_arr = np.zeros(len(skeleton_ids)) # remove all skeletons under confidence level for k, probas in enumerate(cell_type_probas): if np.max(probas) > confidence_lvl: bool_arr[k] = 1 bool_arr = bool_arr.astype(np.bool) skeleton_ids = skeleton_ids[bool_arr].tolist() # print "%d/%d are under confidence level %0.2f and being removed." % \ # (np.sum(~bool_arr), len(skeleton_ids), confidence_lvl) if not os.path.isfile(wd + '/contactsites/connectivity_dict.pkl'): conn_dict_wrapper(wd, all=False) conn_dict_wrapper(wd, all=True) # create matrix if syn_only: syn_props = load_pkl2obj(wd + '/contactsites/connectivity_dict.pkl') area_key = 'sizes_area' total_area_key = 'total_size_area' syn_pred_key = 'syn_types_pred_maj' else: syn_props = load_pkl2obj(wd + '/contactsites/connectivity_dict_all.pkl') area_key = 'cs_area' total_area_key = 'total_cs_area' syn_pred_key = 'syn_types_pred' dendrite_ids = set() pure_dendrite_ids = set() axon_ids = set() pure_axon_ids = set() dendrite_multiple_syns_ids = set() axon_multiple_syns_ids = set() axon_axon_ids = set() axon_axon_pairs = [] for pair_name, pair in syn_props.iteritems(): # if pair[total_area_key] != 0: skel_id1, skel_id2 = re.findall('(\d+)_(\d+)', pair_name)[0] skel_id1 = int(skel_id1) skel_id2 = int(skel_id2) if skel_id1 not in skeleton_ids or skel_id2 not in skeleton_ids: continue axon_ids.add(skel_id1) dendrite_ids.add(skel_id2) pure_axon_ids.add(skel_id1) pure_dendrite_ids.add(skel_id2) if len(pair[area_key]) > 1: dendrite_multiple_syns_ids.add(skel_id2) axon_multiple_syns_ids.add(skel_id1) if np.any(np.array(pair['partner_axoness']) == 1): axon_axon_ids.add(skel_id1) axon_axon_ids.add(skel_id2) axon_axon_pairs.append((skel_id1, skel_id2)) all_used_ids = set() all_used_ids.update(axon_axon_ids) all_used_ids.update(axon_ids) all_used_ids.update(dendrite_ids) # print "%d/%d cells have no connection between each other." %\ # (len(skeleton_ids) - len(all_used_ids), len(skeleton_ids)) # print "Using %d unique cells in wiring." % len(all_used_ids) axon_axon_ids = np.array(list(axon_axon_ids)) axon_ids = np.array(list(axon_ids)) pure_axon_ids = np.array(list(axon_ids)) dendrite_ids = np.array(list(dendrite_ids)) pure_dendrite_ids = np.array(list(dendrite_ids)) axon_multiple_syns_ids = np.array(list(axon_multiple_syns_ids)) dendrite_multiple_syns_ids = np.array(list(dendrite_multiple_syns_ids)) # sort dendrites, axons using its type prediction. order is determined by # dictionaries get_id_dict_from_skel_ids dendrite_pred = np.array([cell_type_pred_dict[den_id] for den_id in dendrite_ids]) type_sorted_ixs = np.argsort(dendrite_pred, kind='mergesort') dendrite_pred = dendrite_pred[type_sorted_ixs] dendrite_ids = dendrite_ids[type_sorted_ixs] pure_dendrite_pred = np.array([cell_type_pred_dict[den_id] for den_id in pure_dendrite_ids]) type_sorted_ixs = np.argsort(pure_dendrite_pred, kind='mergesort') pure_dendrite_pred = pure_dendrite_pred[type_sorted_ixs] pure_dendrite_ids = pure_dendrite_ids[type_sorted_ixs] axon_pred = np.array([cell_type_pred_dict[den_id] for den_id in axon_ids]) type_sorted_ixs = np.argsort(axon_pred, kind='mergesort') axon_pred = axon_pred[type_sorted_ixs] axon_ids = axon_ids[type_sorted_ixs] pure_axon_pred = np.array([cell_type_pred_dict[den_id] for den_id in pure_axon_ids]) type_sorted_ixs = np.argsort(pure_axon_pred, kind='mergesort') pure_axon_pred = pure_axon_pred[type_sorted_ixs] pure_axon_ids = pure_axon_ids[type_sorted_ixs] ax_ax_pred = np.array([cell_type_pred_dict[ax_id] for ax_id in axon_axon_ids]) type_sorted_ixs = np.argsort(ax_ax_pred, kind='mergesort') ax_ax_pred = ax_ax_pred[type_sorted_ixs] ax_multi_syn_pred = np.array([cell_type_pred_dict[mult_syn_skel_id] for mult_syn_skel_id in axon_multiple_syns_ids]) type_sorted_ixs = np.argsort(ax_multi_syn_pred, kind='mergesort') ax_multi_syn_pred = ax_multi_syn_pred[type_sorted_ixs] den_multi_syn_pred = np.array([cell_type_pred_dict[mult_syn_skel_id] for mult_syn_skel_id in dendrite_multiple_syns_ids]) type_sorted_ixs = np.argsort(den_multi_syn_pred, kind='mergesort') den_multi_syn_pred = den_multi_syn_pred[type_sorted_ixs] den_id_dict, rev_den_id_dict = get_id_dict_from_skel_ids(dendrite_ids) ax_id_dict, rev_ax_id_dict = get_id_dict_from_skel_ids(axon_ids) # build reduced matrix wiring = np.zeros((len(dendrite_ids), len(axon_ids), 3), dtype=np.float) wiring_multiple_syns = np.zeros((len(dendrite_ids), len(axon_ids), 3), dtype=np.float) cum_wiring = np.zeros((4, 4, 3)) cum_wiring_axon = np.zeros((4, 4, 3)) wiring_axoness = np.zeros((len(dendrite_ids), len(axon_ids), 3), dtype=np.float) for pair_name, pair in syn_props.iteritems(): if pair[total_area_key] != 0: synapse_type = cell_classification(pair[syn_pred_key]) skel_id1, skel_id2 = re.findall('(\d+)_(\d+)', pair_name)[0] skel_id1 = int(skel_id1) skel_id2 = int(skel_id2) if skel_id1 not in skeleton_ids or skel_id2 not in skeleton_ids: continue dendrite_pos = den_id_dict[skel_id2] axon_pos = ax_id_dict[skel_id1] cum_den_pos = cell_type_pred_dict[skel_id2] cum_ax_pos = cell_type_pred_dict[skel_id1] if np.any(np.array(pair['partner_axoness']) == 1): indiv_syn_sizes = np.array(pair[area_key]) indiv_syn_axoness = np.array(pair['partner_axoness']) == 1 axon_axon_syn_size = indiv_syn_sizes[indiv_syn_axoness] pair[area_key] = indiv_syn_sizes[~indiv_syn_axoness] pair[total_area_key] = np.sum(pair[area_key]) y_axon_axon = np.sum(axon_axon_syn_size) y_axon_axon_display = np.min((y_axon_axon, max_syn_size)) if binary: y_axon_axon = 1. y_axon_axon_display = 1. if synapse_type == 0: y_entry = np.array([0, y_axon_axon, 0]) cum_wiring_axon[cum_den_pos, cum_ax_pos] += y_entry y_entry = np.array([0, y_axon_axon_display, 0]) else: y_entry = np.array([0, 0, y_axon_axon]) cum_wiring_axon[cum_den_pos, cum_ax_pos] += y_entry y_entry = np.array([0, 0, y_axon_axon_display]) wiring_axoness[dendrite_pos, axon_pos] = y_entry if pair[total_area_key] == 0: continue y = pair[total_area_key] y_display = np.min((y, max_syn_size)) if len(pair[area_key]) > 1: if synapse_type == 0: y_entry = np.array([0, y_display, 0]) else: y_entry = np.array([0, 0, y_display]) wiring_multiple_syns[dendrite_pos, axon_pos] = y_entry if binary: y = 1. y_display = 1. if synapse_type == 0: y_entry = np.array([0, y, 0]) cum_wiring[cum_den_pos, cum_ax_pos] += y_entry y_entry = np.array([0, y_display, 0]) else: y_entry = np.array([0, 0, y]) cum_wiring[cum_den_pos, cum_ax_pos] += y_entry y_entry = np.array([0, 0, y_display]) wiring[dendrite_pos, axon_pos] = y_entry max_val = [np.max(wiring[..., 1]), np.max(wiring[..., 2])] max_val_axon_axon = [np.max(wiring_axoness[..., 1]), np.max(wiring_axoness[..., 2])] ax_borders = class_ranges(axon_pred)[1:-1] den_borders = class_ranges(dendrite_pred)[1:-1] maj_vote = get_cell_majority_synsign(cum_wiring) maj_vote_axoness = get_cell_majority_synsign(cum_wiring_axon) # normalize each channel if not binary: wiring[:, :, 1] /= max_val[0] wiring[:, :, 2] /= max_val[1] wiring_axoness[:, :, 1] /= max_val_axon_axon[0] wiring_axoness[:, :, 2] /= max_val_axon_axon[1] wiring_multiple_syns[:, :, 1] /= max_val[0] wiring_multiple_syns[:, :, 2] /= max_val[1] if not syn_only: supp += '_CS' plot_wiring_cs(wiring, den_borders, ax_borders, confidence_lvl, binary, wd, add_fname=supp) plot_wiring_cs(wiring_axoness, den_borders, ax_borders, confidence_lvl, binary, wd, add_fname=supp+'_axon_axon') plot_wiring_cum_cs(cum_wiring, class_ranges(pure_dendrite_pred), class_ranges(pure_axon_pred), confidence_lvl, binary, wd, add_fname=supp) plot_wiring_cum_cs(cum_wiring_axon, class_ranges(ax_ax_pred), class_ranges(ax_ax_pred), confidence_lvl, binary, wd, add_fname=supp+'_axon_axon') plot_wiring_cs(wiring_multiple_syns, den_borders, ax_borders, confidence_lvl, binary, wd, add_fname=supp+'_multiple_syns') else: supp += '' plot_wiring(wiring, den_borders, ax_borders, max_val, confidence_lvl, binary, wd, add_fname=supp, big_entries=big_entries, maj_vote=maj_vote) # plot_wiring(wiring_axoness, den_borders, ax_borders, max_val_axon_axon, # confidence_lvl, binary, wd, add_fname=supp+'_axon_axon', # big_entries=big_entries, maj_vote=maj_vote_axoness) plot_wiring_cum(cum_wiring, class_ranges(dendrite_pred), class_ranges(axon_pred), confidence_lvl, max_val, binary, wd, add_fname=supp, maj_vote=maj_vote) plot_wiring(wiring_multiple_syns, den_borders, ax_borders, max_val, confidence_lvl, binary, wd, add_fname=supp+'_multiple_syns', big_entries=big_entries, maj_vote=maj_vote) return cum_wiring
[docs]def get_cell_majority_synsign(avg_wiring): """Calculates majority synaptic sign of rows in average wiring Parameters ---------- avg_wiring : np.array averaged wiring Returns ------- np.array of int majority vote of synapse sign (row wise) """ cum_rows = np.sum(avg_wiring, axis=0) maj_vote = np.zeros((4)) for i in range(4): maj_vote[i] = cum_rows[i, 2] > cum_rows[i, 1] return maj_vote
[docs]def get_cum_pos(den_ranges, ax_ranges, den_pos, ax_pos): """Calculates the position of synapse in average matrix, i.e. which sector it belongs to. """ den_cum_pos = 0 ax_cum_pos = 0 for i in range(1, len(den_ranges)): if (den_pos >= den_ranges[i-1]) and (den_pos < den_ranges[i]): den_cum_pos = i-1 for i in range(1, len(ax_ranges)): if (ax_pos >= ax_ranges[i-1]) and (ax_pos < ax_ranges[i]): ax_cum_pos = i-1 return den_cum_pos, ax_cum_pos
[docs]def plot_wiring(wiring, den_borders, ax_borders, max_val, confidence_lvl, binary, wd, big_entries=True, add_fname='', maj_vote=()): """Plot type sorted connectivity matrix and save to figures folder in working directory Parameters ---------- wiring : np.array symmetric 2D array of size #cells x #cells den_borders: cell type boarders on post synaptic site ax_borders: cell type boarders on pre synaptic site max_val : float maximum cumulated contact area shown in plot confidence_lvl : float minimum probability of cell type prediction to keep cell binary : bool if True existence of synapse is weighted by 1, else 0 add_fname : str supplement of image file maj_vote : tuple big_entries : bool artificially increase pixel size from 1 to 3 for better visualization """ for k, b in enumerate(den_borders): b += k * 1 wiring = np.concatenate((wiring[:b, :], np.zeros((1, wiring.shape[1], 3)), wiring[b:, :]), axis=0) for k, b in enumerate(ax_borders): b += k * 1 wiring = np.concatenate((wiring[:, :b], np.zeros((wiring.shape[0], 1, 3)), wiring[:, b:]), axis=1) intensity_plot = np.zeros((wiring.shape[0], wiring.shape[1])) ax_borders_h = arr([0, ax_borders[0], ax_borders[1], ax_borders[2], wiring.shape[1]])+arr([0, 1, 2, 3, 4]) wiring *= -1 for i in range(wiring.shape[0]): for j in range(wiring.shape[1]): den_pos, ax_pos = get_cum_pos(ax_borders_h, ax_borders_h, i, j) syn_sign = maj_vote[ax_pos] if wiring[i, j, 1] != 0: intensity_plot[i, j] = (-1)**syn_sign * wiring[i, j, 1] elif wiring[i, j, 2] != 0: intensity_plot[i, j] = (-1)**syn_sign * wiring[i, j, 2] if big_entries: for add_i in [-1, 0, 1]: for add_j in [-1, 0, 1]: den_pos_i, ax_pos_j = get_cum_pos( ax_borders_h, ax_borders_h, i+add_i, j+add_j) if (i+add_i >= wiring.shape[0]) or (i+add_i < 0) or\ (j+add_j >= wiring.shape[1]) or (j+add_j < 0) or\ (den_pos_i != den_pos) or (ax_pos_j != ax_pos): continue if wiring[i, j, 1] != 0: intensity_plot[i+add_i, j+add_j] = \ (-1)**(syn_sign+1) * wiring[i, j, 1] elif wiring[i, j, 2] != 0: intensity_plot[i+add_i, j+add_j] = \ (-1)**(syn_sign+1) * wiring[i, j, 2] if not big_entries: np.save(wd + '/figures/connectivity_matrix.npy', intensity_plot) tmp_max_val = np.max(np.abs(intensity_plot)) matplotlib.rcParams.update({'font.size': 14}) fig = pp.figure() # Create scatter plot gs = gridspec.GridSpec(1, 2, width_ratios=[20, 1]) gs.update(wspace=0.05, hspace=0.08) ax = pp.subplot(gs[0, 0], frameon=False) cax = ax.matshow(-intensity_plot.transpose(1, 0), cmap=diverge_map(), extent=[0, wiring.shape[0], wiring.shape[1], 0], interpolation="none", vmin=-tmp_max_val, vmax=tmp_max_val) ax.set_xlabel('Post', fontsize=18) ax.set_ylabel('Pre', fontsize=18) ax.set_xlim(0, wiring.shape[0]) ax.set_ylim(0, wiring.shape[1]) plt.grid(False) plt.axis('off') for k, b in enumerate(den_borders): b += k * 1 plt.axvline(b+0.5, color='k', lw=0.5, snap=True, antialiased=True) for k, b in enumerate(ax_borders): b += k * 1 plt.axhline(b+0.5, color='k', lw=0.5, snap=True, antialiased=True) cbar_ax = pp.subplot(gs[0, 1]) cbar_ax.yaxis.set_ticks_position('none') cb = fig.colorbar(cax, cax=cbar_ax, ticks=[]) plt.close() if not binary: fig.savefig(wd + '/figures/type_wiring%s_conf' 'lvl%d_be%s.png' % (add_fname, int(confidence_lvl*10), str(big_entries)), dpi=600) else: fig.savefig(wd + '/figures/type_wiring%s_conf' 'lvl%d_be%s_binary.png' % (add_fname, int(confidence_lvl*10), str(big_entries)), dpi=600)
[docs]def plot_wiring_cum(wiring, den_borders, ax_borders, confidence_lvl, max_val, binary, wd, add_fname='', maj_vote=()): """Plot wiring diagram on celltype-to-celltype level, e.g. connectivity between EA and MSN """ # plot intensities, averaged per sector nb_cells_per_sector = np.zeros((4, 4)) intensity_plot = np.zeros((4, 4)) for i in range(4): for j in range(4): diff_den = den_borders[i+1] - den_borders[i] diff_ax = ax_borders[j+1] - ax_borders[j] nb_cells_per_sector[i, j] = diff_den * diff_ax if nb_cells_per_sector[i, j] != 0: sector_intensity = np.sum(wiring[i, j]) / nb_cells_per_sector[i, j] else: sector_intensity = 0 syn_sign = maj_vote[j] if wiring[i, j, 1] > wiring[i, j, 2]: intensity_plot[i, j] = (-1)**(syn_sign+1) * sector_intensity else: intensity_plot[i, j] = (-1)**(syn_sign+1) * np.min((sector_intensity, 0.1)) np.save(wd + '/figures/cumulated_connectivity_matrix.npy', intensity_plot) ind = np.arange(4) intensity_plot = intensity_plot.transpose(1, 0)[::-1] row_sum = np.sum(np.sum(wiring.transpose(1, 0, 2)[::-1], axis=2), axis=1) col_sum = np.sum(np.sum(wiring.transpose(1, 0, 2)[::-1], axis=2), axis=0) max_val_tmp = np.array([np.max(intensity_plot), np.abs(np.min(intensity_plot))]) intensity_plot[intensity_plot < 0] /= max_val_tmp[1] intensity_plot[intensity_plot > 0] /= max_val_tmp[0] matplotlib.rcParams.update({'font.size': 14}) fig = pp.figure() # Create scatter plot gs = gridspec.GridSpec(2, 3, width_ratios=[10, 1, 0.5], height_ratios=[1, 10]) gs.update(wspace=0.05, hspace=0.08) ax = pp.subplot(gs[1, 0], frameon=False) tmp_max_val = np.max(np.abs(intensity_plot)) cax = ax.matshow(intensity_plot, cmap=diverge_map(), extent=[0, 4, 0, 4], vmin=-tmp_max_val, vmax=tmp_max_val) ax.grid(color='k', linestyle='-') cbar_ax = pp.subplot(gs[1, 2]) cbar_ax.yaxis.set_ticks_position('none') axr = pp.subplot(gs[1, 1], sharey=ax, yticks=[], xticks=[], frameon=False, xlim=(np.min(row_sum), np.max(row_sum)), ylim=(0, 4)) axr.tick_params(axis='x', which='major', right="off", top="off", left="off", pad=10, bottom="off", labelsize=12, direction='out', length=4, width=1) axr.spines['top'].set_visible(False) axr.spines['right'].set_visible(False) axr.spines['left'].set_visible(False) axr.spines['bottom'].set_visible(False) axr.get_xaxis().tick_bottom() axr.get_yaxis().tick_left() axr.barh(ind, row_sum[::-1], 1, color='0.6', linewidth=0) axt = pp.subplot(gs[0, 0], sharex=ax, xticks=[], yticks=[], frameon=False, xlim=(0, 4), ylim=(np.min(col_sum), np.max(col_sum))) axt.tick_params(axis='y', which='major', right="off", bottom="off", top="off", left="off", pad=10, labelsize=12, direction='out', length=4, width=1) axr.spines['top'].set_visible(False) axr.spines['right'].set_visible(False) axr.spines['left'].set_visible(False) axr.spines['bottom'].set_visible(False) axt.get_xaxis().tick_bottom() axt.get_yaxis().tick_left() axt.bar(ind, col_sum, 1, color='0.6', linewidth=0) plt.close() if not binary: fig.savefig(wd + '/figures/type_wiring_cum%s_conf' 'lvl%d.png' % (add_fname, int(confidence_lvl*10)), dpi=600) else: fig.savefig(wd + '/figures/type_wiring_cum%s_conf' 'lvl%d_binary.png' % (add_fname, int(confidence_lvl*10)), dpi=600)
[docs]def type_sorted_wiring_cs(wd, confidence_lvl=0.3, binary=False, max_syn_size=0.4): """Same as type_sorted_wiring but for all contact sites (synapse classification 0 and 1) """ skeleton_ids, cell_type_probas = load_celltype_probas(wd) cell_type_pred_dict = load_pkl2obj(wd + '/neurons/celltype_pred_dict.pkl') bool_arr = np.zeros(len(skeleton_ids)) # remove all skeletons under confidence level for k, probas in enumerate(cell_type_probas): if np.max(probas) > confidence_lvl: bool_arr[k] = 1 bool_arr = bool_arr.astype(np.bool) skeleton_ids = skeleton_ids[bool_arr] # create matrix syn_props = load_pkl2obj(wd + '/contactsites/connectivity_dict_all.pkl') total_area_key = 'total_cs_area' dendrite_ids = set() axon_ids = set() for pair_name, pair in syn_props.iteritems(): skel_id1, skel_id2 = re.findall('(\d+)_(\d+)', pair_name)[0] skel_id1 = int(skel_id1) skel_id2 = int(skel_id2) if skel_id1 not in skeleton_ids or skel_id2 not in skeleton_ids: continue axon_ids.add(skel_id1) dendrite_ids.add(skel_id2) all_used_ids = set() all_used_ids.update(axon_ids) all_used_ids.update(dendrite_ids) # print "%d/%d cells have no connection between each other." %\ # (len(skeleton_ids) - len(all_used_ids), len(skeleton_ids)) # print "Using %d unique cells in wiring." % len(all_used_ids) axon_ids = np.array(list(axon_ids)) dendrite_ids = np.array(list(dendrite_ids)) # sort dendrites, axons using its type prediction. order is determined by # dictionaries get_id_dict_from_skel_ids dendrite_pred = np.array([cell_type_pred_dict[den_id] for den_id in dendrite_ids]) type_sorted_ixs = np.argsort(dendrite_pred, kind='mergesort') dendrite_pred = dendrite_pred[type_sorted_ixs] dendrite_ids = dendrite_ids[type_sorted_ixs] axon_pred = np.array([cell_type_pred_dict[den_id] for den_id in axon_ids]) type_sorted_ixs = np.argsort(axon_pred, kind='mergesort') axon_pred = axon_pred[type_sorted_ixs] axon_ids = axon_ids[type_sorted_ixs] den_id_dict, rev_den_id_dict = get_id_dict_from_skel_ids(dendrite_ids) ax_id_dict, rev_ax_id_dict = get_id_dict_from_skel_ids(axon_ids) wiring = np.zeros((len(dendrite_ids), len(axon_ids), 1), dtype=np.float) cum_wiring = np.zeros((4, 4)) for pair_name, pair in syn_props.iteritems(): if pair[total_area_key] != 0: skel_id1, skel_id2 = re.findall('(\d+)_(\d+)', pair_name)[0] skel_id1 = int(skel_id1) skel_id2 = int(skel_id2) if skel_id1 not in skeleton_ids or skel_id2 not in skeleton_ids: continue dendrite_pos = den_id_dict[skel_id2] axon_pos = ax_id_dict[skel_id1] cum_den_pos = cell_type_pred_dict[skel_id2] cum_ax_pos = cell_type_pred_dict[skel_id1] y = pair[total_area_key] if binary: y = 1. wiring[dendrite_pos, axon_pos] = np.min((y, max_syn_size)) cum_wiring[cum_den_pos, cum_ax_pos] += y ax_borders = class_ranges(axon_pred)[1:-1] den_borders = class_ranges(dendrite_pred)[1:-1] supp = '_CS' plot_wiring_cs(wiring, den_borders, ax_borders, confidence_lvl, binary, add_fname=supp) plot_wiring_cum_cs(cum_wiring, class_ranges(dendrite_pred), class_ranges(axon_pred), confidence_lvl, binary, add_fname=supp)
[docs]def plot_wiring_cs(wiring, den_borders, ax_borders, confidence_lvl, binary, wd, add_fname='_CS'): """Same as plot_wiring, but using all contact sites """ fig = plt.figure() ax = plt.gca() max_val = np.max(wiring) for k, b in enumerate(den_borders): b += k * 1 wiring = np.concatenate((wiring[:b, :], np.ones((1, wiring.shape[1], 1)), wiring[b:, :]), axis=0) for k, b in enumerate(ax_borders): b += k * 1 wiring = np.concatenate((wiring[:, :b], np.ones((wiring.shape[0], 1, 1)), wiring[:, b:]), axis=1) ax.matshow(np.max(wiring.transpose(1, 0, 2), axis=2), interpolation="none", extent=[0, wiring.shape[0], wiring.shape[1], 0], cmap='gray') ax.set_xlabel('Post', fontsize=18) ax.set_ylabel('Pre', fontsize=18) ax.set_xlim(0, wiring.shape[0]) ax.set_ylim(0, wiring.shape[1]) plt.grid(False) plt.axis('off') divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.1) a = np.array([[0, 1]]) plt.figure() plt.imshow(a, cmap='gray') plt.gca().set_visible(False) cb = plt.colorbar(cax=cax, ticks=[0, 1]) if not binary: cb.ax.set_yticklabels(['0', "%0.3g+" % max_val], rotation=90) cb.set_label(u'Area of Synaptic Junctions [µm$^2$]') else: cb.ax.set_yticklabels(['0', '1'], rotation=90) cb.set_label(u'Synaptic Junction') plt.close() if not binary: fig.savefig(wd + '/figures/type_wiring%s_conf' 'lvl%d.png' % (add_fname, int(confidence_lvl*10)), dpi=600) else: fig.savefig(wd + '/figures/type_wiring%s_conf' 'lvl%d_binary.png' % (add_fname, int(confidence_lvl*10)), dpi=600)
[docs]def plot_wiring_cum_cs(wiring, den_borders, ax_borders, confidence_lvl, binary, wd, add_fname=''): """Same as plot wiring, but using all contact sites""" # plot intensities, averaged per sector nb_cells_per_sector = np.zeros((4, 4)) intensity_plot = np.zeros((4, 4)) for i in range(4): for j in range(4): diff_den = den_borders[i+1] - den_borders[i] diff_ax = ax_borders[j+1] - ax_borders[j] nb_cells_per_sector[i, j] = diff_den * diff_ax if nb_cells_per_sector[i, j] != 0: sector_intensity = np.sum(wiring[i, j]) / nb_cells_per_sector[i, j] else: sector_intensity = 0 intensity_plot[i, j] = sector_intensity ind = np.arange(4) intensity_plot = intensity_plot.transpose(1, 0)[::-1] wiring = wiring.transpose(1, 0)[::-1] intensity_plot[intensity_plot > 1.0] = 1.0 max_val = np.max(intensity_plot) row_sum = np.sum(wiring, axis=1) col_sum = np.sum(wiring, axis=0) intensity_plot = (intensity_plot - intensity_plot.min())/\ np.max((intensity_plot.max() - intensity_plot.min())) matplotlib.rcParams.update({'font.size': 14}) fig = pp.figure() # Create scatter plot gs = gridspec.GridSpec(2, 3, width_ratios=[10, 1, 0.5], height_ratios=[1, 10]) gs.update(wspace=0.05, hspace=0.08) ax = pp.subplot(gs[1, 0], frameon=False) cax = ax.matshow(intensity_plot, cmap='gray_r', extent=[0, 4, 0, 4]) ax.grid(color='k', linestyle='-') cbar_ax = pp.subplot(gs[1, 2]) cbar_ax.yaxis.set_ticks_position('left') cb = fig.colorbar(cax, cax=cbar_ax, ticks=[0, 1]) cb.ax.set_yticklabels(['0', '%0.4f' % max_val], rotation=90) if not binary: cb.set_label(u'Average Area of Contact Sites [µm$^2$]') else: cb.set_label(u'Average Number of Contact Sites') axr = pp.subplot(gs[1, 1], sharey=ax, yticks=[], xticks=[0, max(row_sum)], frameon=True, xlim=(np.min(row_sum), np.max(row_sum)), ylim=(0, 4)) axr.tick_params(axis='x', which='major', right="off", top="off", pad=10, labelsize=12, direction='out', length=4, width=1) axr.spines['top'].set_visible(False) axr.spines['right'].set_visible(False) axr.get_xaxis().tick_bottom() axr.get_yaxis().tick_left() axr.barh(ind, row_sum[::-1], 1, color='0.6', linewidth=0) axt = pp.subplot(gs[0, 0], sharex=ax, xticks=[], yticks=[0, max(col_sum)], frameon=True, xlim=(0, 4), ylim=(np.min(col_sum), np.max(col_sum))) axt.tick_params(axis='y', which='major', right="off", bottom="off", pad=10, labelsize=12, direction='out', length=4, width=1) axt.spines['top'].set_visible(False) axt.spines['right'].set_visible(False) axt.get_xaxis().tick_bottom() axt.get_yaxis().tick_left() axt.bar(ind, col_sum, 1, color='0.6', linewidth=0) plt.show(block=False) if not binary: fig.savefig(wd + '/figures/type_wiring_cum%s_conf' 'lvl%d.png' % (add_fname, int(confidence_lvl*10)), dpi=600) else: fig.savefig(wd + '/figures/type_wiring_cum%s_conf' 'lvl%d_binary.png' % (add_fname, int(confidence_lvl*10)), dpi=600)
[docs]def make_colormap(seq): """Return a LinearSegmentedColormap seq: a sequence of floats and RGB-tuples. The floats should be increasing and in the interval (0,1). """ seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3] cdict = {'red': [], 'green': [], 'blue': []} for i, item in enumerate(seq): if isinstance(item, float): r1, g1, b1 = seq[i - 1] r2, g2, b2 = seq[i + 1] cdict['red'].append([item, r1, r2]) cdict['green'].append([item, g1, g2]) cdict['blue'].append([item, b1, b2]) return mcolors.LinearSegmentedColormap('CustomMap', cdict)
[docs]def diverge_map(low=(239/255., 65/255., 50/255.), high=(39/255., 184/255., 148/255.)): """Low and high are colors that will be used for the two ends of the spectrum. they can be either color strings or rgb color tuples """ c = mcolors.ColorConverter().to_rgb if isinstance(low, basestring): low = c(low) if isinstance(high, basestring): high = c(high) return make_colormap([low, c('white'), 0.5, c('white'), high])
[docs]def class_ranges(pred_arr): """Helper function to get extent of cell types in sorted prediction Parameters ---------- pred_arr : np.array sorted array of cell type predictions Returns ------- np.array indices of changing cell type labels in pred_arr """ if len(pred_arr) == 0: return np.array([0, 0, 0, 0, 0]) class1 = np.argmax(pred_arr == 1) class2 = np.max((class1, np.argmax(pred_arr == 2))) class3 = np.max((class2, np.argmax(pred_arr == 3))) return np.array([0, class1, class2, class3, len(pred_arr)])