Source code for singlecellmultiomics.universalBamTagger.4SUtagger

#!/usr/bin/env python

import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300

import matplotlib.pyplot as plt
import seaborn as sns
import pysam
from pysamiterators import CachedFasta, MatePairIterator

# Molecule modules:
from singlecellmultiomics.molecule import TranscriptMolecule, MoleculeIterator
from singlecellmultiomics.fragment import SingleEndTranscriptFragment
from singlecellmultiomics.features import FeatureContainer

# Conversion modules:
from singlecellmultiomics.variants.substitutions import conversion_dict_stranded
from singlecellmultiomics.variants import substitution_plot, vcf_to_position_set
from singlecellmultiomics.utils import reverse_complement, complement
from collections import defaultdict, Counter
from singlecellmultiomics.utils import is_main_chromosome
from singlecellmultiomics.bamProcessing import sorted_bam_file, merge_bams


from scipy import stats
from multiprocessing import Pool
import os

import argparse

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import pickle
import gzip
from uuid import uuid4

[docs]def substitution_plot_stranded(pattern_counts: dict, figsize: tuple = (12, 4), conversion_colors: tuple = ('b', 'k', 'r', 'grey', 'g', 'pink','b','k','r','k','w','g'), ylabel: str = '# conversions per molecule', add_main_group_labels: bool = True, ax=None,fig=None, **plot_args ): """ Create 3bp substitution plot Args: pattern_counts(OrderedDict) : Dictionary containing the substitutions to plot. Use variants.vcf_to_variant_contexts to create it. Format: ```OrderedDict([(('ACA', 'A'), 0), (('ACC', 'A'), 1), (('ACG', 'A'), 0), ... (('TTG', 'G'), 0), (('TTT', 'G'), 0)])``` figsize(tuple) : size of the figure to create conversion_colors(tuple) : colors to use for the conversion groups ylabel(str) : y axis label add_main_group_labels(bool) : Add conversion group labels to top of plot **plot_args : Additional argument to pass to .plot() Returns fig : handle to the figure ax : handle to the axis Example: >>> from singlecellmultiomics.variants import vcf_to_variant_contexts, substitution_plot >>> import matplotlib.pyplot as plt >>> pobs = vcf_to_variant_contexts('variants.vcf.gz', 'reference.fasta') >>> for sample, conversions in pobs.items(): >>> fig, ax = substitution_plot(conversions) >>> ax.set_title(sample) >>> plt.show() """ conversions_single_nuc = ('AC', 'AG', 'AT', 'CA', 'CG', 'CT', 'GA', 'GC', 'GT', 'TA', 'TC', 'TG') # Colors for the conversion groups: color_d = dict(zip(conversions_single_nuc, conversion_colors)) colors = [color_d.get(f'{context[1]}{to}') for context, to in pattern_counts.keys()] if ax is None: fig, ax = plt.subplots(figsize=figsize) substitution_dataframe = pd.DataFrame(pattern_counts.values(), index=list(pattern_counts.keys())).T substitution_dataframe.plot(kind='bar', color=colors, legend=False, width=1.0, ax=ax, edgecolor='k', **plot_args) offset = (1 / len(pattern_counts)) * 0.5 # Amount of distance for a half bar # Add 3bp context ticks: ax.set_xticks(np.linspace(-0.5 + offset, 0.5 - offset, len(pattern_counts))) ax.set_xticklabels( [context for context, to in pattern_counts.keys()], rotation=90, size=5) ax.set_ylabel(ylabel) ax.set_xlim((-0.5, 0.5)) sns.despine() if add_main_group_labels: for i, (u, v) in enumerate(conversions_single_nuc): ax.text( # position text relative to Axes (i + 0.5) / len(conversions_single_nuc), 1.0, f'{u}>{v}', fontsize=8, ha='center', va='top', transform=ax.transAxes,bbox=dict(facecolor='white', alpha=1,lw=0) ) return fig, ax
if __name__=='__main__': argparser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='Assign molecules') argparser.add_argument('bamin', type=str, help='Input BAM file') argparser.add_argument('-o', type=str, help="output bam file (.bam)", required=True) argparser.add_argument('-reference', type=str, help="Reference_path (.fasta)", required=True) argparser.add_argument('-known', type=str, help="Known variants (vcf)", required=True) argparser.add_argument('-exons', type=str, help="exons (gtf.gz)", required=True) argparser.add_argument('-introns', type=str, help="introns (gtf.gz)", required=True) argparser.add_argument('--R2_based', help="The input is only R2 sequences, the molcule mapping direction will be inverted", action='store_true') argparser.add_argument('-temp_dir', type=str, help="scmo_temp", default=str(uuid4())) argparser.add_argument('-tagthreads', type=int, help="Amount of threads used (int)", required=True) args = argparser.parse_args() single_cell_bam_path = args.bamin reference_path = args.reference # Load known variation, to ignore for mut-spectrum known_vcf_path = args.known # Paths to gene models exons_gtf_path = args.exons introns_gtf_path = args.introns # Write a tagged bam file to this path: tagged_output_path = args.o ##### def obtain_conversions(contig : str): """ Create conversion dictionary for the suppled contig Args: contig (str) Returns: conversions_per_library (defaultdict( conversion_dict_stranded ) ) : Per library conversion dictionary n_molecules_per_library (Counter) : observed molecules per library contig(str) : the contig passed to the method temp_bam_path(str) : path to tagged bam file, tagged with gene annotations and 4su mutation count """ conversions_per_library = defaultdict( conversion_dict_stranded ) n_molecules_per_library = Counter() from singlecellmultiomics.molecule import might_be_variant # Create temp directory to write tagged bam file to: temp_dir = args.temp_dir temp_bam_path = f'{temp_dir}/{contig}.bam' if not os.path.exists(temp_dir): try: os.makedirs(temp_dir) except Exception as e: pass # Load gene annotations for the selected contig: transcriptome_features = FeatureContainer() transcriptome_features.loadGTF( path=exons_gtf_path, select_feature_type=['exon'], identifierFields=( 'exon_id', 'gene_id'), store_all=True, contig=contig, head=None) transcriptome_features.loadGTF( path=introns_gtf_path, select_feature_type=['intron'], identifierFields=['transcript_id'], store_all=True, contig=contig, head=None) colormap = plt.get_cmap('RdYlBu_r') colormap.set_bad((0,0,0)) read_groups = {} try: with pysam.AlignmentFile(single_cell_bam_path, threads=4) as alignments, \ pysam.VariantFile(known_vcf_path) as known, \ sorted_bam_file(temp_bam_path, origin_bam=single_cell_bam_path, read_groups=read_groups, fast_compression=True) as out, \ pysam.FastaFile(reference_path) as reference_handle: # Cache the sequence of the contig: (faster) reference = CachedFasta(reference_handle) for n_molecules, molecule in enumerate(MoleculeIterator(alignments, TranscriptMolecule, SingleEndTranscriptFragment, fragment_class_args = { 'stranded':True, 'features':transcriptome_features }, molecule_class_args={ 'reference':reference, 'features':transcriptome_features, 'auto_set_intron_exon_features': True }, contig=contig )): # Read out mut spectrum consensus = molecule.get_consensus() if args.R2_based: molecule.strand = not molecule.strand # Invert becayse its R2 based. n_molecules_per_library[molecule.library] += 1 n_4su_mutations = 0 n_4su_contexts = 0 for (chrom,pos), base in consensus.items(): context = reference.fetch(chrom, pos-1, pos+2).upper() if len(context)!=3: continue if ( (context[1]=='A' and not molecule.strand) or (context[1]=='T' and molecule.strand) ) : n_4su_contexts+=1 # Check if the base matches or the refence contains N's if context[1]==base or 'N' in context or len(context)!=3: continue # Ignore germline variants: if might_be_variant(chrom, pos, known): continue if not molecule.strand: # reverse template context = reverse_complement(context) base = complement(base) # Count 4SU specific mutations, and write to molecule later if context[1]=='T' and base=='C': n_4su_mutations+=1 conversions_per_library[molecule.library][(context, base)] += 1 # Write 4su modification to molecule molecule.set_meta('4S',n_4su_mutations) molecule.set_meta('4c',n_4su_contexts) # Set read color based on conversion rate: try: # The max color value will be 10% modification rate cfloat = colormap( np.clip( 10*(n_4su_mutations/n_4su_contexts),0,1) )[:3] except Exception as e: cfloat = colormap._rgba_bad[:3] molecule.set_meta('YC', '%s,%s,%s' % tuple((int(x * 255) for x in cfloat))) molecule.set_meta('4c',n_4su_contexts) molecule.write_tags() for fragment in molecule: rgid = fragment.get_read_group() if not rgid in read_groups: read_groups[rgid] = fragment.get_read_group(True)[1] # Write tagged molecule to output file molecule.write_pysam(out) except KeyboardInterrupt: # This allows you to cancel the analysis (CTRL+C) and get the current result pass return conversions_per_library, n_molecules_per_library, contig, temp_bam_path n_molecules_per_library = Counter() with Pool(args.tagthreads) as workers: conversions_per_library = defaultdict( conversion_dict_stranded ) # library : (context, query) : obs (int) # Obtain all contigs from the input bam file, exclude scaffolds: with pysam.AlignmentFile(single_cell_bam_path) as alignments: contigs = [contig for contig in alignments.references if is_main_chromosome(contig) and contig not in ['MT','Y'] ] # Run conversion detection on all contigs in parallel: tagged_bams = [] for conversions_for_contig, \ n_molecules_for_contig_per_lib, \ contig, \ temp_tagged_bam in workers.imap_unordered(obtain_conversions, contigs): # Merge conversion dictionary: for library, library_convs in conversions_for_contig.items(): for context, observations in library_convs.items(): conversions_per_library[library][context] += observations n_molecules_per_library+=n_molecules_for_contig_per_lib print(f'finished {contig} ', end='\r') tagged_bams.append(temp_tagged_bam) # Merge: print(f'Merging ', end='\r') merge_bams(tagged_bams, tagged_output_path) # Normalize observed counts to the amount of molecules we saw: for library, library_convs in conversions_per_library.items(): for context, observations in library_convs.items(): library_convs[context] = observations / n_molecules_per_library[library] try: fig, axes = plt.subplots(len(conversions_per_library),1, figsize=(16,4*(len(conversions_per_library))), sharey=True ) if len(conversions_per_library)==1: axes = [axes] for ax, (library, conversions) in zip(axes,conversions_per_library.items()): # Export;: substitution_dataframe = pd.DataFrame(conversions.values(), index=list(conversions.keys())).T substitution_dataframe.to_csv(tagged_output_path.replace('.bam',f'{library}_conversions.csv')) substitution_plot_stranded(conversions,fig=fig, ax=ax,ylabel='conversions seen per molecule') ax.set_axisbelow(True) ax.grid(axis='y') ax.set_title(f'{library}, {n_molecules_per_library[library]} molecules') fig.tight_layout(pad=3.0) plt.savefig(tagged_output_path.replace('.bam','conversions.png')) except Exception as e: print(e) # Count amount of 4sU conversions per cell, per gene def listionary(): return defaultdict(list) expression_per_cell_per_gene = defaultdict(Counter) # gene -> cell -> obs four_su_per_cell_per_gene = defaultdict(listionary ) # cell -> gene -> [] 4_su observation counts per molecule four_su_per_gene_per_cell = defaultdict(listionary ) # gene -> cell -> [] 4_su observation counts per molecule with pysam.AlignmentFile(tagged_output_path) as reads: for R1,R2 in MatePairIterator(reads): for read in (R1,R2): # Count every fragment only once by selecting one of the two reads. if read is not None: break if read.has_tag('gn'): gene = read.get_tag('gn') elif read.has_tag('GN'): gene = read.get_tag('GN') else: continue if read.is_duplicate: continue cell = read.get_tag('SM') foursu = read.get_tag('4S') foursu_contexts = read.get_tag('4c') library = read.get_tag('LY') cell = cell.split('_')[1] # Remove library part expression_per_cell_per_gene[gene][(library,cell)] += 1 if foursu_contexts>0: four_su_per_gene_per_cell[gene][(library,cell)].append(foursu/foursu_contexts) four_su_per_cell_per_gene[(library,cell)][gene].append(foursu/foursu_contexts) assert not (foursu>0 and foursu_contexts==0) # Store these dictionaries to disk with gzip.open( tagged_output_path.replace('.bam','4sU_per_gene_per_cell.dict.pickle.gz'),'wb' ) as o: pickle.dump(four_su_per_gene_per_cell, o) with gzip.open( tagged_output_path.replace('.bam','4sU_per_cell_per_gene.dict.pickle.gz'),'wb' ) as o: pickle.dump(four_su_per_cell_per_gene, o) with gzip.open( tagged_output_path.replace('.bam','expression_per_cell_per_gene.pickle.gz'),'wb' ) as o: pickle.dump(expression_per_cell_per_gene, o) four_su_per_gene_per_cell_mean = defaultdict(dict) four_su_per_gene_per_cell_total= defaultdict(dict) for gene in four_su_per_gene_per_cell: for cell, fsu_obs in four_su_per_gene_per_cell[gene].items(): four_su_per_gene_per_cell_mean[gene][cell] = np.mean(fsu_obs) four_su_per_gene_per_cell_total[gene][cell] = np.sum( np.array(fsu_obs)>0 ) four_su_per_gene_per_cell_mean = pd.DataFrame(four_su_per_gene_per_cell_mean).T four_su_per_gene_per_cell_total = pd.DataFrame(four_su_per_gene_per_cell_total).T four_su_per_gene_per_cell_mean.to_csv(tagged_output_path.replace('.bam','4sU_labeled_ratio.csv.gz')) expression_matrix = pd.DataFrame(four_su_per_gene_per_cell).T.fillna(0) libraries = expression_matrix.columns.get_level_values(0).unique() ############ fig, ax = plt.subplots(figsize=(7,7)) min_molecules = 100 conversion_ratios = {} # cell->gene->ratio for library in sorted(list(libraries)): if not '4s' in library: continue cell_efficiencies = {} cell_molecules = Counter() for cell, genes in four_su_per_cell_per_gene.items(): target_cell_name = cell if cell[0]!=library: continue if '100cells' in library: target_cell_name = 'bulk' conversions_total = [] for gene, conversions in genes.items(): conversions_total+= conversions cell_molecules[target_cell_name]+=len(conversions) cell_efficiencies[target_cell_name] = np.mean(conversions_total)*100 selected_cells = [cell for cell in cell_efficiencies if cell_molecules[cell]>min_molecules] cell_efficiencies = {cell:cell_efficiencies[cell] for cell in selected_cells} scatter = plt.scatter( [cell_molecules[cell] for cell in selected_cells], cell_efficiencies.values(), label=library, s=2 ) plt.scatter( np.median([cell_molecules[cell] for cell in selected_cells]), np.median( list(cell_efficiencies.values())), facecolors = 'k', s=250, marker='+',edgecolors='black', lw=3 ) plt.scatter( np.median([cell_molecules[cell] for cell in cell_efficiencies]), np.median( list(cell_efficiencies.values())), facecolors = scatter.get_facecolors(), s=250, marker='+',edgecolors='black', lw=1 ) plt.ylabel('4sU conversion rate (%)') plt.xlabel('total molecules') plt.xscale('log') ax.set_axisbelow(True) ax.grid() plt.legend( bbox_to_anchor=(0.6, 1)) sns.despine() plt.title('4sU conversion rate per cell') plt.savefig(tagged_output_path.replace('.bam','conversion_rate.png'), dpi=200) ########## fig, axes = plt.subplots(6,4,figsize=(13,17), squeeze=True) axes = axes.flatten() axes_index = 0 for gene in expression_matrix.mean(1).sort_values()[-100:].index: fraction_hits = defaultdict(list) labeled = defaultdict(list) total = defaultdict(list) if gene=='MALAT1': continue for (library, cell), labeled_4su_fraction in four_su_per_gene_per_cell[gene].items(): #if not '4sU' in library and not 'LIVE' in library: # continue if '4sU' in library: library = '4sU' else: library= 'unlabeled' fraction_hits[library] += labeled_4su_fraction labeled[library].append( sum([ l>0 for l in labeled_4su_fraction]) ) total[library].append(len(labeled_4su_fraction)) try: max_x = max( ( max(total[library]) for library in total)) slope, intercept, r_value, p_value, std_err = stats.linregress(total['4sU'],labeled['4sU']) if slope<0.001 or p_value>0.05 or np.isnan(p_value): continue slope, intercept, r_value, p_value, std_err = stats.linregress(total['unlabeled'],labeled['unlabeled']) if p_value>0.05 or np.isnan(p_value) : continue except Exception as e: continue ax = axes[axes_index] axes_index+=1 for library in total: slope, intercept, r_value, p_value, std_err = stats.linregress(total[library],labeled[library]) #max_x = max(total[library]) ax.plot([0,max_x],[intercept,max_x*slope + intercept],c='red' if '4sU' in library else 'k' ) for library in total: ax.scatter(total[library],labeled[library], label=library , s=10, alpha=0.5, c='red' if '4sU' in library else 'k' ) slope, intercept, r_value, p_value, std_err = stats.linregress(total['4sU'],labeled['4sU']) ax.legend() ax.set_xlabel('total molecules') ax.set_ylabel('4sU labeled molecules') ax.set_title(f'{gene}\nslope:{slope:.2f}') sns.despine() if axes_index>=len(axes): break fig.tight_layout(pad=1.0) plt.savefig( (tagged_output_path.replace('.bam','slopes.png')) , dpi=200)