From 0c29834ae2c9c31059d59865de242a6909f586f8 Mon Sep 17 00:00:00 2001 From: rostro36 Date: Sat, 28 Sep 2024 14:01:34 +0200 Subject: [PATCH 1/3] Changed wrong filename in README. --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 29d773b..aa573d9 100644 --- a/README.md +++ b/README.md @@ -166,7 +166,7 @@ either by randomly sampling the sequences (“Random”) or by greedily maximizi It is possible to unconditionally generate an entire MSA, using the following script: ``` -python evodiff/generate-msa.py --model-type msa_oa_dm_maxsub --batch-size 1 --n-sequences 64 --n-sequences 256 --subsampling MaxHamming +python evodiff/generate_msa.py --model-type msa_oa_dm_maxsub --batch-size 1 --n-sequences 64 --n-sequences 256 --subsampling MaxHamming ``` The default model type is `msa_oa_dm_maxsub`, which is EvoDiff-MSA-OADM trained on Max subsampled sequences, and the other available @@ -193,14 +193,14 @@ thus generating new members of a protein family without needing to train family- To generate a new query sequence, given an alignment, use the following with the `--start-msa` flag. This starts conditional generation by sampling from a validation MSA. To run this script you must have the Openfold dataset and splits downloaded. ``` -python evodiff/generate-msa.py --model-type msa_oa_dm_maxsub --batch-size 1 --n-sequences 64 --n-sequences 256 --subsampling MaxHamming --start-msa +python evodiff/generate_msa.py --model-type msa_oa_dm_maxsub --batch-size 1 --n-sequences 64 --n-sequences 256 --subsampling MaxHamming --start-msa ``` If you want to generate on a custom MSA, it is possible to retrofit existing code. Additionally, the code is capable of generating an alignment given a query sequence, use the following `--start-query` flag. This starts with the query and generates the alignment. ``` -python evodiff/generate-msa.py --model-type msa_oa_dm_maxsub --batch-size 1 --n-sequences 64 --n-sequences 256 --subsampling MaxHamming --start-query +python evodiff/generate_msa.py --model-type msa_oa_dm_maxsub --batch-size 1 --n-sequences 64 --n-sequences 256 --subsampling MaxHamming --start-query ``` NOTE: you can only specify one of the above flags at a time. You cannot specify both (`--start-query` & `--start-msa`) together. Please look at `generate.py` for more information. From 78c316af73a28572321cfcb08caf7413ac0c7f71 Mon Sep 17 00:00:00 2001 From: rostro36 Date: Sat, 28 Sep 2024 14:04:46 +0200 Subject: [PATCH 2/3] Use the first available GPU, not the second. --- evodiff/generate_msa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/evodiff/generate_msa.py b/evodiff/generate_msa.py index d1757ca..3d57bee 100644 --- a/evodiff/generate_msa.py +++ b/evodiff/generate_msa.py @@ -20,8 +20,8 @@ def main(): #parser.add_argument('config_fpath') #parser.add_argument('out_fpath', type=str, nargs='?', # default=os.getenv('AMLT_OUTPUT_DIR', '/tmp') + '/') - parser.add_argument('-g', '--gpus', default=1, type=int, - help='number of gpus per node') + parser.add_argument('-g', '--gpus', default=0, type=int, + help='Number of gpus per node') parser.add_argument('-off', '--offset', default=0, type=int, help='Number of GPU devices to skip.') parser.add_argument('--model-type', type=str, default='msa_oa_dm_maxsub') From e09bdf7e476e36d4393810c90efd5da52a6e4a72 Mon Sep 17 00:00:00 2001 From: rostro36 Date: Sat, 28 Sep 2024 18:17:46 +0200 Subject: [PATCH 3/3] Cleaned generate_msa.py --- README.md | 4 +- data/openfold/test.a3m | 20 +++++++++ evodiff/data.py | 91 +++++++++++++++++++++++++---------------- evodiff/generate_msa.py | 55 +++++++++++++------------ 4 files changed, 106 insertions(+), 64 deletions(-) create mode 100644 data/openfold/test.a3m diff --git a/README.md b/README.md index aa573d9..4c407b2 100644 --- a/README.md +++ b/README.md @@ -200,8 +200,10 @@ If you want to generate on a custom MSA, it is possible to retrofit existing cod Additionally, the code is capable of generating an alignment given a query sequence, use the following `--start-query` flag. This starts with the query and generates the alignment. ``` -python evodiff/generate_msa.py --model-type msa_oa_dm_maxsub --batch-size 1 --n-sequences 64 --n-sequences 256 --subsampling MaxHamming --start-query +python evodiff/generate_msa.py --model-type msa_oa_dm_maxsub --batch-size 4 --n-sequences 2 --gpus 0 --subsampling MaxHamming --start-query --dataset openfold/test.a3m --out_fpath out ``` +This command takes a .a3m file in data/`--dataset` as input. You have to provide additional non-all-gaps lines, which are higher or equal to the amount of sequences you wan to generate. Special pre-processing behaviour for the default `openfold` dataset. + NOTE: you can only specify one of the above flags at a time. You cannot specify both (`--start-query` & `--start-msa`) together. Please look at `generate.py` for more information. diff --git a/data/openfold/test.a3m b/data/openfold/test.a3m new file mode 100644 index 0000000..c9fea98 --- /dev/null +++ b/data/openfold/test.a3m @@ -0,0 +1,20 @@ +>Test_MSA +ACDEFGHIKLMNPQRSTVY +>Buffer_1 +ACDEFGHIKLMNPQRSTVY +>Buffer_2 +ACDEFGHIKLMNPQRSTVY +>Buffer_3 +ACDEFGHIKLMNPQRSTVY +>Buffer_4 +ACDEFGHIKLMNPQRSTVY +>Buffer_5 +ACDEFGHIKLMNPQRSTVY +>Buffer_6 +ACDEFGHIKLMNPQRSTVY +>Buffer_7 +ACDEFGHIKLMNPQRSTVY +>Buffer_8 +ACDEFGHIKLMNPQRSTVY +>Buffer_9 +ACDEFGHIKLMNPQRSTVY \ No newline at end of file diff --git a/evodiff/data.py b/evodiff/data.py index 05f9eda..0ee7fc2 100644 --- a/evodiff/data.py +++ b/evodiff/data.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from tqdm import tqdm from scipy.spatial.distance import hamming, cdist @@ -308,7 +309,7 @@ def __getitem__(self, idx): class A3MMSADataset(Dataset): """Build dataset for A3M data: MSA Absorbing Diffusion model""" - def __init__(self, selection_type, n_sequences, max_seq_len, data_dir=None, min_depth=None): + def __init__(self, selection_type, n_sequences, max_seq_len, data_dir=None, min_depth=None, openfold=True): """ Args: selection_type: str, @@ -319,11 +320,16 @@ def __init__(self, selection_type, n_sequences, max_seq_len, data_dir=None, min_ maximum MSA sequence length data_dir: str, if you have a specified data directory + min_depth: int, + filter out shallower MSAs + openfold: bool, + use openfold dataset or custom dataset at data_dir """ alphabet = PROTEIN_ALPHABET self.tokenizer = Tokenizer(alphabet) self.alpha = np.array(list(alphabet)) self.gap_idx = self.tokenizer.alphabet.index(GAP) + self.openfold=openfold # Get npz_data dir if data_dir is not None: @@ -331,40 +337,51 @@ def __init__(self, selection_type, n_sequences, max_seq_len, data_dir=None, min_ else: raise FileNotFoundError(data_dir) - [print("Excluding", x) for x in os.listdir(self.data_dir) if x.endswith('.npz')] - all_files = [x for x in os.listdir(self.data_dir) if not x.endswith('.npz')] - all_files = sorted(all_files) + [print(f"Excluding {x}") for x in Path(self.data_dir).glob("*.npz")] + if Path(self.data_dir).is_dir(): + all_files = [x for x in Path(self.data_dir).glob("*[!.npz]")] + all_files = sorted(all_files) + else: + all_files = [self.data_dir] print("unfiltered length", len(all_files)) - - ## Filter based on depth (keep > 64 seqs/MSA) - if not os.path.exists(data_dir + 'openfold_lengths.npz'): - raise Exception("Missing openfold_lengths.npz in openfold/") - if not os.path.exists(data_dir + 'openfold_depths.npz'): - #get_msa_depth_openfold(data_dir, sorted(all_files), 'openfold_depths.npz') - raise Exception("Missing openfold_depths.npz in openfold/") - if min_depth is not None: # reindex, filtering out MSAs < min_depth - _depths = np.load(data_dir+'openfold_depths.npz')['arr_0'] - depths = pd.DataFrame(_depths, columns=['depth']) - depths = depths[depths['depth'] >= min_depth] - keep_idx = depths.index - - _lengths = np.load(data_dir+'openfold_lengths.npz')['ells'] - lengths = np.array(_lengths)[keep_idx] - all_files = np.array(all_files)[keep_idx] - print("filter MSA depth > 64", len(all_files)) - - # Re-filter based on high gap-contining rows - if not os.path.exists(data_dir + 'openfold_gap_depths.npz'): - #get_sliced_gap_depth_openfold(data_dir, all_files, 'openfold_gap_depths.npz', max_seq_len=max_seq_len) - raise Exception("Missing openfold_gap_depths.npz in openfold/") - _gap_depths = np.load(data_dir + 'openfold_gap_depths.npz')['arr_0'] - gap_depths = pd.DataFrame(_gap_depths, columns=['gapdepth']) - gap_depths = gap_depths[gap_depths['gapdepth'] >= min_depth] - filter_gaps_idx = gap_depths.index - lengths = np.array(lengths)[filter_gaps_idx] - all_files = np.array(all_files)[filter_gaps_idx] - print("filter rows with GAPs > 512", len(all_files)) - + if openfold: + ## Filter based on depth (keep > 64 seqs/MSA) + if not os.path.exists(data_dir + 'openfold_lengths.npz'): + raise Exception("Missing openfold_lengths.npz in openfold/") + if not os.path.exists(data_dir + 'openfold_depths.npz'): + #get_msa_depth_openfold(data_dir, sorted(all_files), 'openfold_depths.npz') + raise Exception("Missing openfold_depths.npz in openfold/") + if min_depth is not None: # reindex, filtering out MSAs < min_depth + _depths = np.load(data_dir+'openfold_depths.npz')['arr_0'] + depths = pd.DataFrame(_depths, columns=['depth']) + print(depths) + depths = depths[depths['depth'] >= min_depth] + keep_idx = depths.index + + _lengths = np.load(data_dir+'openfold_lengths.npz')['ells'] + print(np.array(_lengths)) + lengths = np.array(_lengths)[keep_idx] + all_files = np.array(all_files)[keep_idx] + print("filter MSA depth > 64", len(all_files)) + + # Re-filter based on high gap-contining rows + if not os.path.exists(data_dir + 'openfold_gap_depths.npz'): + #get_sliced_gap_depth_openfold(data_dir, all_files, 'openfold_gap_depths.npz', max_seq_len=max_seq_len) + raise Exception("Missing openfold_gap_depths.npz in openfold/") + _gap_depths = np.load(data_dir + 'openfold_gap_depths.npz')['arr_0'] + gap_depths = pd.DataFrame(_gap_depths, columns=['gapdepth']) + gap_depths = gap_depths[gap_depths['gapdepth'] >= min_depth] + filter_gaps_idx = gap_depths.index + lengths = np.array(lengths)[filter_gaps_idx] + all_files = np.array(all_files)[filter_gaps_idx] + print("filter rows with GAPs > 512", len(all_files)) + else: + all_files = np.array(all_files) #maybe expand to whole path + lengths = [] + for file in all_files: + parsed_msa = parse_fasta(file) + lengths.append(max([len(line) for line in parsed_msa])) + lengths = np.array(lengths) self.filenames = all_files # IDs of samples to include self.lengths = lengths # pass to batch sampler self.n_sequences = n_sequences @@ -376,7 +393,10 @@ def __len__(self): def __getitem__(self, idx): filename = self.filenames[idx] - path = read_openfold_files(self.data_dir, filename) + if self.openfold: + path = read_openfold_files(self.data_dir, filename) + else: + path = filename parsed_msa = parse_fasta(path) aligned_msa = [[char for char in seq if (char.isupper() or char == '-') and not char == '.'] for seq in parsed_msa] @@ -522,7 +542,6 @@ def __len__(self): def __getitem__(self, idx): filename = self.filenames[idx] - print(filename) path = read_idr_files(self.data_dir, filename) parsed_msa = parse_fasta(path) aligned_msa = [[char for char in seq if (char.isupper() or char == '-') and not char == '.'] for seq in parsed_msa] diff --git a/evodiff/generate_msa.py b/evodiff/generate_msa.py index 3d57bee..eaa18d6 100644 --- a/evodiff/generate_msa.py +++ b/evodiff/generate_msa.py @@ -18,7 +18,7 @@ def main(): parser = argparse.ArgumentParser() #parser.add_argument('config_fpath') - #parser.add_argument('out_fpath', type=str, nargs='?', + parser.add_argument('--out_fpath', type=str, default=None)# nargs='?', # default=os.getenv('AMLT_OUTPUT_DIR', '/tmp') + '/') parser.add_argument('-g', '--gpus', default=0, type=int, help='Number of gpus per node') @@ -113,11 +113,11 @@ def main(): if args.amlt: home = os.getenv('AMLT_OUTPUT_DIR', '/tmp') + '/' - out_fpath = home + out_fpath = home if args.out_fpath is None else args.out_fpath else: home = str(pathlib.Path.home()) + '/Desktop/DMs/' top_dir = home - out_fpath = home + args.model_type + '/gen-'+str(args.run) + '/' + out_fpath = home + args.model_type + '/gen-'+str(args.run) + '/' if args.out_fpath is None else args.out_fpath if not os.path.exists(out_fpath): os.makedirs(out_fpath) @@ -133,23 +133,21 @@ def main(): print("Penalizing GAPS by factor of", 1+args.penalty_value) else: print("Not penalizing GAPS") - + batch_size = args.batch_size if pathlib.Path(data_dir).is_dir() else 1 if scheme == 'mask': - sample, _string = generate_msa(model, tokenizer, args.batch_size, args.n_sequences, args.seq_length, + sample, _string = generate_msa(model, tokenizer, batch_size, args.n_sequences, args.seq_length, penalty_value=args.penalty_value, device=device, start_query=args.start_query, start_msa=args.start_msa, - data_top_dir=data_top_dir, selection_type=args.subsampling, out_path=out_fpath) + data_top_dir=data_top_dir, selection_type=args.subsampling, out_path=out_fpath, openfold=args.dataset=="openfold", data_dir=args.dataset) elif scheme == 'd3pm': - sample, _string = generate_msa_d3pm(model, args.batch_size, args.n_sequences, args.seq_length, + sample, _string = generate_msa_d3pm(model, batch_size, args.n_sequences, args.seq_length, Q_bar=Q_bar, Q=Q, tokenizer=Tokenizer(), data_top_dir=data_top_dir, selection_type=args.subsampling, out_path=out_fpath, max_timesteps=timestep, start_query=args.start_query, - no_step=False, penalty_value=args.penalty_value, device=device) - - + no_step=False, penalty_value=args.penalty_value, device=device, openfold=args.dataset=="openfold", data_dir=args.dataset) for count, msa in enumerate(_string): fasta_string = "" - with open(out_fpath + 'generated_msas.a3m', 'a') as f: + with open(pathlib.Path(out_fpath)/'generated_msas.a3m', 'a') as f: for seq in range(args.n_sequences): seq_num = seq * args.seq_length next_seq_num = (seq+1) * args.seq_length @@ -160,19 +158,19 @@ def main(): f.write(">tr \n" + str(seq_string) + "\n" ) f.write(fasta_string) f.close() - np.save(out_fpath+'generated_msas', np.array(sample.cpu())) + np.save(pathlib.Path(out_fpath)/'generated_msas', np.array(sample.cpu())) def generate_msa(model, tokenizer, batch_size, n_sequences, seq_length, penalty_value=2, device='gpu', - start_query=False, start_msa=False, data_top_dir='../data', selection_type='MaxHamming', out_path='../ref/'): + start_query=False, start_msa=False, data_top_dir='../data', selection_type='MaxHamming', out_path='../ref/', openfold=False, data_dir="openfold/"): mask_id = tokenizer.mask_id src = torch.full((batch_size, n_sequences, seq_length), fill_value=mask_id) masked_loc_x = np.arange(n_sequences) masked_loc_y = np.arange(seq_length) if start_query: - valid_msas, query_sequences, tokenizer =get_valid_data(data_top_dir, batch_size, 'autoreg', data_dir='openfold/', + valid_msas, query_sequences, tokenizer =get_valid_data(data_top_dir, batch_size, 'autoreg', data_dir=data_dir, selection_type=selection_type, n_sequences=n_sequences, max_seq_len=seq_length, - out_path=out_path) + out_path=out_path, openfold=openfold) # First row is query sequence for i in range(batch_size): seq_len = len(query_sequences[i]) @@ -184,10 +182,11 @@ def generate_msa(model, tokenizer, batch_size, n_sequences, seq_length, penalty_ y_indices = np.arange(seq_len) elif start_msa: valid_msas, query_sequences, tokenizer = get_valid_data(data_top_dir, batch_size, 'autoreg', - data_dir='openfold/', + data_dir=data_dir, selection_type=selection_type, n_sequences=n_sequences, max_seq_len=seq_length, - out_path=out_path) + out_path=out_path, + openfold=openfold) for i in range(batch_size): seq_len = len(query_sequences[i]) src[i, 1:n_sequences, :seq_len] = valid_msas[i][0, 1:n_sequences, :seq_len].squeeze() @@ -270,14 +269,14 @@ def generate_query_oadm_msa_simple(path_to_msa, model, tokenizer, n_sequences, s def generate_msa_d3pm(model, batch_size, n_sequences, seq_length, Q_bar=None, Q=None, tokenizer=Tokenizer(), start_query=False, data_top_dir='../data', selection_type='MaxHamming', out_path='../ref/', - max_timesteps=500, no_step=False, penalty_value=0, device='gpu'): + max_timesteps=500, no_step=False, penalty_value=0, device='gpu', openfold=False, data_dir="openfold/"): sample = torch.randint(0, tokenizer.K, (batch_size, n_sequences, seq_length)) if start_query: x_indices = [] y_indices = [] valid_msas, query_sequences, tokenizer =get_valid_data(data_top_dir, batch_size, 'autoreg', data_dir='openfold/', selection_type=selection_type, n_sequences=n_sequences, max_seq_len=seq_length, - out_path=out_path) + out_path=out_path, openfold=openfold) # First row is query sequence for i in range(batch_size): seq_len = len(query_sequences[i]) @@ -340,7 +339,7 @@ def generate_msa_d3pm(model, batch_size, n_sequences, seq_length, Q_bar=None, Q= def get_valid_data(data_top_dir, num_seqs, arg_mask, data_dir='openfold/', selection_type='MaxHamming', n_sequences=64, max_seq_len=512, - out_path='../DMs/ref/'): + out_path='../DMs/ref/', openfold=True): valid_msas = [] query_msas = [] seq_lens = [] @@ -348,14 +347,16 @@ def get_valid_data(data_top_dir, num_seqs, arg_mask, data_dir='openfold/', selec _ = torch.manual_seed(1) # same seeds as training np.random.seed(1) - dataset = A3MMSADataset(selection_type, n_sequences, max_seq_len, data_dir=os.path.join(data_top_dir,data_dir), min_depth=64) + dataset = A3MMSADataset(selection_type, n_sequences, max_seq_len, data_dir=os.path.join(data_top_dir,data_dir), min_depth=64, openfold=openfold) train_size = len(dataset) - random_ind = np.random.choice(train_size, size=(train_size - 10000), replace=False) - val_ind = np.delete(np.arange(train_size), random_ind) - - - ds_valid = Subset(dataset, val_ind) + openfold=False + if openfold: + random_ind = np.random.choice(train_size, size=(train_size - 10000), replace=False) + val_ind = np.delete(np.arange(train_size), random_ind) + ds_valid = Subset(dataset, val_ind) + else: + ds_valid = dataset if arg_mask == 'autoreg': tokenizer = Tokenizer() @@ -394,7 +395,7 @@ def get_valid_data(data_top_dir, num_seqs, arg_mask, data_dir='openfold/', selec print("LEN VALID MSAS", len(valid_msas)) untokenized = [[tokenizer.untokenize(msa.flatten())] for msa in valid_msas] fasta_string = "" - with open(out_path + 'valid_msas.a3m', 'a') as f: + with open(pathlib.Path(out_path)/'valid_msas.a3m', 'a') as f: for i, msa in enumerate(untokenized): for seq in range(n_sequences): seq_num = seq * seq_lens[i]