Source code for combine

#!/usr/bin/env python

"""
Combine contact timeseries from multiple repeat runs.

This module provides functionality to combine contact files from multiple 
trajectory repeats, enabling pooled analysis of binding kinetics.
"""

import os
import argparse

[docs] class CombineContacts(object): """Class to combine contact timeseries from multiple repeat runs. This class enables pooling data from multiple trajectory repeats and calculating posteriors from all data together, rather than analyzing each run separately. :param contact_files: List of contact pickle files to combine :type contact_files: list of str :param output_name: Name for the combined output file (default: 'combined_contacts.pkl') :type output_name: str, optional :param validate_compatibility: Whether to validate that files are compatible (default: True) :type validate_compatibility: bool, optional """ def __init__(self, contact_files, output_name='combined_contacts.pkl', validate_compatibility=True): self.contact_files = contact_files self.output_name = output_name self.validate_compatibility = validate_compatibility if len(contact_files) < 2: raise ValueError("At least 2 contact files are required for combining") def _load_contact_file(self, filename): """Load a contact pickle file and return data and metadata.""" if not os.path.exists(filename): raise FileNotFoundError(f"Contact file not found: {filename}") with open(filename, 'rb') as f: contacts = pickle.load(f) metadata = contacts.dtype.metadata return contacts, metadata def _validate_compatibility(self, metadatas): """Validate that contact files are compatible for combining.""" reference = metadatas[0] # Check that all files have the same atom groups for i, meta in enumerate(metadatas[1:], 1): # Compare cutoff if meta['cutoff'] != reference['cutoff']: raise ValueError(f"Incompatible cutoffs: file 0 has {reference['cutoff']}, " f"file {i} has {meta['cutoff']}") # Compare atom group selections by checking if resids match ref_ag1_resids = set(reference['ag1'].residues.resids) ref_ag2_resids = set(reference['ag2'].residues.resids) meta_ag1_resids = set(meta['ag1'].residues.resids) meta_ag2_resids = set(meta['ag2'].residues.resids) if ref_ag1_resids != meta_ag1_resids: raise ValueError(f"Incompatible ag1 residues between file 0 and file {i}") if ref_ag2_resids != meta_ag2_resids: raise ValueError(f"Incompatible ag2 residues between file 0 and file {i}") # Check timesteps and warn if different timesteps = [meta['ts'] for meta in metadatas] if not all(abs(ts - timesteps[0]) < 1e-6 for ts in timesteps): print("WARNING: Different timesteps detected across runs:") for i, (filename, ts) in enumerate(zip(self.contact_files, timesteps)): print(f" File {i} ({filename}): dt = {ts} ns") print("This may affect residence time estimates, especially for fast events.")
[docs] def run(self): """Combine contact files and save the result.""" print(f"Combining {len(self.contact_files)} contact files...") all_contacts = [] all_metadatas = [] # Load all contact files for i, filename in enumerate(self.contact_files): print(f"Loading file {i+1}/{len(self.contact_files)}: {filename}") contacts, metadata = self._load_contact_file(filename) all_contacts.append(contacts) all_metadatas.append(metadata) # Validate compatibility if requested if self.validate_compatibility: print("Validating file compatibility...") self._validate_compatibility(all_metadatas) # Combine contact data print("Combining contact data...") # Calculate total size and create combined array total_size = sum(len(contacts) for contacts in all_contacts) reference_metadata = all_metadatas[0].copy() # Extend metadata to include trajectory source information reference_metadata['source_files'] = self.contact_files reference_metadata['n_trajectories'] = len(self.contact_files) # Determine number of columns (5 for raw contacts, 4 for processed) n_cols = all_contacts[0].shape[1] # Create dtype with extended metadata combined_dtype = np.dtype(np.float64, metadata=reference_metadata) # Add trajectory source column (will be last column) combined_contacts = np.zeros((total_size, n_cols + 1), dtype=np.float64) # Combine data and add trajectory source information offset = 0 for traj_idx, contacts in enumerate(all_contacts): n_contacts = len(contacts) # Copy original contact data combined_contacts[offset:offset+n_contacts, :n_cols] = contacts[:] # Add trajectory source index combined_contacts[offset:offset+n_contacts, n_cols] = traj_idx offset += n_contacts # Create final memmap with proper dtype final_contacts = combined_contacts.view(combined_dtype) # Save combined contacts print(f"Saving combined contacts to {self.output_name}...") final_contacts.dump(self.output_name, protocol=5) print(f"Successfully combined {len(self.contact_files)} files into {self.output_name}") print(f"Total contacts: {total_size}") print(f"Added trajectory source column (index {n_cols}) for kinetic clustering support") return self.output_name
[docs] def get_parser(): """Create parser, parse command line arguments, and return ArgumentParser object. :return: An ArgumentParser instance with command line arguments stored. :rtype: `ArgumentParser` object """ parser = argparse.ArgumentParser( description="Combine contact timeseries from multiple repeat runs. " "This enables pooling data from multiple trajectory repeats " "and calculating posteriors from all data together." ) required = parser.add_argument_group('required arguments') required.add_argument( '--contacts', nargs='+', required=True, help="""List of contact pickle files to combine (e.g., contacts_7.0.pkl from different runs)""", ) parser.add_argument( '--output', type=str, default='combined_contacts.pkl', help="Output filename for combined contacts (default: combined_contacts.pkl)" ) parser.add_argument( '--no-validate', action='store_true', help="Skip compatibility validation (use with caution)" ) # this is to make the cli work, should be just a temporary solution parser.add_argument('combine', nargs='?', help=argparse.SUPPRESS) return parser
[docs] def main(): """Execute this function when this script is called from the command line. """ parser = get_parser() args = parser.parse_args() # Validate input files exist missing_files = [] for filename in args.contacts: if not os.path.exists(filename): missing_files.append(filename) if missing_files: print("ERROR: The following contact files were not found:") for filename in missing_files: print(f" {filename}") return 1 if len(args.contacts) < 2: print("ERROR: At least 2 contact files are required for combining") return 1 if os.path.exists(args.output): print(f"ERROR: Output file {args.output} already exists") return 1 try: combiner = CombineContacts( contact_files=args.contacts, output_name=args.output, validate_compatibility=not args.no_validate ) output_file = combiner.run() print(f"\nCombination successful!") print(f"Combined contact file saved as: {output_file}") print(f"\nYou can now use this file with the Gibbs sampler:") print(f" python -m basicrta.gibbs --contacts {output_file} --nproc <N>") return 0 except Exception as e: print(f"ERROR: {e}") return 1
if __name__ == "__main__": exit(main())