"""This module provides the ProcessProtein class, which collects and processes
Gibbs sampler data.
"""
import os
import gc
import warnings
import numpy as np
from tqdm import tqdm
from multiprocessing import Pool, Lock
from glob import glob
import MDAnalysis as mda
from MDAnalysis.analysis.base import Results
from basicrta import istarmap
from basicrta.gibbs import Gibbs
gc.enable()
[docs]
class ProcessProtein(object):
r"""ProcessProtein is the class that collects and processes Gibbs sampler
data. This class collects results for all residues in the
`basicrta-{cutoff}` directory and can write out a :math:`\tau` vs resid
numpy array or plot :math:`\tau` vs resid. If a structure is provided,
:math:`\tau` will be written as b-factors for visualization.
:param niter: Number of iterations used in the Gibbs samplers
:type niter: int
:param prot: Name of protein in `tm_dict.txt`, used to draw TM bars in
:math:`tau` vs resid plot.
:type prot: str, optional
:param cutoff: Cutoff used in contact analysis.
:type cutoff: float
:param gskip: Gibbs skip parameter for decorrelated samples;
only save every `gskip` samples from full Gibbs sampler chain;
default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522
When the sampled Markov chain is loaded, then the output is already
saved at every `Gibbs.g` samples. We calculate a new `gskip` value to
get close to the desired `gskip` value.
:type gskip: int
:param burnin: Burn-in parameter, drop first `burnin` samples as equilibration;
default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522
:type burnin: int
"""
def __init__(self, niter, prot, cutoff,
gskip=100, burnin=10000,
taus=None, bars=None):
self.residues = None
self.niter = niter
self.prot = prot
self.cutoff = cutoff
self.gskip = gskip
self.burnin = burnin
self.taus = taus
self.bars = bars
def __getitem__(self, item):
return getattr(self, item)
def _single_residue(self, adir, process=False):
if os.path.exists(f'{adir}/gibbs_{self.niter}.pkl'):
result = f'{adir}/gibbs_{self.niter}.pkl'
try:
g = Gibbs().load(result)
except:
result = None
tau = [0, 0, 0]
else:
if process:
# calculate the new g.gskip value:
ggskip = self.gskip // g.g
if ggskip < 1:
ggskip = 1
warnings.warn(f"WARNING: gskip={self.gskip} is less than g={g.g}, setting gskip to 1")
# NOTE: Gibbs samples are saved every g.g steps, then sub-sampled by g.gskip
# Total skip interval = g.g * g.gskip, giving niter // (g.g * g.gskip) independent samples
g.gskip = ggskip # process every g.g * g.gskip samples from full chain
g.burnin = self.burnin
try:
g.process_gibbs()
except ValueError:
# HACK: triggered when we do not have enough samples for clustering
# TODO: make sure elsewhere that we do not save pickle files
# with insufficient data
# TODO: use a logger and say that we failed for this residue even though
# sampler data was (supposedly) available
result = None
tau = [0, 0, 0]
else:
tau = g.estimate_tau()
else:
tau = g.estimate_tau()
else:
# if the pkl files do not exist
result = None
tau = [0, 0, 0]
residue = adir.split('/')[-1]
return residue, tau, result
#setattr(self.residues, f'{residue}', Results())
#setattr(self.residues[f'{residue}'], 'file', result)
#setattr(self.residues[f'{residue}'], 'tau', tau)
#return self
[docs]
def reprocess(self, nproc=1):
"""Rerun processing and clustering on :class:`Gibbs` data.
:param nproc: Number of processes to use in clustering results for all
residues.
:type nproc: int
"""
from basicrta.util import get_bars
dirs = np.array(glob(f'basicrta-{self.cutoff}/?[0-9]*'))
sorted_inds = (np.array([int(adir.split('/')[-1][1:]) for adir in dirs])
.argsort())
dirs = dirs[sorted_inds]
inarr = np.array([[adir, True] for adir in dirs])
with (Pool(nproc, initializer=tqdm.set_lock,
initargs=(Lock(),)) as p):
try:
residues, taus, results = [], [], []
for residue, tau, result in tqdm(p.istarmap(self._single_residue, inarr),
total=len(dirs), position=0,
desc='overall progress'):
residues.append(residue)
taus.append(tau)
results.append(result)
gc.collect()
pass
except KeyboardInterrupt:
pass
taus = np.array(taus)
bars = get_bars(taus)
self.taus = taus[:, 1]
self.bars = bars
self.residues = np.array(residues)
self.files = np.array(results)
[docs]
def get_taus(self, nproc=1):
r"""Get :math:`\tau` and 95\% confidence interval bounds for the slowest
process for each residue.
:returns: Returns a tuple of the form (:math:`\tau`, [CI lower bound,
CI upper bound])
:rtype: tuple
"""
from basicrta.util import get_bars
dirs = np.array(glob(f'basicrta-{self.cutoff}/?[0-9]*'))
sorted_inds = (np.array([int(adir.split('/')[-1][1:]) for adir in dirs])
.argsort())
dirs = dirs[sorted_inds]
with (Pool(nproc, initializer=tqdm.set_lock,
initargs=(Lock(),)) as p):
try:
residues, taus, results = [], [], []
for residue, tau, result in tqdm(p.imap(self._single_residue, dirs),
total=len(dirs), position=0,
desc='overall progress'):
residues.append(residue)
taus.append(tau)
results.append(result)
except KeyboardInterrupt:
pass
#taus = []
#for res in tqdm(self.residues, total=len(self.residues)):
# taus.append(res.tau)
taus = np.array(taus)
bars = get_bars(taus)
self.taus = taus[:, 1]
self.bars = bars
self.residues = np.array(residues)
self.files = np.array(results)
return taus[:, 1], bars
[docs]
def write_data(self, fname='tausout'):
r"""Write :math:`\tau` values with 95\% confidence interval to a numpy
file with the format [`sel1` resid, :math:`\tau`, CI lower bound, CI
upper bound].
:param fname: Filename to save data to.
:type fname: str, optional
"""
if self.taus is None:
taus, bars = self.get_taus()
# Handle residues as numpy array (from reprocess/get_taus methods)
# TODO: double-check that we need to use res[1:] and can't get this easier
residues = np.array([int(res[1:]) for res in self.residues])
data = np.stack((residues, self.taus, self.bars[0], self.bars[1]))
np.save(fname, data.T)
[docs]
def plot_protein(self, **kwargs):
r"""Plot :math:`\tau` vs resid. kwargs are passed to the
:meth:`plot_protein` method of `util.py`. These can be used to change
the labeling cutoff, y-limit of the plot, scale the figure, and set
major and minor ticks.
"""
from basicrta.util import plot_protein
if self.taus is None:
self.get_taus()
residues = self.residues
residues = [res.split('/')[-1] for res in residues]
exclude_inds = np.where(self.bars < 0)[1]
taus = np.delete(self.taus, exclude_inds)
bars = np.delete(self.bars, exclude_inds, axis=1)
residues = np.delete(residues, exclude_inds)
plot_protein(residues, taus, bars, self.prot, **kwargs)
[docs]
def b_color_structure(self, structure):
r"""Add :math:`\tau` to b-factors in the specified structure. Saves
structure with b-factors to `tau_bcolored.pdb`.
"""
if self.taus is None:
taus, bars = self.get_taus()
cis = bars[1]+bars[0]
errs = taus/cis
errs[errs != errs] = 0
residues = list(self.residues.keys())
u = mda.Universe(structure)
u.add_TopologyAttr('tempfactors')
u.add_TopologyAttr('occupancies')
for tau, err, residue in tqdm(zip(taus, errs, residues)):
res = u.select_atoms(f'protein and resid {residue[1:]}')
res.tempfactors = np.round(tau, 2)
res.occupancies = np.round(err, 2)
u.select_atoms('protein').write('tau_bcolored.pdb')
[docs]
def get_parser():
import argparse
parser = argparse.ArgumentParser(description="""perform clustering for each
residue located in basicrta-{cutoff}/""",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
required = parser.add_argument_group('required arguments')
required.add_argument('--cutoff', required=True, type=float, help="""cutoff
used in contact analysis, will cluster results in
basicrta-{cutoff}/""")
parser.add_argument('--nproc', type=int, default=1, help="""number of
processes to use in multiprocessing""")
parser.add_argument('--niter', type=int, default=110000, help="""number of
iterations used in the gibbs sampler, used to load
gibbs_{niter}.pkl""")
parser.add_argument('--prot', type=str, nargs='?', help="""name of protein
in tm_dict.txt, used to draw TM bars in tau vs resid
plot""")
parser.add_argument('--label_cutoff', type=float, default=3,
dest='label_cutoff',
help="""Only label residues with tau >
LABEL-CUTOFF * <tau>.""")
parser.add_argument('--structure', type=str, nargs='?', help="""will add tau
as bfactors to the structure if provided""")
# use for default values
parser.add_argument('--gskip', type=int, default=100,
help='Gibbs skip parameter for decorrelated samples;'
'default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522')
parser.add_argument('--burnin', type=int, default=10000,
help='Burn-in parameter, drop first N samples as equilibration;'
'default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522')
# this is to make the cli work, should be just a temporary solution
parser.add_argument('cluster', nargs='?', help=argparse.SUPPRESS)
return parser
[docs]
def main():
parser = get_parser()
args = parser.parse_args()
pp = ProcessProtein(args.niter, args.prot, args.cutoff,
gskip=args.gskip, burnin=args.burnin)
pp.reprocess(nproc=args.nproc)
pp.write_data()
pp.plot_protein(label_cutoff=args.label_cutoff)
if __name__ == "__main__": #pragma: no cover
# the script is tested in the test_cluster.py but cannot be accounted for
# in the coverage report
exit(main())