Source code for contacts

"""
Create contact maps between two atom groups.

This module provides the `MapContacts` class, which creates the initial contact
map between the two atom groups using a maximum cutoff (`max_cutoff`), which
provides for quicker processing if creating results for multiple cutoffs. The 
`ProcessContacts` class takes the initial contact map and creates the processed
contact map based on the prescribed cutoff.
"""

from tqdm import tqdm
from MDAnalysis.lib import distances
from multiprocessing import Pool, Lock
from basicrta import istarmap
import numpy as np
import multiprocessing
import collections
import MDAnalysis as mda
import pickle
import glob
import os
os.environ['MKL_NUM_THREADS'] = '1'


[docs] class MapContacts(object): """This class is used to create the map of contacts between two groups of atoms. A single cutoff is used to define a contact between the two groups, where if any atomic distance between the two groups is less than the cutoff, a contact is considered formed. :param u: Universe containing the topology and trajectory for which the contacts will be computed. :type u: `MDAnalysis Universe` :param ag1: Primary AtomGroup for which contacts will be computed, typically a protein. :type ag1: MDAnalysis AtomGroup :param ag2: Secondary AtomGroup which forms contacts with `ag1`, typically lipids, ions, or other small molecules. Each residue of `ag2` must have the same number of atoms. :type ag2: MDAnalysis AtomGroup :param nproc: Number of processes to use in computing contacts (default is 1). :type nproc: int, optional :param frames: List of frames to use in computing contacts (default is None, meaning all frames are used). :type frames: list or np.array, optional :param max_cutoff: Maximum cutoff to use in computing contacts. A primary contact map is created upon which multiple cutoffs can be imposed, i.e. in the case where a proper cutoff is being determined. This can typically be left at its default value, unless a greater value is needed (default is 10.0). :type max_cutoff: float, optional :param nslices: Number of slices to break the trajectory into for processing. If device memory is limited, try increasing `nslices` (default is 100). :type nslices: int, optional """ def __init__(self, u, ag1, ag2, nproc=1, frames=None, max_cutoff=10.0, nslices=100): self.u, self.nproc = u, nproc self.ag1, self.ag2 = ag1, ag2 self.max_cutoff = max_cutoff self.frames, self.nslices = frames, nslices self.contacts_filename = f"contacts_max{self.max_cutoff}.pkl"
[docs] def run(self): """Run contact analysis and save to `contacts_max{max_cutoff}.pkl` """ if self.frames is not None: sliced_frames = np.array_split(self.frames, self.nslices) else: sliced_frames = np.array_split(np.arange(len(self.u.trajectory)), self.nslices) input_list = [[i, self.u.trajectory[aslice]] for i, aslice in enumerate(sliced_frames)] lens = [] with (Pool(self.nproc, initializer=tqdm.set_lock, initargs=(Lock(),)) as p): for alen in tqdm(p.istarmap(self._run_contacts, input_list), total=self.nslices, position=0, desc='overall progress'): lens.append(alen) lens = np.array(lens) mapsize = sum(lens) bounds = np.concatenate([[0], np.cumsum(lens)]) dtype = np.dtype(np.float64, metadata={'top': self.u.filename, 'traj': self.u.trajectory.filename, 'ag1': self.ag1, 'ag2': self.ag2, 'ts': self.u.trajectory.dt/1000, 'max_cutoff': self.max_cutoff}) contact_map = np.memmap('.tmpmap', mode='w+', shape=(mapsize, 5), dtype=dtype) for i in range(self.nslices): contact_map[bounds[i]:bounds[i+1]] = np.genfromtxt(f'.contacts_' f'{i:04}', delimiter=',') contact_map.flush() contact_map.dump(self.contacts_filename, protocol=5) os.remove('.tmpmap') cfiles = glob.glob('.contacts*') [os.remove(f) for f in cfiles] print(f'\nSaved contacts as "{self.contacts_filename}"')
def _run_contacts(self, i, sliced_traj): from basicrta.util import get_dec try: proc = int(multiprocessing.current_process().name.split('-')[-1]) except ValueError: proc = 1 with open(f'.contacts_{i:04}', 'w+') as f: dec = get_dec(self.u.trajectory.ts.dt/1000) # convert to ns text = f'slice {i+1} of {self.nslices}' data_len = 0 for ts in tqdm(sliced_traj, desc=text, position=proc, total=len(sliced_traj), leave=False): dset = [] b = distances.capped_distance(self.ag1.positions, self.ag2.positions, max_cutoff=self.max_cutoff) pairlist = [(self.ag1.resids[b[0][i, 0]], self.ag2.resids[b[0][i, 1]]) for i in range(len(b[0]))] pairdir = collections.Counter(a for a in pairlist) lsum = 0 for j in pairdir: temp = pairdir[j] dset.append([int(ts.frame), int(j[0]), int(j[1]), float(min(b[1][lsum:lsum+temp])), float(np.round(ts.time, dec)/1000)]) # convert to ns lsum += temp [f.write(f"{line}".strip('[]') + "\n") for line in dset] data_len += len(dset) f.flush() return data_len
[docs] class ProcessContacts(object): """The :class:`ProcessProtein` class takes the primary contact map (i.e. `contacts_max10.0.pkl`) and collects contacts based on a prescribed cutoff. :param cutoff: Collect all contacts between `ag1` and `ag2` within this value. :type cutoff: float :param map_name: Name of primary contact map. The default produced by MapContacts is `contacts_max10.0.pkl`. :type map_name: str :param nproc: Number of processes to use in collecting contacts (default is 1). :type nproc: int, optional """ def __init__(self, cutoff, map_name, nproc=1): self.nproc = nproc self.map_name = map_name self.cutoff = cutoff
[docs] def run(self): """Process contacts using the prescribed cutoff and write to contacts-{cutoff}.pkl """ if os.path.exists(self.map_name): with open(self.map_name, 'r+b') as f: memmap = pickle.load(f) # memmap = np.load(self.map_name, mmap_mode='r') dtype = memmap.dtype new_metadata = dtype.metadata.copy() new_metadata['cutoff'] = self.cutoff new_dtype = np.dtype(np.float64, metadata=new_metadata) memmap = memmap[memmap[:, -2] <= self.cutoff] else: raise FileNotFoundError(f'{self.map_name} not found. Specify the ' 'contacts file using the "map_name" ' 'argument') self.ts = dtype.metadata['ts'] lresids = np.unique(memmap[:, 2]) params = [[res, memmap[memmap[:, 2] == res], i] for i, res in enumerate(lresids)] pool = Pool(self.nproc, initializer=tqdm.set_lock, initargs=(Lock(),)) try: lens = pool.starmap(self._lipswap, params) except KeyboardInterrupt: pool.terminate() pool.close() bounds = np.concatenate([[0], np.cumsum(lens)]).astype(int) mapsize = sum(lens) contact_map = np.memmap('.tmpmap', mode='w+', shape=(mapsize, 4), dtype=new_dtype) for i in range(len(lresids)): contact_map[bounds[i]:bounds[i+1]] = np.load(f'.contacts_{i:04}.' f'npy') contact_map.flush() contact_map.dump(f'contacts_{self.cutoff}.pkl', protocol=5) # os.remove('.tmpmap') # cfiles = glob.glob('.contacts*') # [os.remove(f) for f in cfiles] print(f'\nSaved contacts to "contacts_{self.cutoff}.pkl"')
def _lipswap(self, lip, memarr, i): from basicrta.util import get_dec try: # proc = int(multiprocessing.current_process().name[-1]) proc = int(multiprocessing.current_process().name.split('-')[-1]) except ValueError: proc = 1 presids = np.unique(memarr[:, 1]) dset = [] dec, ts = get_dec(self.ts), self.ts for pres in tqdm(presids, desc=f'lipID {int(lip)}', position=proc, leave=False): stimes = np.round(memarr[:, -1][memarr[:, 1] == pres], dec) if len(stimes) == 0: continue stimes = np.concatenate([np.array([-1]), stimes, np.array([stimes[-1] + 1])]) diff = np.round(stimes[1:] - stimes[:-1], dec) singles = stimes[ np.where((diff[1:] > ts) & (diff[:-1] > ts))[0] + 1] diff[diff > ts] = 0 inds = np.where(diff == 0)[0] sums = [sum(diff[inds[i]:inds[i + 1]]) for i in range(len(inds) - 1)] clens = np.round(np.array(sums), dec) minds = np.where(clens != 0)[0] clens = clens[minds] + ts strt_times = stimes[inds[minds] + 1] [dset.append([pres, lip, time, ts]) for time in singles] [dset.append([pres, lip, time, clen]) for time, clen in zip(strt_times, clens)] np.save(f'.contacts_{i:04}', np.array(dset)) return len(dset)
[docs] class CombineContacts(object): """Class to combine contact timeseries from multiple repeat runs. This class enables pooling data from multiple trajectory repeats and calculating posteriors from all data together, rather than analyzing each run separately. :param contact_files: List of contact pickle files to combine :type contact_files: list of str :param output_name: Name for the combined output file (default: 'combined_contacts.pkl') :type output_name: str, optional :param validate_compatibility: Whether to validate that files are compatible (default: True) :type validate_compatibility: bool, optional """ def __init__(self, contact_files, output_name='combined_contacts.pkl', validate_compatibility=True): self.contact_files = contact_files self.output_name = output_name self.validate_compatibility = validate_compatibility if len(contact_files) < 2: raise ValueError("At least 2 contact files are required for combining") def _load_contact_file(self, filename): """Load a contact pickle file and return data and metadata.""" if not os.path.exists(filename): raise FileNotFoundError(f"Contact file not found: {filename}") with open(filename, 'rb') as f: contacts = pickle.load(f) metadata = contacts.dtype.metadata return contacts, metadata def _validate_compatibility(self, metadatas): """Validate that contact files are compatible for combining.""" reference = metadatas[0] # Check that all files have the same atom groups for i, meta in enumerate(metadatas[1:], 1): # Compare cutoff if meta['cutoff'] != reference['cutoff']: raise ValueError(f"Incompatible cutoffs: file 0 has {reference['cutoff']}, " f"file {i} has {meta['cutoff']}") # Compare atom group selections by checking if resids match ref_ag1_resids = set(reference['ag1'].residues.resids) ref_ag2_resids = set(reference['ag2'].residues.resids) meta_ag1_resids = set(meta['ag1'].residues.resids) meta_ag2_resids = set(meta['ag2'].residues.resids) if ref_ag1_resids != meta_ag1_resids: raise ValueError(f"Incompatible ag1 residues between file 0 and file {i}") if ref_ag2_resids != meta_ag2_resids: raise ValueError(f"Incompatible ag2 residues between file 0 and file {i}") # Check timesteps and warn if different timesteps = [meta['ts'] for meta in metadatas] if not all(abs(ts - timesteps[0]) < 1e-6 for ts in timesteps): print("WARNING: Different timesteps detected across runs:") for i, (filename, ts) in enumerate(zip(self.contact_files, timesteps)): print(f" File {i} ({filename}): dt = {ts} ns") print("This may affect residence time estimates, especially for fast events.")
[docs] def run(self): """Combine contact files and save the result.""" print(f"Combining {len(self.contact_files)} contact files...") all_contacts = [] all_metadatas = [] # Load all contact files for i, filename in enumerate(self.contact_files): print(f"Loading file {i+1}/{len(self.contact_files)}: {filename}") contacts, metadata = self._load_contact_file(filename) all_contacts.append(contacts) all_metadatas.append(metadata) # Validate compatibility if requested if self.validate_compatibility: print("Validating file compatibility...") self._validate_compatibility(all_metadatas) # Combine contact data print("Combining contact data...") # Calculate total size and create combined array total_size = sum(len(contacts) for contacts in all_contacts) reference_metadata = all_metadatas[0].copy() # Extend metadata to include trajectory source information reference_metadata['source_files'] = self.contact_files reference_metadata['n_trajectories'] = len(self.contact_files) # Determine number of columns (5 for raw contacts, 4 for processed) n_cols = all_contacts[0].shape[1] # Create dtype with extended metadata combined_dtype = np.dtype(np.float64, metadata=reference_metadata) # Add trajectory source column (will be last column) combined_contacts = np.zeros((total_size, n_cols + 1), dtype=np.float64) # Combine data and add trajectory source information offset = 0 for traj_idx, contacts in enumerate(all_contacts): n_contacts = len(contacts) # Copy original contact data combined_contacts[offset:offset+n_contacts, :n_cols] = contacts[:] # Add trajectory source index combined_contacts[offset:offset+n_contacts, n_cols] = traj_idx offset += n_contacts # Create final memmap with proper dtype final_contacts = combined_contacts.view(combined_dtype) # Save combined contacts print(f"Saving combined contacts to {self.output_name}...") final_contacts.dump(self.output_name, protocol=5) print(f"Successfully combined {len(self.contact_files)} files into {self.output_name}") print(f"Total contacts: {total_size}") print(f"Added trajectory source column (index {n_cols}) for kinetic clustering support") return self.output_name
[docs] def main(): parser = get_parser() args = parser.parse_args() u = mda.Universe(args.top, args.traj) cutoff, nproc, nslices = args.cutoff, args.nproc, args.nslices ag1 = u.select_atoms(args.sel1) ag2 = u.select_atoms(args.sel2) mc = MapContacts(u, ag1, ag2, nproc=nproc, nslices=nslices) mapname = mc.contacts_filename if not os.path.exists(mapname): print(f"running MapContacts to generate {mapname}") mc.run() else: print(f"using existing {mapname}") ProcessContacts(cutoff, mapname, nproc=nproc).run()
[docs] def get_parser(): import argparse parser = argparse.ArgumentParser(description="""Create the initial contact map and process it using a prescribed cutoff""") required = parser.add_argument_group('required arguments') required.add_argument('--top', type=str, help="Topology") required.add_argument('--traj', type=str, help="Trajectory") required.add_argument('--sel1', type=str, help="Primary atom selection, based \ on MDAnalysis atom selection. basicrta will produce \ tau for each residue in this atom group.") required.add_argument('--sel2', type=str, help="Secondary atom selection, \ based on MDAnalysis atom selection. basicrta will \ collect contacts between each residue of this group \ with each residue of `sel1`.") required.add_argument('--cutoff', type=float, help="""Value to use for defining a contact (in Angstrom). Any atom of `sel2` that is at a distance less than or equal to `cutoff` of any atom in `sel1` will be considered in contact.""", required=True) parser.add_argument('--nproc', type=int, default=1, help="""Number of processes to use in multiprocessing""") parser.add_argument('--nslices', type=int, default=100, help="""Number of slices to break the trajectory into. Increase this to reduce the amount of memory needed for each process.""") # this is to make the cli work, should be just a temporary solution parser.add_argument('contacts', nargs='?', help=argparse.SUPPRESS) return parser
if __name__ == '__main__': exit(main()) """DOCSSS """