Source code for gibbs

""" 
Perform Gibbs samplers and process data.

This module provides the `ParallelGibbs` class, which parallelizes the creation
of Gibbs samplers for each residue in the contact map. This module also provides
the `Gibbs` class, which allows for the loading and processing of the gibbs
sampler data, as well as plotting and saving processed results.
"""
import os
import gc
import pickle
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from numpy.random import default_rng
from tqdm import tqdm
from MDAnalysis.analysis.base import Results
from basicrta.util import confidence_interval
import multiprocessing
from multiprocessing import Pool, Lock
import MDAnalysis as mda
from basicrta import istarmap

gc.enable()
mpl.rcParams['pdf.fonttype'] = 42
rng = default_rng()


[docs] class ParallelGibbs(object): """ A module to take a contact map and run Gibbs samplers for each residue. :param contacts: Contact pickle file (`contacts-{cutoff}.pkl`). :type contacts: str :param nproc: Number of processes to use in running Gibbs samplers. :type nproc: int :param ncomp: Number of mixture components to use in the Gibbs sampler. :type ncomp: int :param niter: Number of iterations of the Gibbs sampler to perform. :type niter: int """ def __init__(self, contacts, nproc=1, ncomp=15, niter=110000): self.cutoff = float(contacts.strip('.pkl').split('/')[-1].split('_') [-1]) self.niter = niter self.nproc = nproc self.ncomp = ncomp self.contacts = contacts
[docs] def run(self, run_resids=None, g=100): """ The :meth:`run` method executes the Gibbs samplers for all residues of `sel1` present in the contact map, or a list of resids can be provided. :param run_resids: Resid(s) for which to run a Gibbs sampler. :type run_resids: int or list, optional """ with open(self.contacts, 'r+b') as f: contacts = pickle.load(f) # Check if this is a combined contact file metadata = contacts.dtype.metadata is_combined = metadata and 'n_trajectories' in metadata and metadata['n_trajectories'] > 1 if is_combined: print(f"WARNING: Using combined contact file with {metadata['n_trajectories']} trajectories.") print("WARNING: Kinetic clustering is not yet supported for combined contacts.") print("WARNING: The Gibbs sampler will pool all residence times together.") protids = np.unique(contacts[:, 0]) if not run_resids: run_resids = protids if not isinstance(run_resids, (list, np.ndarray)): run_resids = [run_resids] rg = contacts.dtype.metadata['ag1'].residues resids = rg.resids reslets = np.array([mda.lib.util.convert_aa_code(name) for name in rg.resnames]) residues = np.array([f'{reslet}{resid}' for reslet, resid in zip(reslets, resids)]) times = [contacts[contacts[:, 0] == i][:, 3] for i in run_resids] inds = np.array([np.where(resids == resid)[0][0] for resid in run_resids]) residues = residues[inds] input_list = [[residues[i], times[i].copy(), i % self.nproc, self.ncomp, self.niter, self.cutoff, g, is_combined] for i in range(len(residues))] del contacts, times gc.collect() with (Pool(self.nproc, initializer=tqdm.set_lock, initargs=(Lock(),)) as p): try: for _ in tqdm(p.istarmap(run_residue, input_list), total=len(residues), position=0, desc='overall progress'): pass except KeyboardInterrupt: pass
[docs] def run_residue(residue, time, proc, ncomp, niter, cutoff, g, from_combined=False): """Run Gibbs sampler for a single residue. :param residue: Residue name :type residue: str :param time: Residence times data :type time: array-like :param proc: Process number for progress bar positioning :type proc: int :param ncomp: Number of mixture components :type ncomp: int :param niter: Number of iterations :type niter: int :param cutoff: Cutoff value used in contact analysis :type cutoff: float :param g: Gibbs skip parameter :type g: int :param from_combined: Whether data comes from combined contacts :type from_combined: bool """ x = np.array(time) if len(x) != 0: try: proc = int(multiprocessing.current_process().name.split('-')[-1]) except ValueError: proc = 1 gib = Gibbs(times=x, residue=residue, loc=proc, ncomp=ncomp, niter=niter, cutoff=cutoff, g=g) gib._from_combined_contacts = from_combined gib.run()
[docs] class Gibbs(object): r"""Gibbs sampler to estimate parameters of an exponential mixture for a set of data. Results are stored in :class:`gibbs.results`, which uses :class:`MDAnalysis.analysis.base.Results()`. If 'results=None' the gibbs sampler has not been executed, which requires calling :meth:`run`. :param times: Set of residence times to analyze :type times: array, optional :param residue: Residue name associated with the set of residence times :type residue: str :param loc: Used for progress bar in parallel applications :type loc: int :param ncomp: Number of exponential components to use in the mixture model :type ncomp: int :param niter: Number of iterations to run the Gibbs sampler :type niter: int :param cutoff: Cutoff value used in contact analysis, used to determine directory to load/save results. Allows for multiple cutoffs to be tested in directory containing contacts. :type cutoff: float :param g: Gibbs skip parameter for decorrelated samples; only save every `g` samples from full Gibbs sampler chain; default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522 (NOTE: this value is called *gskip* in cluster.py) :type g: int :param burnin: Burn-in parameter, drop first `burnin` samples as equilibration; default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522 :type burnin: int :param gskip: Process data from the subsampled chain (ever `g` samples) at a coarser skip interval of `gskip` samples. Thus, in total, samples are taken at ``g * gskip`` steps from the full chain. (This is useful for sensitivity analysis where we run the chain with a small `g` value and save many samples and then use `gskip` to process samples at increasingly larger intervals without having to re-run the chain.) The default value of 1 means that the samples are processed at every `g` samples from the full chain. :type gskip: int EXAMPLE ------- >>> from basicrta.gibbs import Gibbs >>> from basicrta.tests.datafiles import times >>> g = Gibbs(times=times, residue='W313', cutoff=7.0) >>> g.run() >>> g.process_gibbs() >>> g.estimate_tau() [1, 2, 3] To load a Gibbs sampler that has already been executed use the :meth:`load` method >>> g = Gibbs().load('results.pkl') The Gibbs sampler can be executed using the :meth:`run` method without processing the resulting data. Once the :meth:`process_gibbs` method is called, the :attr:`Gibbs.results.processed_results` attribute will be populated. """ def __init__(self, times=None, residue=None, loc=0, ncomp=15, niter=110000, cutoff=None, g=100, burnin=10000, gskip=1): self.times = times self.residue = residue self.niter = niter self.loc = loc self.ncomp = ncomp self.g = g self.gskip = gskip self.burnin = burnin self.cutoff = cutoff self.processed_results = Results() self._noise_cutoff = 0.4 if times is not None: diff = (np.sort(times)[1:]-np.sort(times)[:-1]) try: self.ts = diff[diff != 0][0] except IndexError: self.ts = times.min() else: self.ts = None self.keys = {'times', 'residue', 'loc', 'ncomp', 'niter', 'g', 'burnin', 'processed_results', 'ts', 'mcweights', 'mcrates', 't', 's', 'cutoff', 'indicator'} def __getitem__(self, item): return getattr(self, item) def _prepare(self): from basicrta.util import get_s self.t, self.s = get_s(self.times, self.ts) # initialize arrays self.indicator = np.zeros(((self.niter + 1) // self.g, self.times.shape[0]), dtype=np.uint8) self.mcweights = np.zeros(((self.niter + 1) // self.g, self.ncomp)) self.mcrates = np.zeros(((self.niter + 1) // self.g, self.ncomp)) # guess hyperparameters self.whypers = np.ones(self.ncomp) / [self.ncomp] self.rhypers = np.ones((self.ncomp, 2)) * [1, 3]
[docs] def run(self): r""" Execute the Gibbs sampler and save the raw data to the instance of :class:`Gibbs`. """ # initialize weights and rates self._prepare() if not os.path.exists(f'basicrta-{self.cutoff}/{self.residue}'): os.makedirs(f'basicrta-{self.cutoff}/{self.residue}') inrates = 0.5 * 10 ** np.arange(-self.ncomp + 2, 2, dtype=float) tmpw = 9 * 10 ** (-np.arange(1, self.ncomp + 1, dtype=float)) weights, rates = tmpw / tmpw.sum(), inrates[::-1] # gibbs sampler for j in tqdm(range(1, self.niter+1), desc=f'{self.residue}-K{self.ncomp}', position=self.loc, leave=False): # compute probabilities (equation 7) tmp = weights*rates*np.exp(np.outer(-rates, self.times)).T psample = (tmp.T/tmp.sum(axis=1)).T # sample indicator z = np.argmax(rng.multinomial(1, psample), axis=1) # get indicator for each data point inds = [np.where(z == i)[0] for i in range(self.ncomp)] # compute total time and number of point for each component Ns = np.array([len(inds[i]) for i in range(self.ncomp)]) Ts = np.array([self.times[inds[i]].sum() for i in range(self.ncomp)]) # sample posteriors (equations 8 and 9) weights = rng.dirichlet(self.whypers+Ns) rates = rng.gamma(self.rhypers[:, 0]+Ns, 1/(self.rhypers[:, 1]+Ts)) # save every g steps if j % self.g == 0: ind = j//self.g-1 self.mcweights[ind], self.mcrates[ind] = weights, rates self.indicator[ind] = z self.save()
[docs] def cluster(self, method="GaussianMixture", **kwargs): r""" Cluster the processed results using the methods available in :class:`sklearn.mixture` :param method: Mixture method to use :type method: str """ # Check if this Gibbs result was created from combined contact data if hasattr(self, '_from_combined_contacts') and self._from_combined_contacts: print("INFO: Using combined contact data for clustering. " "Trajectory source information is pooled together.") from sklearn import mixture from scipy import stats clu = getattr(mixture, method) burnin_ind = self.burnin // self.g data_len = len(self.times) wcutoff = 10 / data_len weights = self.mcweights[burnin_ind::self.gskip] rates = self.mcrates[burnin_ind::self.gskip] lens = np.array([len(row[row > wcutoff]) for row in weights]) lmode = stats.mode(lens).mode train_param = lmode train_inds = np.where(lens == train_param)[0] train_weights = (weights[train_inds][weights[train_inds] > wcutoff]. reshape(-1, train_param)) train_rates = (rates[train_inds][weights[train_inds] > wcutoff]. reshape(-1, train_param)) inds = np.where(weights > wcutoff) aweights, arates = weights[inds], rates[inds] data = np.stack((aweights, arates), axis=1) tweights, trates = train_weights.flatten(), train_rates.flatten() train_data = np.stack((tweights, trates), axis=1) r = clu(**kwargs) r.fit(np.log(train_data)) all_labels = r.predict(np.log(data)) if self.indicator is not None: indicator = self.indicator[burnin_ind::self.gskip] else: indicator = self._sample_indicator() pindicator = np.zeros((self.times.shape[0], lmode)) for j in np.unique(inds[0]): mapinds = all_labels[inds[0] == j] for i, indx in enumerate(inds[1][inds[0] == j]): tmpind = np.where(indicator[j] == indx)[0] pindicator[tmpind, mapinds[i]] += 1 pindicator = (pindicator.T / pindicator.sum(axis=1)).T self.processed_results.indicator = pindicator self.processed_results.labels = all_labels
[docs] def process_gibbs(self, show=False): r""" Process the samples collected from the Gibbs sampler. :meth:`process_gibbs` can be called multiple times to check the robustness of the results. """ from basicrta.util import mixture_and_plot from scipy import stats data_len = len(self.times) wcutoff = 10/data_len burnin_ind = self.burnin//self.g inds = np.where(self.mcweights[burnin_ind::self.gskip] > wcutoff) indices = (np.arange(self.burnin, self.niter + 1, self.g*self.gskip) [inds[0]] // self.g) weights = self.mcweights[burnin_ind::self.gskip] rates = self.mcrates[burnin_ind::self.gskip] fweights, frates = weights[inds], rates[inds] lens = [len(row[row > wcutoff]) for row in self.mcweights[burnin_ind::self.gskip]] lmode = stats.mode(lens).mode self.cluster(n_init=117, n_components=lmode) labels, presorts = mixture_and_plot(self, show=show) self.processed_results.labels = labels self.processed_results.indicator = self.processed_results.indicator[:, presorts] attrs = ["weights", "rates", "ncomp", "residue", "iteration", "niter"] values = [fweights, frates, lmode, self.residue, indices, self.niter] for attr, val in zip(attrs, values): setattr(self.processed_results, attr, val) self._estimate_params() self.save()
[docs] def result_plot(self, remove_noise=False, **kwargs): """ Generate the combined result plot with option to change kwargs without re-clustering. :param remove_noise: Option to remove noise clusters :type remove_noise: bool """ from basicrta.util import mixture_and_plot mixture_and_plot(self, remove_noise=remove_noise, **kwargs)
def _sample_indicator(self): indicator = np.zeros(((self.niter+1)//(self.g*self.gskip), self.times.shape[0]), dtype=np.uint8) burnin_ind = self.burnin//self.g for i, (w, r) in enumerate(zip(self.mcweights, self.mcrates)): # compute probabilities probs = w*r*np.exp(np.outer(-r, self.times)).T z = (probs.T/probs.sum(axis=1)).T # sample indicator s = np.argmax(rng.multinomial(1, z), axis=1) indicator[i] = s self.indicator = indicator return indicator[burnin_ind::self.gskip]
[docs] def save(self): """ Save current state of the :class:`Gibbs` instance. """ savedir = f'basicrta-{self.cutoff}/{self.residue}/' filename = f'gibbs_{self.niter}.pkl' if os.path.exists(savedir): if os.path.exists(savedir+filename): os.rename(savedir+filename, savedir+filename+'.bak') with open(f'basicrta-{self.cutoff}/{self.residue}/gibbs_' f'{self.niter}.pkl', 'w+b') as f: pickle.dump(self, f) else: raise OSError(f'No such directory: {savedir}')
[docs] @staticmethod def load(file): """ Load an instance of :class:`Gibbs`. :param file: Path to instance of :class:`Gibbs` :type file: str """ from basicrta.util import get_s keys = ['times', 'residue', 'loc', 'ncomp', 'niter', 'g', 'burnin', 'processed_results', 'ts', 'mcweights', 'mcrates', 't', 's', 'cutoff', 'indicator', 'whypers', 'rhypers'] with open(file, 'r+b') as f: r = pickle.load(f) g = Gibbs() for attr in keys: try: setattr(g, attr, r[f'{attr}']) except AttributeError: setattr(g, attr, None) if isinstance(g.residue, np.ndarray): g.residue = g.residue[0] if g.t is None: g.t, g.s = get_s(g.times, g.ts) # if len(g.processed_results) == 0: # g._process_gibbs() return g
[docs] def plot_tau_hist(self, scale=1, save=False): r""" Plot histogram of tau values. The figure aspect ratio is 4:3, and can be made larger/smaller using the `scale` argument. :param scale: Increase plot size by this factor :type scale: float :param save: Save plot to file :type save: bool """ from matplotlib.ticker import MaxNLocator cmap = mpl.colormaps['tab10'] rp = self.processed_results imaxs = self.processed_results.indicator.max(axis=0) noise_inds = np.where(imaxs < self._noise_cutoff)[0] inds = np.delete(np.unique(rp.labels), noise_inds) i = rp.parameters[inds, 1].argmin() fig, ax = plt.subplots(1, figsize=(4*scale, 3*scale)) ax.hist(1/rp.rates[rp.labels == i], label=f'{i}', alpha=0.5, color=cmap(i)) ax.set_xlabel(r'$\tau$ [ns]') ax.set_ylabel('count') tmin = (1/rp.rates[rp.labels == i]).min() tmax = (1/rp.rates[rp.labels == i]).max() ax.set_xlim(tmin, tmax) ax.xaxis.set_major_locator(MaxNLocator(4)) ax.xaxis.set_minor_locator(MaxNLocator(12)) ax.yaxis.set_major_locator(MaxNLocator(3)) ax.yaxis.set_minor_locator(MaxNLocator(12)) # ax.ticklabel_format(style='sci', axis='x', scilimits=(0, 0), # useMathText=True) plt.tight_layout() if save: plt.savefig(f'basicrta-{self.cutoff}/{self.residue}/' f'tau_hist.png', bbox_inches='tight') plt.savefig(f'basicrta-{self.cutoff}/{self.residue}/' f'tau_hist.pdf', bbox_inches='tight') plt.show()
[docs] def plot_hist(self, scale=1, save=False, component=None, bins=15): from matplotlib.ticker import MaxNLocator from scipy import stats from matplotlib.gridspec import GridSpec cmap = mpl.colormaps['tab10'] rp = self.processed_results if component is None: comps = np.arange(rp.ncomp) elif isinstance(component, int): comps = [component] else: comps = component if self.whypers is None: self._prepare() i = comps[0] fig = plt.figure(figsize=(9*scale, 3*scale)) gs = GridSpec(4, 12, figure=fig, hspace=0.2, wspace=0.2, bottom=0.28, left=0.05, right=0.98, top=0.93) ax0 = fig.add_subplot(gs[:, :4]) ax1 = np.array([[fig.add_subplot(gs[:-1, 4:7]), fig.add_subplot(gs[:-1, 7])], [fig.add_subplot(gs[-1, 4:7]), fig.add_subplot(gs[-1, 7])]]) ax2 = fig.add_subplot(gs[0, 8:]), fig.add_subplot(gs[1:, 8:]) # plot posteriors [ax0.hist(rp.weights[rp.labels == i], label='posterior', alpha=0.5, color=cmap(i), density=True, bins=bins) for i in comps] [ax1[0, 0].hist(rp.rates[rp.labels == i], label=f'{i}', alpha=0.5, color=cmap(i), density=True, bins=bins) for i in comps] [ax1[1, 0].hist(rp.rates[rp.labels == i], label=f'{i}', alpha=0.5, color=cmap(i), density=True, bins=bins) for i in comps] [ax2[1].hist(1/rp.rates[rp.labels == i], label=f'{i}', alpha=0.5, color=cmap(i), density=True, bins=bins) for i in comps] # create bounds and plot priors wbounds = np.array([[rp.weights[rp.labels == i].min(), rp.weights[rp.labels == i].max()] for i in comps]) rbounds = np.array([[rp.rates[rp.labels == i].min(), rp.rates[rp.labels == i].max()] for i in comps]) tbounds = np.array([[(1/rp.rates[rp.labels == i]).min(), (1/rp.rates[rp.labels == i]).max()] for i in comps]) rx = np.linspace(0, 10, 10000) tx = np.linspace(0, 500, 10000) ax0.hist(rng.dirichlet(self.whypers, size=1000000).flatten(), density=True, bins=20000, label='prior', alpha=0.5) rys = (stats.gamma(self.rhypers[0, 0], scale=1/self.rhypers[0, 1]). pdf(rx)) tys = (stats.invgamma(self.rhypers[0, 0], scale=self.rhypers[0, 1]). pdf(tx)) ax1[1, 0].plot(rx, rys, label=f'{i}', alpha=0.5) ax1[1, 0].fill_between(rx, rys, alpha=0.5) ax1[1, 1].plot(rx, rys, label=f'{i}', alpha=0.5) ax1[1, 1].fill_between(rx, rys, alpha=0.5) ax2[0].plot(tx, tys, label=f'{i}', alpha=0.5) ax2[0].fill_between(tx, tys, alpha=0.5) ax2[1].plot(tx, tys, label=f'{i}', alpha=0.5) ax2[1].fill_between(tx, tys, alpha=0.5) ax1[0, 0].spines['bottom'].set_visible(False) ax1[0, 1].spines['bottom'].set_visible(False) ax1[1, 0].spines['top'].set_visible(False) ax1[1, 1].spines['top'].set_visible(False) ax1[0, 0].spines['right'].set_visible(False) ax1[1, 0].spines['right'].set_visible(False) ax1[0, 1].spines['left'].set_visible(False) ax1[1, 1].spines['left'].set_visible(False) ax1[0, 0].tick_params(axis='x', labelbottom=False) ax1[0, 1].tick_params(axis='x', labelbottom=False) ax1[0, 1].tick_params(axis='y', labelleft=False) ax1[1, 1].tick_params(axis='y', labelleft=False) ax2[0].spines['bottom'].set_visible(False) ax2[1].spines['top'].set_visible(False) ax2[0].tick_params(axis='x', labelbottom=False) ax2[0].set_xticks([]) d = 0.15 kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False) kwargs2 = dict(marker=[(1+d, 0), (0, 1+d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False) ax1[0, 0].plot([0], transform=ax1[0, 0].transAxes, **kwargs) ax1[1, 0].plot([1], transform=ax1[1, 0].transAxes, **kwargs) ax1[0, 1].plot([1], [0], transform=ax1[0, 1].transAxes, **kwargs) ax1[1, 1].plot([1], [1], transform=ax1[1, 1].transAxes, **kwargs) ax1[0, 0].plot([1], [1], transform=ax1[0, 0].transAxes, **kwargs2) ax2[0].plot([0, 1], [0, 0], transform=ax2[0].transAxes, **kwargs) ax2[1].plot([0, 1], [1, 1], transform=ax2[1].transAxes, **kwargs) ax0.set_xlabel(r'$\pi_k$') ax1[1, 0].set_xlabel(r'$\lambda_k$ [ns$^{-1}$]') # set_shared_xlabel(ax1[1, :], label=r'$\lambda_k$ [ns$^{-1}$]') ax2[1].set_xlabel(r'$\tau$ [ns]') ax0.set_ylabel('p') if component is None: ax1[0].set_xlim(1e-4, 1) ax1[1].set_xlim(1e-3, 10) ax1[0].legend(title='component') ax1[1].legend(title='component') ax1[0].set_xscale('log') ax1[1].set_xscale('log') else: ax0.set_xlim(1e-5, 1e-3) ax1[0, 0].set_xlim(1e-4, 1e-2) ax1[1, 0].set_xlim(1e-4, 1e-2) ax1[0, 1].set_xlim(1e-2, 10) ax1[1, 1].set_xlim(1e-2, 10) ax1[0, 0].set_ylim(5, 1200) ax1[0, 1].set_ylim(5, 1200) ax1[1, 0].set_ylim(0, 5) ax1[1, 1].set_ylim(0, 5) ax2[0].set_xlim(-5, 500) ax2[1].set_xlim(-5, 500) ax2[0].set_ylim(0.02, 0.2) ax2[1].set_ylim(0, 0.015) ax0.xaxis.set_major_locator(MaxNLocator(3, min_n_ticks=3, prune='both')) ax0.xaxis.set_minor_locator(MaxNLocator(12, min_n_ticks=9, prune='both')) ax0.yaxis.set_major_locator(MaxNLocator(3, min_n_ticks=3, prune='both')) ax0.yaxis.set_minor_locator(MaxNLocator(12, min_n_ticks=9, prune='both')) ax0.ticklabel_format(style='sci', axis='both', scilimits=(0, 0), useMathText=True) ax1[1, 0].xaxis.set_major_locator(MaxNLocator(3, min_n_ticks=3, prune='both')) ax1[1, 0].xaxis.set_minor_locator(MaxNLocator(12, min_n_ticks=9, prune='both')) ax1[1, 1].xaxis.set_major_locator(MaxNLocator(3, min_n_ticks=3, prune='both')) ax1[1, 1].xaxis.set_minor_locator(MaxNLocator(12, min_n_ticks=9, prune='both')) ax1[0, 0].yaxis.set_major_locator(MaxNLocator(3, min_n_ticks=3, prune='both')) ax1[0, 0].yaxis.set_minor_locator(MaxNLocator(12, min_n_ticks=9, prune='both')) ax1[1, 0].yaxis.set_major_locator(MaxNLocator(3, min_n_ticks=3, prune='both')) ax1[1, 0].yaxis.set_minor_locator(MaxNLocator(12, min_n_ticks=9, prune='both')) ax1[0, 0].ticklabel_format(style='sci', axis='y', scilimits=(1, 1), useMathText=True) ax1[1, 0].ticklabel_format(style='sci', axis='y', scilimits=(1, 1), useMathText=True) ax1[1, 0].ticklabel_format(style='sci', axis='x', scilimits=(-3, -3), useMathText=True) ax1[0, 1].ticklabel_format(style='sci', axis='x', scilimits=(0, 0), useMathText=True) ax1[0, 1].ticklabel_format(style='sci', axis='y', scilimits=(1, 1), useMathText=True) ax2[0].yaxis.set_major_locator(MaxNLocator(3, min_n_ticks=2, prune='both')) ax2[0].yaxis.set_minor_locator(MaxNLocator(12, min_n_ticks=9, prune='both')) ax2[1].yaxis.set_major_locator(MaxNLocator(3, min_n_ticks=3, prune='both')) ax2[1].yaxis.set_minor_locator(MaxNLocator(15, min_n_ticks=9, prune='both')) ax2[1].xaxis.set_major_locator(MaxNLocator(3, min_n_ticks=3, prune='both')) ax2[1].xaxis.set_minor_locator(MaxNLocator(12, min_n_ticks=9, prune='both')) ax2[0].ticklabel_format(style='sci', axis='x', scilimits=(0, 0), useMathText=True) ax2[0].ticklabel_format(style='sci', axis='y', scilimits=(-1, -1), useMathText=True) ax2[1].ticklabel_format(style='sci', axis='y', scilimits=(-1, -1), useMathText=True) ax2[1].ticklabel_format(style='sci', axis='x', scilimits=(2, 2), useMathText=True) ax1[0, 0].set_xticks([]) ax1[0, 1].set_xticks([]) ax1[0, 1].set_yticks([]) ax1[1, 1].set_yticks([]) handles, labels = ax0.get_legend_handles_labels() fig.legend(handles, labels, loc='lower center', ncols=2) if save: if component is not None: plt.savefig(f'basicrta-{self.cutoff}/{self.residue}/' f'hist_results_{component}.png', bbox_inches='tight') plt.savefig(f'basicrta-{self.cutoff}/{self.residue}/' f'hist_results_{component}.pdf', bbox_inches='tight') else: plt.savefig(f'basicrta-{self.cutoff}/{self.residue}/' 'hist_results.png', bbox_inches='tight') plt.savefig(f'basicrta-{self.cutoff}/{self.residue}/' 'hist_results.pdf', bbox_inches='tight') plt.show()
[docs] def plot_gibbs(self, scale=1.5, sparse=1, save=False): cmap = mpl.colormaps['tab10'] rp = self.processed_results fig, ax = plt.subplots(2, figsize=(4*scale, 3*scale), sharex=True) [ax[0].plot(rp.iteration[rp.labels == i][::sparse], rp.weights[rp.labels == i][::sparse], '.', label=f'{i}', color=cmap(i)) for i in np.unique(rp.labels)] ax[0].set_yscale('log') ax[0].set_ylabel(r'$\pi_k$') [ax[1].plot(rp.iteration[rp.labels == i][::sparse], rp.rates[rp.labels == i][::sparse], '.', label=f'{i}', color=cmap(i)) for i in np.unique(rp.labels)] ax[1].set_yscale('log') ax[1].set_ylabel(r'\lambda_k (ns$^{-1}$)') ax[1].set_xlabel('sample') ax[0].legend(title='component') ax[1].legend(title='component') plt.tight_layout() if save: plt.savefig(f'basicrta-{self.cutoff}/{self.residue}/' 'plot_results.png', bbox_inches='tight') plt.savefig(f'basicrta-{self.cutoff}/{self.residue}/' 'plot_results.pdf', bbox_inches='tight') plt.show()
def _estimate_params(self): rp = self.processed_results ws = [rp.weights[rp.labels == i] for i in range(rp.ncomp)] rs = [rp.rates[rp.labels == i] for i in range(rp.ncomp)] wbins = [np.exp(np.linspace(np.log(rp.weights[rp.labels == i].min()), np.log(rp.weights[rp.labels == i].max()), 20)) for i in range(rp.ncomp)] rbins = [np.exp(np.linspace(np.log(rp.rates[rp.labels == i].min()), np.log(rp.rates[rp.labels == i].max()), 20)) for i in range(rp.ncomp)] wbounds = np.array([confidence_interval(d) for d in ws]) rbounds = np.array([confidence_interval(d) for d in rs]) whists = [np.histogram(w, bins=bins) for w, bins in zip(ws, wbins)] rhists = [np.histogram(r, bins=bins) for r, bins in zip(rs, rbins)] params = np.array([[wh[1][np.argmax(wh[0])], rh[1][np.argmax(rh[0])]] for wh, rh in zip(whists, rhists)]) rp.parameters = params rp.intervals = np.array([wbounds, rbounds])
[docs] def estimate_tau(self): r""" Estimate the posterior maximum and confidence interval (CI) for the :math:`tau` distribution of the slowest process. NOTE: In the future this will return an array containing :math:`tau` and CI for all clusters. :return: An array containing the posterior maximum and bounds of the 95% confidence interval in the format [LB, max, UB]. :rtype: list """ bintype='sqrt' rp = self.processed_results imaxs = self.processed_results.indicator.max(axis=0) noise_inds = np.where(imaxs < self._noise_cutoff)[0] inds = np.delete(np.unique(rp.labels), noise_inds) index = rp.parameters[inds, 1].argmin() taus = 1 / rp.rates[rp.labels == index] wts = rp.weights[rp.labels == index] ci = confidence_interval(taus) h = np.histogram(taus, bins='sqrt') indmax = h[0].argmax() val = 0.5 * (h[1][:-1][indmax] + h[1][1:][indmax]) # Used for finding maximum of weight vs tau 2d distribution #wbins = np.histogram_bin_edges(wts, bins=bintype) #rbins = np.histogram_bin_edges(taus, bins=bintype) #vals, ws, rs = np.histogram2d(wts, taus, bins=[wbins,rbins]) #indmax = np.unravel_index(vals.argmax(), vals.shape) #val = 0.5 * (rs[:-1] + rs[1:])[indmax[1]] return [ci[0], val, ci[1]]
[docs] def plot_surv(self, scale=1, remove_noise=False, save=False, xlim=None, ylim=(1e-6, 5), xmajor=None, xminor=None, xscale='linear', yscale='log'): """ Plot the survival function with the exponential mixture components where parameters are determined from the clustering results. :param scale: Modify the size of the figure by this factor :type scale: float :param remove_noise: Whether to remove noise clusters :type remove_noise: bool :param save: Whether to save the figure :type save: bool :param xlim: X-axis limits :type xlim: tuple :param ylim: Y-axis limits :type ylim: tuple :param xmajor: X-axis major tick :type xmajor: int :param xminor: X-axis minor tick :type xminor: int """ from matplotlib.ticker import MultipleLocator, MaxNLocator if xmajor is None: maj_loc = MaxNLocator(nbins=3) else: maj_loc = MultipleLocator(xmajor) if xminor is None: min_loc = MaxNLocator(nbins=12) else: min_loc = MultipleLocator(xminor) cmap = mpl.colormaps['tab10'] rp = self.processed_results imaxs = self.processed_results.indicator.max(axis=0) noise_inds = np.where(imaxs < self._noise_cutoff)[0] uniq_labels = np.unique(rp.labels) if remove_noise: uniq_labels = np.delete(uniq_labels, noise_inds) ws, rs = rp.parameters[:, 0], rp.parameters[:, 1] fig, ax = plt.subplots(1, figsize=(4*scale, 3*scale)) ax.plot(self.t, self.s, '.') [ax.plot(self.t, ws[i]*np.exp(-rs[i]*self.t), label=f'{i}', color=cmap(i)) for i in np.unique(uniq_labels)] ax.set_ylim(ylim) ax.set_xlim(xlim) ax.set_yscale(yscale) ax.set_xscale(xscale) ax.set_ylabel('survival function $s$') ax.set_xlabel(r'$t$ [ns]') ax.set_yticks([1, 1e-2, 1e-4]) ax.xaxis.set_major_locator(maj_loc) ax.xaxis.set_minor_locator(min_loc) ax.legend(title='cluster') plt.tight_layout() if save: plt.savefig(f'basicrta-{self.cutoff}/{self.residue}/' 's_vs_t.png', bbox_inches='tight') plt.savefig(f'basicrta-{self.cutoff}/{self.residue}/' 's_vs_t.pdf', bbox_inches='tight') plt.show()
[docs] def get_parser(): import argparse parser = argparse.ArgumentParser(description="""run gibbs samplers for all or a specified residue present in the contact map""", formatter_class=argparse.ArgumentDefaultsHelpFormatter) required = parser.add_argument_group('required arguments') required.add_argument('--contacts', required=True, help="""Contact file produced from `basicrta contacts`, default is contacts_{cutoff}.pkl""") parser.add_argument('--resid', type=int, help="""run gibbs sampler for this residue. Will collect cutoff from contact file name.""") parser.add_argument('--nproc', type=int, default=1, help="""number of processes to use in multiprocessing""") parser.add_argument('--niter', type=int, default=110000, help="""number of iterations to use for the gibbs sampler""") parser.add_argument('--ncomp', type=int, default=15, help="""number of components to use for the exponential mixture model""") # this is to make the cli work, should be just a temporary solution parser.add_argument('gibbs', nargs='?', help=argparse.SUPPRESS) return parser
[docs] def main(): parser = get_parser() args = parser.parse_args() contact_path = os.path.abspath(args.contacts) cutoff = args.contacts.split('/')[-1].strip('.pkl').split('_')[-1] ParallelGibbs(contact_path, nproc=args.nproc, ncomp=args.ncomp, niter=args.niter).run(run_resids=args.resid)
if __name__ == '__main__': exit(main())