From f4d9e37e6d760b36e79d26f981fcf74b3d99f82b Mon Sep 17 00:00:00 2001 From: Milot Mirdita Date: Sun, 14 Jul 2024 16:24:42 +0900 Subject: [PATCH] colabfold_search : allow continuing interrupted runs, disabling unpacking databases --- colabfold/mmseqs/search.py | 218 +++++++++++++++++++++---------------- 1 file changed, 126 insertions(+), 92 deletions(-) diff --git a/colabfold/mmseqs/search.py b/colabfold/mmseqs/search.py index 52c40173..44ce9f92 100644 --- a/colabfold/mmseqs/search.py +++ b/colabfold/mmseqs/search.py @@ -1,7 +1,5 @@ """ Functionality for running mmseqs locally. Takes in a fasta file, outputs final.a3m - -Note: Currently needs mmseqs compiled from source """ import logging @@ -18,8 +16,28 @@ logger = logging.getLogger(__name__) +MODULE_OUTPUT_POS = { + "align": 4, + "convertalis": 4, + "expandaln": 5, + "filterresult": 4, + "lndb": 2, + "mergedbs": 2, + "mvdb": 2, + "pairaln": 4, + "result2msa": 4, + "search": 3, +} def run_mmseqs(mmseqs: Path, params: List[Union[str, Path]]): + module = params[0] + if module in MODULE_OUTPUT_POS: + output_pos = MODULE_OUTPUT_POS[module] + output_path = Path(params[output_pos]).with_suffix('.dbtype') + if output_path.exists(): + logger.info(f"Skipping {module} because {output_path} already exists") + return + params_log = " ".join(str(i) for i in params) logger.info(f"Running {mmseqs} {params_log}") # hide MMseqs2 verbose paramters list that clogs up the log @@ -46,6 +64,7 @@ def mmseqs_search_monomer( s: float = 8, db_load_mode: int = 2, threads: int = 32, + unpack: bool = True, ): """Run mmseqs with a local colabfold database set @@ -86,8 +105,6 @@ def mmseqs_search_monomer( dbSuffix2 = ".idx" dbSuffix3 = ".idx" - # fmt: off - # @formatter:off search_param = ["--num-iterations", "3", "--db-load-mode", str(db_load_mode), "-a", "-e", "0.1", "--max-seqs", "10000"] search_param += ["--prefilter-mode", str(prefilter_mode)] if s is not None: @@ -98,24 +115,27 @@ def mmseqs_search_monomer( filter_param = ["--filter-msa", str(filter), "--filter-min-enable", "1000", "--diff", str(diff), "--qid", "0.0,0.2,0.4,0.6,0.8,1.0", "--qsc", "0", "--max-seq-id", "0.95",] expand_param = ["--expansion-mode", "0", "-e", str(expand_eval), "--expand-filter-clusters", str(filter), "--max-seq-id", "0.95",] - run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(uniref_db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads)] + search_param) - run_mmseqs(mmseqs, ["mvdb", base.joinpath("tmp/latest/profile_1"), base.joinpath("prof_res")]) - run_mmseqs(mmseqs, ["lndb", base.joinpath("qdb_h"), base.joinpath("prof_res_h")]) - run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{uniref_db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + expand_param) - run_mmseqs(mmseqs, ["align", base.joinpath("prof_res"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_exp"), base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", str(align_eval), "--max-accept", str(max_accept), "--threads", str(threads), "--alt-ali", "10", "-a"]) - run_mmseqs(mmseqs, ["filterresult", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), - base.joinpath("res_exp_realign"), base.joinpath("res_exp_realign_filter"), "--db-load-mode", - str(db_load_mode), "--qid", "0", "--qsc", str(qsc), "--diff", "0", "--threads", - str(threads), "--max-seq-id", "1.0", "--filter-min-enable", "100"]) - run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), - base.joinpath("res_exp_realign_filter"), base.joinpath("uniref.a3m"), "--msa-format-mode", - "6", "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param) - subprocess.run([mmseqs] + ["rmdb", base.joinpath("res_exp_realign")]) - subprocess.run([mmseqs] + ["rmdb", base.joinpath("res_exp")]) - subprocess.run([mmseqs] + ["rmdb", base.joinpath("res")]) - subprocess.run([mmseqs] + ["rmdb", base.joinpath("res_exp_realign_filter")]) + if not base.joinpath("uniref.a3m").with_suffix('.a3m.dbtype').exists(): + run_mmseqs(mmseqs, ["search", base.joinpath("qdb"), dbbase.joinpath(uniref_db), base.joinpath("res"), base.joinpath("tmp"), "--threads", str(threads)] + search_param) + run_mmseqs(mmseqs, ["mvdb", base.joinpath("tmp/latest/profile_1"), base.joinpath("prof_res")]) + run_mmseqs(mmseqs, ["lndb", base.joinpath("qdb_h"), base.joinpath("prof_res_h")]) + run_mmseqs(mmseqs, ["expandaln", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res"), dbbase.joinpath(f"{uniref_db}{dbSuffix2}"), base.joinpath("res_exp"), "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + expand_param) + run_mmseqs(mmseqs, ["align", base.joinpath("prof_res"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), base.joinpath("res_exp"), base.joinpath("res_exp_realign"), "--db-load-mode", str(db_load_mode), "-e", str(align_eval), "--max-accept", str(max_accept), "--threads", str(threads), "--alt-ali", "10", "-a"]) + run_mmseqs(mmseqs, ["filterresult", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), + base.joinpath("res_exp_realign"), base.joinpath("res_exp_realign_filter"), "--db-load-mode", + str(db_load_mode), "--qid", "0", "--qsc", str(qsc), "--diff", "0", "--threads", + str(threads), "--max-seq-id", "1.0", "--filter-min-enable", "100"]) + run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{uniref_db}{dbSuffix1}"), + base.joinpath("res_exp_realign_filter"), base.joinpath("uniref.a3m"), "--msa-format-mode", + "6", "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign_filter")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")]) + else: + logger.info(f"Skipping {uniref_db} search because uniref.a3m already exists") - if use_env: + if use_env and not base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m").with_suffix('.a3m.dbtype').exists(): run_mmseqs(mmseqs, ["search", base.joinpath("prof_res"), dbbase.joinpath(metagenomic_db), base.joinpath("res_env"), base.joinpath("tmp3"), "--threads", str(threads)] + search_param) run_mmseqs(mmseqs, ["expandaln", base.joinpath("prof_res"), dbbase.joinpath(f"{metagenomic_db}{dbSuffix1}"), base.joinpath("res_env"), @@ -133,45 +153,49 @@ def mmseqs_search_monomer( base.joinpath("res_env_exp_realign_filter"), base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m"), "--msa-format-mode", "6", "--db-load-mode", str(db_load_mode), "--threads", str(threads)] + filter_param) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp_realign_filter")]) run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp_realign")]) run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env_exp")]) run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_env")]) + elif use_env: + logger.info(f"Skipping {metagenomic_db} search because bfd.mgnify30.metaeuk30.smag30.a3m already exists") - run_mmseqs(mmseqs, ["mergedbs", base.joinpath("qdb"), base.joinpath("final.a3m"), base.joinpath("uniref.a3m"), base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")]) - else: - run_mmseqs(mmseqs, ["mvdb", base.joinpath("uniref.a3m"), base.joinpath("final.a3m")]) - - if use_templates: + if use_templates and not base.joinpath("res_pdb.m8").with_suffix('.m8.dbtype').exists(): run_mmseqs(mmseqs, ["search", base.joinpath("prof_res"), dbbase.joinpath(template_db), base.joinpath("res_pdb"), base.joinpath("tmp2"), "--db-load-mode", str(db_load_mode), "--threads", str(threads), "-s", "7.5", "-a", "-e", "0.1", "--prefilter-mode", str(prefilter_mode)]) run_mmseqs(mmseqs, ["convertalis", base.joinpath("prof_res"), dbbase.joinpath(f"{template_db}{dbSuffix3}"), base.joinpath("res_pdb"), - base.joinpath(f"{template_db}"), "--format-output", + base.joinpath("res_pdb.m8"), "--format-output", "query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,cigar", "--db-output", "1", "--db-load-mode", str(db_load_mode), "--threads", str(threads)]) - run_mmseqs(mmseqs, ["unpackdb", base.joinpath(f"{template_db}"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".m8"]) run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_pdb")]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath(f"{template_db}")]) + elif use_templates: + logger.info(f"Skipping {template_db} search because res_pdb.m8 already exists") - run_mmseqs(mmseqs, ["unpackdb", base.joinpath("final.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".a3m"]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("final.a3m")]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("uniref.a3m")]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")]) - # @formatter:on - # fmt: on + if use_env: + run_mmseqs(mmseqs, ["mergedbs", base.joinpath("qdb"), base.joinpath("final.a3m"), base.joinpath("uniref.a3m"), base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("bfd.mgnify30.metaeuk30.smag30.a3m")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("uniref.a3m")]) + else: + run_mmseqs(mmseqs, ["mvdb", base.joinpath("uniref.a3m"), base.joinpath("final.a3m")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("uniref.a3m")]) - for file in base.glob("prof_res*"): - file.unlink() + if unpack: + run_mmseqs(mmseqs, ["unpackdb", base.joinpath("final.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".a3m"]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("final.a3m")]) + + if use_templates: + run_mmseqs(mmseqs, ["unpackdb", base.joinpath("res_pdb.m8"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", ".m8"]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_pdb.m8")]) + + run_mmseqs(mmseqs, ["rmdb", base.joinpath("prof_res")]) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("prof_res_h")]) shutil.rmtree(base.joinpath("tmp")) if use_templates: shutil.rmtree(base.joinpath("tmp2")) if use_env: shutil.rmtree(base.joinpath("tmp3")) - def mmseqs_search_pair( dbbase: Path, base: Path, @@ -184,6 +208,7 @@ def mmseqs_search_pair( threads: int = 64, db_load_mode: int = 2, pairing_strategy: int = 0, + unpack: bool = True, ): if not dbbase.joinpath(f"{uniref_db}.dbtype").is_file(): raise FileNotFoundError(f"Database {uniref_db} does not exist") @@ -225,14 +250,15 @@ def mmseqs_search_pair( run_mmseqs(mmseqs, ["align", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res_exp_realign_pair"), base.joinpath("res_exp_realign_pair_bt"), "--db-load-mode", str(db_load_mode), "-e", "inf", "-a", "--threads", str(threads), ],) run_mmseqs(mmseqs, ["pairaln", base.joinpath("qdb"), dbbase.joinpath(f"{db}"), base.joinpath("res_exp_realign_pair_bt"), base.joinpath("res_final"), "--db-load-mode", str(db_load_mode), "--pairing-mode", str(pairing_strategy), "--pairing-dummy-mode", "1", "--threads", str(threads),],) run_mmseqs(mmseqs, ["result2msa", base.joinpath("qdb"), dbbase.joinpath(f"{db}{dbSuffix1}"), base.joinpath("res_final"), base.joinpath("pair.a3m"), "--db-load-mode", str(db_load_mode), "--msa-format-mode", "5", "--threads", str(threads),],) - run_mmseqs(mmseqs, ["unpackdb", base.joinpath("pair.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", output,],) + if unpack: + run_mmseqs(mmseqs, ["unpackdb", base.joinpath("pair.a3m"), base.joinpath("."), "--unpack-name-mode", "0", "--unpack-suffix", output,],) + run_mmseqs(mmseqs, ["rmdb", base.joinpath("pair.a3m")]) run_mmseqs(mmseqs, ["rmdb", base.joinpath("res")]) run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp")]) run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign")]) run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign_pair")]) run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_exp_realign_pair_bt")]) run_mmseqs(mmseqs, ["rmdb", base.joinpath("res_final")]) - run_mmseqs(mmseqs, ["rmdb", base.joinpath("pair.a3m")]) shutil.rmtree(base.joinpath("tmp")) # @formatter:on # fmt: on @@ -340,6 +366,9 @@ def main(): default=0, help="Database preload mode 0: auto, 1: fread, 2: mmap, 3: mmap+touch", ) + parser.add_argument( + "--unpack", type=int, default=1, choices=[0, 1], help="Unpack results to loose files or keep MMseqs2 databases." + ) parser.add_argument( "--threads", type=int, default=64, help="Number of threads to use." ) @@ -416,6 +445,7 @@ def main(): s=args.s, db_load_mode=args.db_load_mode, threads=args.threads, + unpack=args.unpack, ) if is_complex is True: mmseqs_search_pair( @@ -429,6 +459,7 @@ def main(): threads=args.threads, pairing_strategy=args.pairing_strategy, pair_env=False, + unpack=args.unpack, ) if args.use_env_pairing: mmseqs_search_pair( @@ -443,63 +474,66 @@ def main(): threads=args.threads, pairing_strategy=args.pairing_strategy, pair_env=True, + unpack=args.unpack, ) - id = 0 - for job_number, ( - raw_jobname, - query_sequences, - query_seqs_cardinality, - ) in enumerate(queries_unique): - unpaired_msa = [] - paired_msa = None - if len(query_seqs_cardinality) > 1: - paired_msa = [] - for seq in query_sequences: - with args.base.joinpath(f"{id}.a3m").open("r") as f: - unpaired_msa.append(f.read()) - args.base.joinpath(f"{id}.a3m").unlink() - - if args.use_env_pairing: - with open(args.base.joinpath(f"{id}.paired.a3m"), 'a') as file_pair: - with open(args.base.joinpath(f"{id}.env.paired.a3m"), 'r') as file_pair_env: - while chunk := file_pair_env.read(10 * 1024 * 1024): - file_pair.write(chunk) - args.base.joinpath(f"{id}.env.paired.a3m").unlink() - + if args.unpack: + id = 0 + for job_number, ( + raw_jobname, + query_sequences, + query_seqs_cardinality, + ) in enumerate(queries_unique): + unpaired_msa = [] + paired_msa = None if len(query_seqs_cardinality) > 1: - with args.base.joinpath(f"{id}.paired.a3m").open("r") as f: - paired_msa.append(f.read()) - args.base.joinpath(f"{id}.paired.a3m").unlink() - id += 1 - msa = msa_to_str( - unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality + paired_msa = [] + for seq in query_sequences: + with args.base.joinpath(f"{id}.a3m").open("r") as f: + unpaired_msa.append(f.read()) + args.base.joinpath(f"{id}.a3m").unlink() + + if args.use_env_pairing: + with open(args.base.joinpath(f"{id}.paired.a3m"), 'a') as file_pair: + with open(args.base.joinpath(f"{id}.env.paired.a3m"), 'r') as file_pair_env: + while chunk := file_pair_env.read(10 * 1024 * 1024): + file_pair.write(chunk) + args.base.joinpath(f"{id}.env.paired.a3m").unlink() + + if len(query_seqs_cardinality) > 1: + with args.base.joinpath(f"{id}.paired.a3m").open("r") as f: + paired_msa.append(f.read()) + args.base.joinpath(f"{id}.paired.a3m").unlink() + id += 1 + msa = msa_to_str( + unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality + ) + args.base.joinpath(f"{job_number}.a3m").write_text(msa) + + if args.unpack: + # rename a3m files + for job_number, (raw_jobname, query_sequences, query_seqs_cardinality) in enumerate(queries_unique): + os.rename( + args.base.joinpath(f"{job_number}.a3m"), + args.base.joinpath(f"{safe_filename(raw_jobname)}.a3m"), ) - args.base.joinpath(f"{job_number}.a3m").write_text(msa) - - # rename a3m files - for job_number, (raw_jobname, query_sequences, query_seqs_cardinality) in enumerate(queries_unique): - os.rename( - args.base.joinpath(f"{job_number}.a3m"), - args.base.joinpath(f"{safe_filename(raw_jobname)}.a3m"), - ) - # rename m8 files - if args.use_templates: - id = 0 - for raw_jobname, query_sequences, query_seqs_cardinality in queries_unique: - with args.base.joinpath(f"{safe_filename(raw_jobname)}_{args.db2}.m8").open( - "w" - ) as f: - for _ in range(len(query_seqs_cardinality)): - with args.base.joinpath(f"{id}.m8").open("r") as g: - f.write(g.read()) - os.remove(args.base.joinpath(f"{id}.m8")) - id += 1 + # rename m8 files + if args.use_templates: + id = 0 + for raw_jobname, query_sequences, query_seqs_cardinality in queries_unique: + with args.base.joinpath(f"{safe_filename(raw_jobname)}_{args.db2}.m8").open( + "w" + ) as f: + for _ in range(len(query_seqs_cardinality)): + with args.base.joinpath(f"{id}.m8").open("r") as g: + f.write(g.read()) + os.remove(args.base.joinpath(f"{id}.m8")) + id += 1 + run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")]) + run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")]) query_file.unlink() - run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")]) - run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")]) if __name__ == "__main__":