From 97f7602b6e21ede36dbbb8a1b80335e5d73559c1 Mon Sep 17 00:00:00 2001 From: yangyxt Date: Mon, 16 Dec 2024 11:17:31 +0800 Subject: [PATCH 01/12] add parallel prescore extraction --- src/scripts/extract_scored.py | 275 +++++++++++++++++++++++++++------- 1 file changed, 218 insertions(+), 57 deletions(-) diff --git a/src/scripts/extract_scored.py b/src/scripts/extract_scored.py index e88f103..dc9b6e8 100755 --- a/src/scripts/extract_scored.py +++ b/src/scripts/extract_scored.py @@ -5,69 +5,230 @@ import os import pysam from optparse import OptionParser +import multiprocessing as mp +from functools import partial +import time -parser = OptionParser() -parser.add_option("-p", "--path", dest="path", help="Path to scored variants.") -parser.add_option("-i", "--input", dest="input", help="Read variants from vcf file (default stdin)", default=None) -parser.add_option("--found_out", dest="found_out", help="Write found variants to file (default: stdout)", default=None) -parser.add_option("--header", dest="header", help="Write full header to output (default none)", - default=False, action="store_true") -(options, args) = parser.parse_args() - -if options.input: - stdin = open(options.input, 'r') -else: - stdin = sys.stdin - -if options.found_out: - found_out = open(options.found_out, 'w') -else: - found_out = sys.stdout - -fpos, fref, falt = 1, 2, 3 -if os.path.exists(options.path) and os.path.exists(options.path+".tbi"): - filename = options.path - sys.stderr.write("Opening %s...\n" % (filename)) - regionTabix = pysam.Tabixfile(filename, 'r') - header = list(regionTabix.header) - for line in header: - if options.header: + +def setup_output_dir(output_base, chrom): + """Create chromosome-specific output directory""" + chrom_dir = os.path.join(output_base, chrom) + os.makedirs(chrom_dir, exist_ok=True) + return chrom_dir + + +def extract_prescored_chromosome(input_file, output_base, chrom): + """Extract records for a single chromosome from prescored TSV file""" + try: + # Setup output directory + chrom_dir = setup_output_dir(output_base, chrom) + input_file_name = os.path.basename(input_file) + input_file_name_base = input_file_name.rsplit('.', 1)[0] + output_file = os.path.join(chrom_dir, f"{input_file_name_base}.{chrom}.tsv") + compressed_file = f"{output_file}.gz" + + # Check if extraction is needed + if os.path.exists(compressed_file): + if os.path.getmtime(compressed_file) > os.path.getmtime(input_file): + return compressed_file + + # Consider the edge case where the chromsome does not exist in the input file + chromosomes = get_chromosomes(input_file) + if chrom not in chromosomes: + return None + + # Extract records for this chromosome using tabix + tbx = pysam.TabixFile(input_file) + with open(output_file, 'w') as f: + for row in tbx.fetch(chrom): + f.write(f"{row}\n") + + # Compress and index the output file + pysam.tabix_compress(output_file, compressed_file, force=True) + pysam.tabix_index(compressed_file, + preset=None, + force=True, + seq_col=0, + start_col=1, + end_col=1, + zerobased=False) + + # Remove uncompressed file + os.remove(output_file) + return compressed_file + + except Exception as e: + raise Exception(f"Error extracting prescored chromosome {chrom}: {str(e)}") + + +def buffer_vcf_by_chromosome(stdin): + """Read VCF from stdin and buffer by chromosome""" + vcf_by_chrom = {} + header_lines = [] + + for line in stdin: + if line.startswith('#'): + header_lines.append(line) + continue + + fields = line.strip().split('\t') + chrom = fields[0] + + if chrom not in vcf_by_chrom: + vcf_by_chrom[chrom] = [] + vcf_by_chrom[chrom].append(line) + + return header_lines, vcf_by_chrom + + +def process_chromosome(args): + """Process a single chromosome""" + chrom, vcf_lines, prescored_file, temp_dir, fpos, fref, falt = args + try: + # First extract prescored records for this chromosome + prescored_chrom_file = extract_prescored_chromosome( + prescored_file, + os.path.join(temp_dir, "prescored"), + chrom + ) + + # Setup output files for this chromosome + found_file = os.path.join(temp_dir, "matches", f"found.{chrom}.tmp") + notfound_file = os.path.join(temp_dir, "matches", f"notfound.{chrom}.tmp") + os.makedirs(os.path.dirname(found_file), exist_ok=True) + + if prescored_chrom_file is None: + # Create empty found file and output all records to notfound file + with open(notfound_file, 'w') as f_notfound: + for line in vcf_lines: + f_notfound.write(line) + return chrom, True + + # Open prescored tabix file + pre_tbx = pysam.TabixFile(prescored_chrom_file) + + with open(found_file, 'w') as f_found, open(notfound_file, 'w') as f_notfound: + # Process each variant + for line in vcf_lines: + fields = line.strip().split('\t') + pos = int(fields[1]) + lref, allele = fields[-2], fields[-1].strip() + found = False + + # Look for matches in prescored file + for pre_line in pre_tbx.fetch(chrom, pos-1, pos): + vfields = pre_line.rstrip().split('\t') + if (vfields[fref] == lref) and (vfields[falt] == allele) and (vfields[fpos] == fields[1]): + f_found.write(pre_line + '\n') + found = True + break + + if not found: + f_notfound.write(line) + + return chrom, True + except Exception as e: + sys.stderr.write(f'Error processing chromosome {chrom}: {str(e)}\n') + return chrom, False + + +def get_chromosomes(vcf_file): + """Get chromosomes from input VCF file""" + vcf_tbx = pysam.TabixFile(vcf_file) + return sorted(vcf_tbx.contigs) + + +def main(): + parser = OptionParser() + parser.add_option("-p", "--path", dest="path", help="Path to scored variants.") + parser.add_option("-i", "--input", dest="input", help="Read variants from vcf file (default stdin)", default=None) + parser.add_option("--found_out", dest="found_out", help="Write found variants to file (default: stdout)", default=None) + parser.add_option("--header", dest="header", help="Write full header to output (default none)", + default=False, action="store_true") + (options, args) = parser.parse_args() + + # Setup input/output files + stdin = open(options.input, 'r') if options.input else sys.stdin + found_out = open(options.found_out, 'w') if options.found_out else sys.stdout + + # Create temporary directory + temp_dir = "temp_extract_scored" + os.makedirs(temp_dir, exist_ok=True) + + # Initialize column indices + fpos, fref, falt = 1, 2, 3 + + # Check prescored file + if not (os.path.exists(options.path) and os.path.exists(options.path+".tbi")): + raise IOError("No valid file with pre-scored variants.\n") + + # Get header and column indices from prescored file + pre_tbx = pysam.TabixFile(options.path, 'r') + header = list(pre_tbx.header) + + # Write headers to output files if requested + if options.header: + for line in header: found_out.write(line+"\n") + + # Get column indices from header + for line in header: try: fref = line.split('\t').index('Ref') falt = line.split('\t').index('Alt') except ValueError: pass -else: - raise IOError("No valid file with pre-scored variants.\n") - -for line in stdin: - line = line.rstrip('\n\r') - if line.startswith('#'): - sys.stdout.write(line + '\n') - continue + # Buffer VCF data and get chromosomes + header_lines, vcf_by_chrom = buffer_vcf_by_chromosome(stdin) + chromosomes = sorted(vcf_by_chrom.keys()) + + # Write VCF headers to stdout + for line in header_lines: + sys.stdout.write(line) + + # Setup parallel processing args + process_args = [ + (chrom, tuple(vcf_by_chrom[chrom]), options.path, temp_dir, fpos, fref, falt) + for chrom in chromosomes + ] + + # Process chromosomes in parallel + # Get number of threads from Snakemake + threads = int(os.environ.get("SNAKEMAKE_THREADS", "1")) + threads = min(threads, len(chromosomes)) + print(f"Using {threads} threads to extract the scored variants across all chromosomes", file=sys.stderr) + with mp.Pool(threads) as pool: + results = pool.map(process_chromosome, process_args) + + # Combine results + for chrom, success in results: + if success: + found_file = os.path.join(temp_dir, "matches", f"found.{chrom}.tmp") + notfound_file = os.path.join(temp_dir, "matches", f"notfound.{chrom}.tmp") + + if os.path.exists(found_file): + with open(found_file) as f: + for line in f: + found_out.write(line) + os.remove(found_file) + + if os.path.exists(notfound_file): + with open(notfound_file) as f: + for line in f: + sys.stdout.write(line) + os.remove(notfound_file) + + # Cleanup try: - fields = line.split('\t') - found = False - chrom = fields[0] - pos = int(fields[1]) - lref, allele = fields[-2], fields[-1] - for regionHit in regionTabix.fetch(chrom, pos-1, pos): - vfields = regionHit.rstrip().split('\t') - if (vfields[fref] == lref) and (vfields[falt] == allele) and (vfields[fpos] == fields[1]): - found_out.write(regionHit+"\n") - found = True - - if not found: - sys.stdout.write(line + '\n') - - except ValueError: - sys.stderr.write('Encountered uncovered chromosome\n') - sys.stdout.write(line + '\n') - -if options.input: - stdin.close() - -if options.found_out: - found_out.close() + import shutil + shutil.rmtree(temp_dir) + except: + pass + + # Close files + if options.found_out: + found_out.close() + +if __name__ == "__main__": + main() From b4e2c89059d6219c1581c0c96d85d2fac5716c4b Mon Sep 17 00:00:00 2001 From: yangyxt Date: Mon, 16 Dec 2024 11:41:56 +0800 Subject: [PATCH 02/12] close stdin if needed --- src/scripts/extract_scored.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/scripts/extract_scored.py b/src/scripts/extract_scored.py index dc9b6e8..77823d0 100755 --- a/src/scripts/extract_scored.py +++ b/src/scripts/extract_scored.py @@ -6,8 +6,7 @@ import pysam from optparse import OptionParser import multiprocessing as mp -from functools import partial -import time + def setup_output_dir(output_base, chrom): @@ -230,5 +229,8 @@ def main(): if options.found_out: found_out.close() + if options.input: + stdin.close() + if __name__ == "__main__": main() From ceaa15c787e87514be823b726aaef2f6333d1953 Mon Sep 17 00:00:00 2001 From: yangyxt Date: Mon, 16 Dec 2024 15:30:57 +0800 Subject: [PATCH 03/12] migrate to python 2.7 --- src/scripts/extract_scored.py | 252 ++++++++++++++++++---------------- 1 file changed, 131 insertions(+), 121 deletions(-) diff --git a/src/scripts/extract_scored.py b/src/scripts/extract_scored.py index 77823d0..20b3383 100755 --- a/src/scripts/extract_scored.py +++ b/src/scripts/extract_scored.py @@ -7,15 +7,35 @@ from optparse import OptionParser import multiprocessing as mp - +def buffer_vcf_by_chromosome(input_stream): + """Read VCF from input stream and buffer by chromosome""" + vcf_by_chrom = {} + header_lines = [] + + for line in input_stream: + if line.startswith('#'): + header_lines.append(line) + continue + + fields = line.strip().split('\t') + chrom = fields[0] + + if chrom not in vcf_by_chrom: + vcf_by_chrom[chrom] = [] + vcf_by_chrom[chrom].append(line) + + return header_lines, vcf_by_chrom def setup_output_dir(output_base, chrom): """Create chromosome-specific output directory""" chrom_dir = os.path.join(output_base, chrom) - os.makedirs(chrom_dir, exist_ok=True) + try: + os.makedirs(chrom_dir) + except OSError: + if not os.path.isdir(chrom_dir): + raise return chrom_dir - def extract_prescored_chromosome(input_file, output_base, chrom): """Extract records for a single chromosome from prescored TSV file""" try: @@ -23,24 +43,19 @@ def extract_prescored_chromosome(input_file, output_base, chrom): chrom_dir = setup_output_dir(output_base, chrom) input_file_name = os.path.basename(input_file) input_file_name_base = input_file_name.rsplit('.', 1)[0] - output_file = os.path.join(chrom_dir, f"{input_file_name_base}.{chrom}.tsv") - compressed_file = f"{output_file}.gz" + output_file = os.path.join(chrom_dir, "{0}.{1}.tsv".format(input_file_name_base, chrom)) + compressed_file = "{0}.gz".format(output_file) # Check if extraction is needed if os.path.exists(compressed_file): if os.path.getmtime(compressed_file) > os.path.getmtime(input_file): return compressed_file - - # Consider the edge case where the chromsome does not exist in the input file - chromosomes = get_chromosomes(input_file) - if chrom not in chromosomes: - return None # Extract records for this chromosome using tabix tbx = pysam.TabixFile(input_file) with open(output_file, 'w') as f: for row in tbx.fetch(chrom): - f.write(f"{row}\n") + f.write("{0}\n".format(row)) # Compress and index the output file pysam.tabix_compress(output_file, compressed_file, force=True) @@ -57,28 +72,7 @@ def extract_prescored_chromosome(input_file, output_base, chrom): return compressed_file except Exception as e: - raise Exception(f"Error extracting prescored chromosome {chrom}: {str(e)}") - - -def buffer_vcf_by_chromosome(stdin): - """Read VCF from stdin and buffer by chromosome""" - vcf_by_chrom = {} - header_lines = [] - - for line in stdin: - if line.startswith('#'): - header_lines.append(line) - continue - - fields = line.strip().split('\t') - chrom = fields[0] - - if chrom not in vcf_by_chrom: - vcf_by_chrom[chrom] = [] - vcf_by_chrom[chrom].append(line) - - return header_lines, vcf_by_chrom - + raise Exception("Error extracting prescored chromosome {0}: {1}".format(chrom, str(e))) def process_chromosome(args): """Process a single chromosome""" @@ -92,9 +86,13 @@ def process_chromosome(args): ) # Setup output files for this chromosome - found_file = os.path.join(temp_dir, "matches", f"found.{chrom}.tmp") - notfound_file = os.path.join(temp_dir, "matches", f"notfound.{chrom}.tmp") - os.makedirs(os.path.dirname(found_file), exist_ok=True) + found_file = os.path.join(temp_dir, "matches", "found.{0}.tmp".format(chrom)) + notfound_file = os.path.join(temp_dir, "matches", "notfound.{0}.tmp".format(chrom)) + try: + os.makedirs(os.path.dirname(found_file)) + except OSError: + if not os.path.isdir(os.path.dirname(found_file)): + raise if prescored_chrom_file is None: # Create empty found file and output all records to notfound file @@ -127,16 +125,10 @@ def process_chromosome(args): return chrom, True except Exception as e: - sys.stderr.write(f'Error processing chromosome {chrom}: {str(e)}\n') + sys.stderr.write('Error processing chromosome {0}: {1}\n'.format(chrom, str(e))) return chrom, False -def get_chromosomes(vcf_file): - """Get chromosomes from input VCF file""" - vcf_tbx = pysam.TabixFile(vcf_file) - return sorted(vcf_tbx.contigs) - - def main(): parser = OptionParser() parser.add_option("-p", "--path", dest="path", help="Path to scored variants.") @@ -145,92 +137,110 @@ def main(): parser.add_option("--header", dest="header", help="Write full header to output (default none)", default=False, action="store_true") (options, args) = parser.parse_args() - - # Setup input/output files - stdin = open(options.input, 'r') if options.input else sys.stdin + + # Setup input stream + input_stream = sys.stdin + if options.input and options.input != "-": + try: + input_stream = open(options.input, 'r') + except IOError as e: + sys.stderr.write("Error opening input file: {0}\n".format(str(e))) + sys.exit(1) + + # Setup output stream found_out = open(options.found_out, 'w') if options.found_out else sys.stdout - # Create temporary directory - temp_dir = "temp_extract_scored" - os.makedirs(temp_dir, exist_ok=True) + try: + # Create temporary directory + temp_dir = "temp_extract_scored" + try: + os.makedirs(temp_dir) + except OSError: + if not os.path.isdir(temp_dir): + raise - # Initialize column indices - fpos, fref, falt = 1, 2, 3 - - # Check prescored file - if not (os.path.exists(options.path) and os.path.exists(options.path+".tbi")): - raise IOError("No valid file with pre-scored variants.\n") - - # Get header and column indices from prescored file - pre_tbx = pysam.TabixFile(options.path, 'r') - header = list(pre_tbx.header) - - # Write headers to output files if requested - if options.header: + # Initialize column indices + fpos, fref, falt = 1, 2, 3 + + # Check prescored file + if not (os.path.exists(options.path) and os.path.exists(options.path+".tbi")): + raise IOError("No valid file with pre-scored variants.\n") + + # Get header and column indices from prescored file + pre_tbx = pysam.TabixFile(options.path, 'r') + header = list(pre_tbx.header) + + # Write headers to output files if requested + if options.header: + for line in header: + found_out.write(line+"\n") + + # Get column indices from header for line in header: - found_out.write(line+"\n") - - # Get column indices from header - for line in header: + try: + fref = line.split('\t').index('Ref') + falt = line.split('\t').index('Alt') + except ValueError: + pass + + # Buffer VCF data and get chromosomes + header_lines, vcf_by_chrom = buffer_vcf_by_chromosome(input_stream) + chromosomes = sorted(vcf_by_chrom.keys()) + + # Write VCF headers to stdout + for line in header_lines: + sys.stdout.write(line) + + # Get number of threads from Snakemake + threads = int(os.environ.get("SNAKEMAKE_THREADS", "1")) + threads = min(threads, len(chromosomes)) + print("Using {0} threads to extract the scored variants across all chromosomes".format(threads), file=sys.stderr) + + # Setup parallel processing args + process_args = [ + (chrom, vcf_by_chrom[chrom], options.path, temp_dir, fpos, fref, falt) + for chrom in chromosomes + ] + + # Process chromosomes in parallel + pool = mp.Pool(threads) + results = pool.map(process_chromosome, process_args) + pool.close() + pool.join() + + # Combine results + for chrom, success in results: + if success: + found_file = os.path.join(temp_dir, "matches", "found.{0}.tmp".format(chrom)) + notfound_file = os.path.join(temp_dir, "matches", "notfound.{0}.tmp".format(chrom)) + + if os.path.exists(found_file): + with open(found_file) as f: + for line in f: + found_out.write(line) + os.remove(found_file) + + if os.path.exists(notfound_file): + with open(notfound_file) as f: + for line in f: + sys.stdout.write(line) + os.remove(notfound_file) + + # Cleanup try: - fref = line.split('\t').index('Ref') - falt = line.split('\t').index('Alt') - except ValueError: + import shutil + shutil.rmtree(temp_dir) + except: pass - # Buffer VCF data and get chromosomes - header_lines, vcf_by_chrom = buffer_vcf_by_chromosome(stdin) - chromosomes = sorted(vcf_by_chrom.keys()) - - # Write VCF headers to stdout - for line in header_lines: - sys.stdout.write(line) - - # Setup parallel processing args - process_args = [ - (chrom, tuple(vcf_by_chrom[chrom]), options.path, temp_dir, fpos, fref, falt) - for chrom in chromosomes - ] - - # Process chromosomes in parallel - # Get number of threads from Snakemake - threads = int(os.environ.get("SNAKEMAKE_THREADS", "1")) - threads = min(threads, len(chromosomes)) - print(f"Using {threads} threads to extract the scored variants across all chromosomes", file=sys.stderr) - with mp.Pool(threads) as pool: - results = pool.map(process_chromosome, process_args) - - # Combine results - for chrom, success in results: - if success: - found_file = os.path.join(temp_dir, "matches", f"found.{chrom}.tmp") - notfound_file = os.path.join(temp_dir, "matches", f"notfound.{chrom}.tmp") - - if os.path.exists(found_file): - with open(found_file) as f: - for line in f: - found_out.write(line) - os.remove(found_file) + finally: + # Close input file if it's not stdin + if options.input and options.input != "-": + input_stream.close() - if os.path.exists(notfound_file): - with open(notfound_file) as f: - for line in f: - sys.stdout.write(line) - os.remove(notfound_file) - - # Cleanup - try: - import shutil - shutil.rmtree(temp_dir) - except: - pass - - # Close files - if options.found_out: - found_out.close() - - if options.input: - stdin.close() + # Close output file if it's not stdout + if options.found_out: + found_out.close() if __name__ == "__main__": main() From 3714d18f03cd53234a1cc2df411757623894086f Mon Sep 17 00:00:00 2001 From: yangyxt Date: Mon, 16 Dec 2024 15:50:16 +0800 Subject: [PATCH 04/12] add a log line to show the used thread number --- src/scripts/extract_scored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scripts/extract_scored.py b/src/scripts/extract_scored.py index 20b3383..48d61d9 100755 --- a/src/scripts/extract_scored.py +++ b/src/scripts/extract_scored.py @@ -194,7 +194,7 @@ def main(): # Get number of threads from Snakemake threads = int(os.environ.get("SNAKEMAKE_THREADS", "1")) threads = min(threads, len(chromosomes)) - print("Using {0} threads to extract the scored variants across all chromosomes".format(threads), file=sys.stderr) + sys.stderr.write("Using {0} threads to extract the scored variants across all chromosomes\n".format(threads)) # Setup parallel processing args process_args = [ From 039e99ce2a3ea21acc7c31347d509be1af0df88c Mon Sep 17 00:00:00 2001 From: yangyxt Date: Mon, 16 Dec 2024 16:45:48 +0800 Subject: [PATCH 05/12] add logging lines --- src/scripts/extract_scored.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/scripts/extract_scored.py b/src/scripts/extract_scored.py index 48d61d9..044a38f 100755 --- a/src/scripts/extract_scored.py +++ b/src/scripts/extract_scored.py @@ -7,6 +7,7 @@ from optparse import OptionParser import multiprocessing as mp + def buffer_vcf_by_chromosome(input_stream): """Read VCF from input stream and buffer by chromosome""" vcf_by_chrom = {} @@ -26,6 +27,7 @@ def buffer_vcf_by_chromosome(input_stream): return header_lines, vcf_by_chrom + def setup_output_dir(output_base, chrom): """Create chromosome-specific output directory""" chrom_dir = os.path.join(output_base, chrom) @@ -36,6 +38,7 @@ def setup_output_dir(output_base, chrom): raise return chrom_dir + def extract_prescored_chromosome(input_file, output_base, chrom): """Extract records for a single chromosome from prescored TSV file""" try: @@ -49,6 +52,7 @@ def extract_prescored_chromosome(input_file, output_base, chrom): # Check if extraction is needed if os.path.exists(compressed_file): if os.path.getmtime(compressed_file) > os.path.getmtime(input_file): + sys.stderr.write("The prescored file {0} for chromosome {1} is up to date, skip the extraction\n".format(compressed_file, chrom)) return compressed_file # Extract records for this chromosome using tabix @@ -69,11 +73,13 @@ def extract_prescored_chromosome(input_file, output_base, chrom): # Remove uncompressed file os.remove(output_file) + sys.stderr.write("The prescored file {0} for chromosome {1} is extracted\n".format(compressed_file, chrom)) return compressed_file except Exception as e: raise Exception("Error extracting prescored chromosome {0}: {1}".format(chrom, str(e))) + def process_chromosome(args): """Process a single chromosome""" chrom, vcf_lines, prescored_file, temp_dir, fpos, fref, falt = args @@ -81,7 +87,7 @@ def process_chromosome(args): # First extract prescored records for this chromosome prescored_chrom_file = extract_prescored_chromosome( prescored_file, - os.path.join(temp_dir, "prescored"), + os.path.dirname(prescored_file), chrom ) @@ -151,14 +157,6 @@ def main(): found_out = open(options.found_out, 'w') if options.found_out else sys.stdout try: - # Create temporary directory - temp_dir = "temp_extract_scored" - try: - os.makedirs(temp_dir) - except OSError: - if not os.path.isdir(temp_dir): - raise - # Initialize column indices fpos, fref, falt = 1, 2, 3 @@ -186,16 +184,20 @@ def main(): # Buffer VCF data and get chromosomes header_lines, vcf_by_chrom = buffer_vcf_by_chromosome(input_stream) chromosomes = sorted(vcf_by_chrom.keys()) + sys.stderr.write("The chromosomes are {0}\n".format(chromosomes)) + sys.stderr.write("There are in total {} lines of records got from the buffer of the input VCF file\n".format(sum([len(vcf_by_chrom[chrom]) for chrom in chromosomes]))) # Write VCF headers to stdout for line in header_lines: sys.stdout.write(line) # Get number of threads from Snakemake - threads = int(os.environ.get("SNAKEMAKE_THREADS", "1")) + threads = int(os.environ.get("SNAKEMAKE_THREADS", "10")) threads = min(threads, len(chromosomes)) sys.stderr.write("Using {0} threads to extract the scored variants across all chromosomes\n".format(threads)) + temp_dir = os.environ.get("TMPDIR", "/tmp") + # Setup parallel processing args process_args = [ (chrom, vcf_by_chrom[chrom], options.path, temp_dir, fpos, fref, falt) From dbf1ec62372df275058252c3b072154ea1c377bb Mon Sep 17 00:00:00 2001 From: yangyxt Date: Mon, 16 Dec 2024 16:58:52 +0800 Subject: [PATCH 06/12] update Snakefile and adding a thread argument for checkpoint process --- Snakefile | 3 ++- src/scripts/extract_scored.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Snakefile b/Snakefile index 65ad44c..ecce445 100644 --- a/Snakefile +++ b/Snakefile @@ -77,6 +77,7 @@ checkpoint prescore: prescored=temp("{file}.pre.tsv"), log: "{file}.prescore.log", + threads: workflow.cores, params: cadd=os.environ["CADD"], shell: @@ -91,7 +92,7 @@ checkpoint prescore: do cat {input.vcf}.new \ | python {params.cadd}/src/scripts/extract_scored.py --header \ - -p $PRESCORED --found_out={output.prescored}.tmp \ + -p $PRESCORED --found_out={output.prescored}.tmp --threads {threads} \ > {input.vcf}.tmp 2>> {log}; cat {output.prescored}.tmp >> {output.prescored} mv {input.vcf}.tmp {input.vcf}.new &> {log}; diff --git a/src/scripts/extract_scored.py b/src/scripts/extract_scored.py index 044a38f..23117a5 100755 --- a/src/scripts/extract_scored.py +++ b/src/scripts/extract_scored.py @@ -45,7 +45,8 @@ def extract_prescored_chromosome(input_file, output_base, chrom): # Setup output directory chrom_dir = setup_output_dir(output_base, chrom) input_file_name = os.path.basename(input_file) - input_file_name_base = input_file_name.rsplit('.', 1)[0] + input_file_name_base = input_file_name.replace(".tsv.gz", "") + assert input_file_name_base != input_file_name, "The input file name {0} is not valid".format(input_file_name) output_file = os.path.join(chrom_dir, "{0}.{1}.tsv".format(input_file_name_base, chrom)) compressed_file = "{0}.gz".format(output_file) @@ -142,6 +143,7 @@ def main(): parser.add_option("--found_out", dest="found_out", help="Write found variants to file (default: stdout)", default=None) parser.add_option("--header", dest="header", help="Write full header to output (default none)", default=False, action="store_true") + parser.add_option("-t", "--threads", dest="threads", help="Number of threads to use (default: 1)", default=1) (options, args) = parser.parse_args() # Setup input stream @@ -192,7 +194,6 @@ def main(): sys.stdout.write(line) # Get number of threads from Snakemake - threads = int(os.environ.get("SNAKEMAKE_THREADS", "10")) threads = min(threads, len(chromosomes)) sys.stderr.write("Using {0} threads to extract the scored variants across all chromosomes\n".format(threads)) From 3097ec9f085444ee8f41516450b6bfd9ca61321f Mon Sep 17 00:00:00 2001 From: yangyxt Date: Mon, 16 Dec 2024 17:26:40 +0800 Subject: [PATCH 07/12] fix a syntax error --- src/scripts/extract_scored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scripts/extract_scored.py b/src/scripts/extract_scored.py index 23117a5..f2c5e62 100755 --- a/src/scripts/extract_scored.py +++ b/src/scripts/extract_scored.py @@ -194,7 +194,7 @@ def main(): sys.stdout.write(line) # Get number of threads from Snakemake - threads = min(threads, len(chromosomes)) + threads = min(options.threads, len(chromosomes)) sys.stderr.write("Using {0} threads to extract the scored variants across all chromosomes\n".format(threads)) temp_dir = os.environ.get("TMPDIR", "/tmp") From e479ad30c391464005e8a85d9dd1aba3e8d3becf Mon Sep 17 00:00:00 2001 From: yangyxt Date: Tue, 17 Dec 2024 15:19:52 +0800 Subject: [PATCH 08/12] remove the temp_dir clean part --- Snakefile | 12 +++++------- src/scripts/extract_scored.py | 7 ------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/Snakefile b/Snakefile index ecce445..97831cc 100644 --- a/Snakefile +++ b/Snakefile @@ -86,20 +86,18 @@ checkpoint prescore: echo '## Prescored variant file' > {output.prescored} 2> {log}; PRESCORED_FILES=`find -L {input.prescored} -maxdepth 1 -type f -name \\*.tsv.gz | wc -l` cp {input.vcf} {input.vcf}.new - if [ ${{PRESCORED_FILES}} -gt 0 ]; - then - for PRESCORED in $(ls {input.prescored}/*.tsv.gz) - do + if [ ${{PRESCORED_FILES}} -gt 0 ]; then + for PRESCORED in $(ls {input.prescored}/*.tsv.gz); do cat {input.vcf}.new \ | python {params.cadd}/src/scripts/extract_scored.py --header \ -p $PRESCORED --found_out={output.prescored}.tmp --threads {threads} \ > {input.vcf}.tmp 2>> {log}; cat {output.prescored}.tmp >> {output.prescored} - mv {input.vcf}.tmp {input.vcf}.new &> {log}; + mv {input.vcf}.tmp {input.vcf}.new 2>> {log}; done; - rm {output.prescored}.tmp &>> {log} + rm {output.prescored}.tmp 2>> {log} fi - mv {input.vcf}.new {output.novel} &>> {log} + mv {input.vcf}.new {output.novel} 2>> {log} """ diff --git a/src/scripts/extract_scored.py b/src/scripts/extract_scored.py index f2c5e62..9177530 100755 --- a/src/scripts/extract_scored.py +++ b/src/scripts/extract_scored.py @@ -228,13 +228,6 @@ def main(): for line in f: sys.stdout.write(line) os.remove(notfound_file) - - # Cleanup - try: - import shutil - shutil.rmtree(temp_dir) - except: - pass finally: # Close input file if it's not stdin From 1abe15af5713605ea7b4e16efdde0f1009a168dc Mon Sep 17 00:00:00 2001 From: yangyxt Date: Tue, 17 Dec 2024 15:57:26 +0800 Subject: [PATCH 09/12] optimize tabix indexing check --- src/scripts/extract_scored.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/scripts/extract_scored.py b/src/scripts/extract_scored.py index 9177530..6b49cc5 100755 --- a/src/scripts/extract_scored.py +++ b/src/scripts/extract_scored.py @@ -53,8 +53,24 @@ def extract_prescored_chromosome(input_file, output_base, chrom): # Check if extraction is needed if os.path.exists(compressed_file): if os.path.getmtime(compressed_file) > os.path.getmtime(input_file): - sys.stderr.write("The prescored file {0} for chromosome {1} is up to date, skip the extraction\n".format(compressed_file, chrom)) - return compressed_file + if os.path.exists(compressed_file + ".tbi"): + if os.path.getmtime(compressed_file + ".tbi") > os.path.getmtime(compressed_file): + sys.stderr.write("The prescored file {0} for chromosome {1} is up to date, skip the extraction\n".format(compressed_file, chrom)) + return compressed_file + else: + tabix_only=True + else: + tabix_only=True + + if tabix_only: + pysam.tabix_index(compressed_file, + preset=None, + force=True, + seq_col=0, + start_col=1, + end_col=1, + zerobased=False) + return compressed_file # Extract records for this chromosome using tabix tbx = pysam.TabixFile(input_file) From 00873b5ea8367108dce2d92b55cff80e312785a3 Mon Sep 17 00:00:00 2001 From: yangyxt Date: Thu, 19 Dec 2024 14:57:43 +0800 Subject: [PATCH 10/12] Refractor esmScore_inFrame_av.py to greatly improve the performance --- .../lib/tools/esmScore/esmScore_inFrame_av.py | 813 ++++++------------ 1 file changed, 264 insertions(+), 549 deletions(-) diff --git a/src/scripts/lib/tools/esmScore/esmScore_inFrame_av.py b/src/scripts/lib/tools/esmScore/esmScore_inFrame_av.py index a8ca043..d7497d1 100644 --- a/src/scripts/lib/tools/esmScore/esmScore_inFrame_av.py +++ b/src/scripts/lib/tools/esmScore/esmScore_inFrame_av.py @@ -1,25 +1,254 @@ """ Description: input is vep annotated vcf and a file containing all peptide sequences with Ensemble transcript IDs. -The script adds a score for frameshifts and stop gains to the info column of the vcf file. In brief, scores for inframe InDel variants were calculated for variants annotated with the Ensembl VEP tools' -inframe insertion or inframe deletion consequence annotation. Variants with missense consequence annotations were only used if multiple amino acids are substituted. -Variants with stop gain, stop lost, and stop retained consequence annotations were explicitly excluded. Amino acid sequences of reference alleles were obtained as described -above, with the only difference that a window of 250 amino acids was used. As for InDel variants multiple amino acids can be affected, the amino acid sequence corresponding to the -alternative allele can differ in more than one position from the sequence of the reference allele. To account for this, we generated the entire amino acid sequence of the alternative -allele using the Ensembl VEP tools' annotations and the reference sequence. To calculate scores for inframe InDel variants, log transformed probabilities of the entire reference and -alternative sequences were added up, respectively, and substracted from each other, yielding log odds ratios. -The log odds ratios resulting from each of the five models were than averaged and used as final score. -Author: thorben Maass +The script adds a score for frameshifts and stop gains to the info column of the vcf file. In brief, Scores for frameshift or stop gain variants were calculated +for variants annotated with the Ensembl VEP tools' frameshift +or stop gain consequence annotation. Amino acid sequences of reference alleles were obtained as described above, with the only difference that a window of 250 amino acids was used. +Log transformed probabilities were calculated and summed up for the entire reference amino acid sequence. +Calculation of a score for the alternative allele was carried out based on the entire reference amino acid sequence. Here, we summed up log probabilities of the reference +sequences amino acids up to the point where the frameshift or stop gain occurred as obtained from the Ensembl VEP tools' annotations. For every amino acid that is lost due +to the frameshift or stop gain, we used the median of log transformed probabilities calculated from all possible amino acids at each individual position in the remaining sequence +and added them to the sum corresponding to the alternative allele. +The average of logs odds ratios between the reference and alternative sequences from the five models was than used as a final score. + +Author: Thorben Maass, Max Schubach Contact: tho.maass@uni-luebeck.de Year:2023 -""" +Refractored by yangyxt (using numpy array instead of list appending, greatly improved the performance when dealing with huge VCF file) +""" +import warnings import numpy as np from Bio.bgzf import BgzfReader, BgzfWriter import torch from esm import pretrained import click +# Constants +WINDOW_SIZE = 250 +BATCH_SIZE = 20 + +def read_and_extract_vcf_data(input_file): + """Reads the VCF file and extracts relevant information.""" + vcf_data = [] + with BgzfReader(input_file, "r") as vcf_file: + for line in vcf_file: + vcf_data.append(line) + + info_pos = {} + for line in vcf_data: + if line.startswith("##INFO="): + info = line.split("|") + for i, item in enumerate(info): + if item in ("Feature", "Protein_position", "Amino_acids", "Consequence"): + info_pos[item] = i + if len(info_pos) == 4: + break + + # Preallocate NumPy arrays + num_variants = sum(1 for line in vcf_data if not line.startswith("#")) + variant_ids = np.empty(num_variants, dtype=object) + transcript_ids = np.empty(num_variants, dtype=object) + oAA = np.empty(num_variants, dtype=object) + nAA = np.empty(num_variants, dtype=object) + prot_pos_start = np.empty(num_variants, dtype=int) + prot_pos_end = np.empty(num_variants, dtype=int) + cons = np.empty(num_variants, dtype=object) + + idx = 0 + for variant in vcf_data: + if not variant.startswith("#"): + variant_entry = variant.split(",") + for i in range(len(variant_entry)): + variant_info = variant_entry[i].split("|") + consequences = variant_info[info_pos["Consequence"]].split("&") + if ("frameshift_variant" in consequences or "stop_gained" in consequences) and len(variant_info[info_pos["Amino_acids"]].split("/")) == 2: + variant_ids[idx] = variant_entry[0].split("|")[0] + transcript_ids[idx] = variant_info[info_pos["Feature"]] + cons[idx] = consequences + oAA[idx] = variant_info[info_pos["Amino_acids"]].split("/")[0] + nAA[idx] = variant_info[info_pos["Amino_acids"]].split("/")[1] + prot_pos_range = variant_info[info_pos["Protein_position"]].split("/")[0] + if "-" in prot_pos_range: + start, end = map(int, prot_pos_range.split("-")) + prot_pos_start[idx] = start + prot_pos_end[idx] = end + else: + pos = int(prot_pos_range) + prot_pos_start[idx] = pos + prot_pos_end[idx] = pos + idx += 1 + + # Trim arrays to actual size + variant_ids = variant_ids[:idx] + transcript_ids = transcript_ids[:idx] + cons = cons[:idx] + oAA = oAA[:idx] + nAA = nAA[:idx] + prot_pos_start = prot_pos_start[:idx] + prot_pos_end = prot_pos_end[:idx] + + return vcf_data, variant_ids, transcript_ids, oAA, nAA, prot_pos_start, prot_pos_end, cons + +def process_transcript_data(transcript_file, transcript_ids, prot_pos_start, prot_pos_end): + """Processes transcript data and creates aa_seq_ref.""" + with open(transcript_file, "r") as f: + transcript_info_entries = f.read().split(">")[1:] + + transcript_info = [] + transcript_info_id = [] + for entry in transcript_info_entries: + parts = entry.split(" ") + transcript_info.append(parts) + transcript_id_full = parts[4] + transcript_id = transcript_id_full.split(".")[0] + transcript_info_id.append(transcript_id) + + # Preallocate arrays + num_transcripts = len(transcript_ids) + aa_seq_ref = np.empty(num_transcripts, dtype=object) + total_stop_codons = np.zeros(num_transcripts, dtype=int) + stop_codons_before_mutation = np.zeros(num_transcripts, dtype=int) + stop_codons_in_indel = np.zeros(num_transcripts, dtype=int) + + for j, transcript_id in enumerate(transcript_ids): + transcript_found = False + for i, info_id in enumerate(transcript_info_id): + if info_id == transcript_id: + transcript_found = True + temp_seq = transcript_info[i][-1].replace("\n", "") + + stop_codons_before_mutation[j] = temp_seq[:prot_pos_start[j]].count("*") + stop_codons_in_indel[j] = temp_seq[prot_pos_start[j]:prot_pos_end[j]].count("*") + total_stop_codons[j] = temp_seq.count("*") + + aa_seq_ref[j] = temp_seq.replace("*", "") + break + + if not transcript_found: + aa_seq_ref[j] = "NA" + stop_codons_before_mutation[j] = 9999 + total_stop_codons[j] = 9999 + stop_codons_in_indel[j] = 9999 + + return aa_seq_ref, total_stop_codons, stop_codons_before_mutation, stop_codons_in_indel + +def prepare_data_for_esm(aa_seq, transcript_ids, prot_pos_start, stop_codons_before_mutation): + """Prepares data for the ESM model.""" + data = [] + prot_pos_mod = [] + for i in range(len(aa_seq)): + if aa_seq[i] == "NA": + continue + + adjusted_pos = prot_pos_start[i] - stop_codons_before_mutation[i] + seq_len = len(aa_seq[i]) + + if seq_len < WINDOW_SIZE: + data.append((transcript_ids[i], aa_seq[i])) + prot_pos_mod.append(adjusted_pos) + elif adjusted_pos + 1 + WINDOW_SIZE // 2 <= seq_len and adjusted_pos + 1 - WINDOW_SIZE // 2 >= 1: + start = adjusted_pos - WINDOW_SIZE // 2 + end = adjusted_pos + WINDOW_SIZE // 2 + data.append((transcript_ids[i], aa_seq[i][start:end])) + prot_pos_mod.append(WINDOW_SIZE // 2) + elif seq_len >= WINDOW_SIZE and adjusted_pos + 1 - WINDOW_SIZE // 2 < 1: + data.append((transcript_ids[i], aa_seq[i][:WINDOW_SIZE])) + prot_pos_mod.append(adjusted_pos) + else: + data.append((transcript_ids[i], aa_seq[i][-WINDOW_SIZE:])) + prot_pos_mod.append(adjusted_pos - (seq_len - WINDOW_SIZE)) + + return data, prot_pos_mod + +def calculate_esm_scores(data, prot_pos_mod, modelsToUse, conseq, batch_size=BATCH_SIZE): + """Runs the ESM model and calculates scores.""" + model_scores = [] + for model_name in modelsToUse: + torch.cuda.empty_cache() + model, alphabet = pretrained.load_model_and_alphabet(model_name) + model.eval() + batch_converter = alphabet.get_batch_converter() + + if torch.cuda.is_available(): + model = model.cuda() + + seq_scores = [] + for i in range(0, len(data), batch_size): + batch_data = data[i:i + batch_size] + batch_labels, batch_strs, batch_tokens = batch_converter(batch_data) + + with torch.no_grad(): + if torch.cuda.is_available(): + batch_tokens = batch_tokens.cuda() + token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1).cpu() + + for j, (transcript_id, seq) in enumerate(batch_data): + idx = i + j + if conseq[idx] == "FS": + score = 0 + for y, aa in enumerate(seq): + if y < prot_pos_mod[idx]: + score += token_probs[j, y + 1, alphabet.get_idx(aa)].item() + else: + aa_scores = [token_probs[j, y + 1, k].item() for k in range(4, 24)] + aa_scores.append(token_probs[j, y + 1, 26].item()) + aa_scores.sort() + mid = len(aa_scores) // 2 + median = (aa_scores[mid] + aa_scores[~mid]) / 2 + score += median + seq_scores.append(score) + elif conseq[idx] == "NA": + seq_scores.append(0) + + model_scores.append(seq_scores) + + return np.array(model_scores) + +def annotate_vcf_and_write_output(vcf_data, variant_ids, transcript_ids, np_array_score_diff, modelsToUse, output_file, aa_seq_ref): + """Adds scores to the VCF file and writes the output.""" + header_end = 0 + for i, line in enumerate(vcf_data): + if line.startswith("#CHROM"): + vcf_data[i - 1] += '##INFO=\n' + header_end = i + break + + vcf_data_modified = vcf_data[:header_end + 1] + + for i in range(header_end + 1, len(vcf_data)): + line = vcf_data[i] + new_line = line + j = 0 + while j < len(variant_ids): + if line.split("|")[0] == variant_ids[j]: + num_scores = 0 + for l in range(j, len(variant_ids)): + if line.split("|")[0] == variant_ids[l]: + num_scores += 1 + else: + break + + new_line = new_line[:-1] + ";EsmScoreFrameshift" + "=" + new_line[-1:] + + for h in range(num_scores): + if aa_seq_ref[j + h] != "NA": + avg_score = np.mean(np_array_score_diff[:, j + h]) + new_line = new_line[:-1] + "{0}|{1:.3f}".format(transcript_ids[j + h][11:], avg_score) + new_line[-1:] + else: + new_line = new_line[:-1] + "{0}|NA".format(transcript_ids[j + h][11:]) + new_line[-1:] + + if h < num_scores - 1: + new_line = new_line[:-1] + "," + new_line[-1:] + + j += num_scores + else: + j += 1 + vcf_data_modified.append(new_line) + + with BgzfWriter(output_file, "w") as vcf_file_output: + for line in vcf_data_modified: + vcf_file_output.write(line) @click.command() @click.option( @@ -70,559 +299,45 @@ "--batch-size", "batch_size", type=int, - default=20, + default=BATCH_SIZE, help="Batch size for esm model, default is 20", ) -def cli(input_file, transcript_file, model_directory, modelsToUse, output_file, batch_size): +def cli( + input_file, transcript_file, model_directory, modelsToUse, output_file, batch_size +): + """Main CLI function.""" torch.hub.set_dir(model_directory) - # get information from vcf file with SNVs and write them into lists - vcf_file_data = BgzfReader(input_file, "r") # TM_example.vcf.gz - vcf_data = [] - for line in vcf_file_data: - vcf_data.append(line) - - info_pos_Feature = False # TranscriptID - info_pos_ProteinPosition = False # resdidue in protein that is mutated - info_pos_AA = False # mutation from aa (amino acid) x to y - info_pos_consequence = False - # identify positions of annotations importnat for esm score - for line in vcf_data: - if line[0:7] == "##INFO=": - info = line.split("|") - for i in range(0, len(info), 1): - if info[i] == "Feature": - info_pos_Feature = i - if info[i] == "Protein_position": - info_pos_ProteinPosition = i - if info[i] == "Amino_acids": - info_pos_AA = i - if info[i] == "Consequence": - info_pos_consequence = i - break - - # extract annotations important for esm score, "NA" for non-coding variants - variant_ids = [] - transcript_id = [] - oAA = [] - nAA = [] - protPosStart = [] - protPosEnd = [] - protPos_mod = [] - cons = [] - # protPos_mod=[]#falls protein laenger als 1024 aa - - for variant in vcf_data: - if variant[0:1] != "#": - variant_entry = variant.split(",") - for i in range(0, len(variant_entry), 1): - variant_info = variant_entry[i].split("|") - consequences = variant_info[info_pos_consequence].split("&") - if ( - ( - "inframe_insertion" in consequences - or "inframe_deletion" in consequences - ) - and len(variant_info[info_pos_AA].split("/")) == 2 - and "stop_gained" not in consequences - and "stop_lost" not in consequences - and "stop_retained_variant" not in consequences - ): - variant_ids.append(variant_entry[0].split("|")[0]) - transcript_id.append("transcript:" + variant_info[info_pos_Feature]) - cons.append(variant_info[info_pos_consequence].split("&")) - - oAA.append( - variant_info[info_pos_AA].split("/")[0] - ) # can also be "-" if there is an insertion - nAA.append(variant_info[info_pos_AA].split("/")[1]) - if ( - "-" in variant_info[info_pos_ProteinPosition].split("/")[0] - ): # in case of frameshifts, vep only gives X as the new aa - protPosStart.append( - int( - variant_info[info_pos_ProteinPosition] - .split("/")[0] - .split("-")[0] - ) - ) - protPosEnd.append( - int( - variant_info[info_pos_ProteinPosition] - .split("/")[0] - .split("-")[1] - ) - ) - else: - protPosStart.append( - int(variant_info[info_pos_ProteinPosition].split("/")[0]) - ) - protPosEnd.append( - int(variant_info[info_pos_ProteinPosition].split("/")[0]) - ) - protPos_mod.append(False) - elif ( - "missense_variant" in consequences - and "-" in variant_info[info_pos_ProteinPosition].split("/")[0] - and len(variant_info[info_pos_AA].split("/")) == 2 - and "stop_gained" not in consequences - and "stop_lost" not in consequences - and "stop_retained_variant" not in consequences - ): - variant_ids.append(variant_entry[0].split("|")[0]) - transcript_id.append("transcript:" + variant_info[info_pos_Feature]) - cons.append(variant_info[info_pos_consequence].split("&")) - oAA.append( - variant_info[info_pos_AA].split("/")[0] - ) # can also be "-" if there is an insertion - nAA.append(variant_info[info_pos_AA].split("/")[1]) - protPosStart.append( - int( - variant_info[info_pos_ProteinPosition] - .split("/")[0] - .split("-")[0] - ) - ) - protPosEnd.append( - int( - variant_info[info_pos_ProteinPosition] - .split("/")[0] - .split("-")[1] - ) - ) - protPos_mod.append(False) - - # dissect file with all aa seqs to entries - transcript_data = open( - transcript_file, "r" - ) # - transcript_info_entries = transcript_data.read().split( - ">" - ) # evtl erstes > in file weglöschen - transcript_data.close() - transcript_info = [] - transcript_info_id = [] - - # transcript info contains aa seqs, becomes processed later - for i in range(0, len(transcript_info_entries), 1): - if transcript_info_entries[i] != "": - transcript_info.append(transcript_info_entries[i].split(" ")) - - # transcript ids - for i in range(0, len(transcript_info_entries), 1): - if transcript_info_entries[i] != "": - transcript_info_tmp = transcript_info_entries[i].split(" ")[4] - pointAt = False - # remove version of ENST ID vor comparison with vep annotation - for p in range(0, len(transcript_info_tmp), 1): - if transcript_info_tmp[p] == ".": - pointAt = p - - transcript_info_tmp = transcript_info_tmp[:pointAt] - - transcript_info_id.append(transcript_info_tmp) - - if (len(transcript_info_id)) != len(transcript_info): - print("ERROR!!!!!!") - - # create list with aa_seq_refs of transcript_ids, mal gucken, ob man alle auf einmal uebergebenkann an esm model - aa_seq_ref = [] - totalNumberOfStopCodons = [] - numberOfStopCodons = [] - numberOfStopCodonsInIndel = [] - for j in range(0, len(transcript_id), 1): - transcript_found = False - for i in range( - 1, len(transcript_info_id), 1 - ): # start bei 1 statt 0 weil das inputfile mit ">" anfaengt und 0. element in aa_seq_ref einfach [] ist - if transcript_info_id[i] == transcript_id[j]: # -2 damit ".9" usw wegfaellt - transcript_found = True - # prepare Seq remove remainings of header - temp_seq = transcript_info[i][-1] - for k in range(0, len(temp_seq), 1): - if temp_seq[k] != "\n": - k = k + 1 - else: - k = k + 1 - temp_seq = temp_seq[k:] - break - - # prepare seq (remove /n) - forbidden_chars = "\n" - for char in forbidden_chars: - temp_seq = temp_seq.replace(char, "") - - # count stop codons in seq before site of mutation - numberOfStopCodons.append(0) - if "*" in temp_seq: - for k in range(0, len(temp_seq), 1): - if temp_seq[k] == "*" and k < protPosStart[j]: - numberOfStopCodons[j] = numberOfStopCodons[j] + 1 - - # count stop codons in Indel - numberOfStopCodonsInIndel.append(0) - if "*" in temp_seq: - for k in range(0, len(temp_seq), 1): - if ( - temp_seq[k] == "*" - and k >= protPosStart[j] - and k < protPosEnd[j] - ): - numberOfStopCodonsInIndel[j] = ( - numberOfStopCodonsInIndel[j] + 1 - ) - - # count stop codons in seq - totalNumberOfStopCodons.append(0) - if "*" in temp_seq: - for k in range(0, len(temp_seq), 1): - if temp_seq[k] == "*": - totalNumberOfStopCodons[j] = totalNumberOfStopCodons[j] + 1 - - # remove additional stop codons (remove *) - forbidden_chars = "*" - for char in forbidden_chars: - temp_seq = temp_seq.replace(char, "") - - aa_seq_ref.append(temp_seq) - if transcript_found == False: - aa_seq_ref.append("NA") - numberOfStopCodons.append(9999) - totalNumberOfStopCodons.append(9999) - numberOfStopCodonsInIndel.append(9999) + vcf_data, variant_ids, transcript_ids, oAA, nAA, prot_pos_start, prot_pos_end, cons = read_and_extract_vcf_data(input_file) + aa_seq_ref, total_stop_codons, stop_codons_before_mutation, stop_codons_in_indel = process_transcript_data( + transcript_file, transcript_ids, prot_pos_start, prot_pos_end + ) conseq = [] aa_seq_alt = [] - for j in range(0, len(aa_seq_ref), 1): + for j in range(0, len(aa_seq_ref)): if aa_seq_ref[j] == "NA": aa_seq_alt.append("NA") conseq.append("NA") - - elif ( - len(nAA[j]) == len(oAA[j]) and "-" not in oAA[j] and "-" not in nAA[j] - ): # inframe ins wenn gleich viele weg kommen wie dazu ommen (zB AAA/GGG) - nAA_mod = nAA[j].replace( - "*", "" - ) # falls oAA und nAA ein nicht terminales stopp codon haben (A*G/P*K) - aa_seq_alt.append( - aa_seq_ref[j][0 : protPosStart[j] - numberOfStopCodons[j] - 1] - + nAA_mod - + aa_seq_ref[j][ - protPosEnd[j] - - numberOfStopCodons[j] - - numberOfStopCodonsInIndel[j] : - ] - ) - conseq.append("MultiMissense") - - elif ( - len(nAA[j]) >= len(oAA[j]) and "-" in oAA[j] - ): # inframe ins wenn keine alte aa weg kommt (zB -/GP) - aa_seq_alt.append( - aa_seq_ref[j][0 : protPosStart[j] - numberOfStopCodons[j] - 1] - + nAA[j] - + aa_seq_ref[j][protPosStart[j] - numberOfStopCodons[j] - 1 :] - ) - conseq.append("inFrame") - - elif len(nAA[j]) > len( - oAA[j] - ): # inframe ins wenn alte aa zerstoert wird (zB Q/PE) - nAA_mod = nAA[j].replace( - "*", "" - ) # falls aa vor eigentlichen stopp codon eingefuegt wird und altes stopp dabei zerstoert wird, zaehlt auch als inframe aber ohne stop gain (zB */Y*) - aa_seq_alt.append( - aa_seq_ref[j][0 : protPosStart[j] - numberOfStopCodons[j] - 1] - + nAA_mod - + aa_seq_ref[j][ - protPosEnd[j] - - numberOfStopCodons[j] - - numberOfStopCodonsInIndel[j] : - ] - ) - conseq.append("inFrame") - - elif ( - len(nAA[j]) <= len(oAA[j]) and "-" in nAA[j] - ): # inframe deletion wenn alte aa nicht zerstoert wird (zB QQ/-) - aa_seq_alt.append( - aa_seq_ref[j][0 : protPosStart[j] - numberOfStopCodons[j] - 1] - + aa_seq_ref[j][ - protPosEnd[j] - - numberOfStopCodonsInIndel[j] - - numberOfStopCodons[j] : - ] - ) - conseq.append("inFrame") - - elif len(nAA[j]) < len( - oAA[j] - ): # inframe deletion wenn alte aa zerstoert wird (zB KE/K) - nAA_mod = nAA[j].replace( - "*", "" - ) # falls alte aa und stop zerstoert wird und neues stopp eingefuegt wird, wird dann von vep nicht als stop gained bezeichnet (Y*/*) - aa_seq_alt.append( - aa_seq_ref[j][0 : protPosStart[j] - numberOfStopCodons[j] - 1] - + nAA_mod - + aa_seq_ref[j][ - protPosEnd[j] - - numberOfStopCodons[j] - - numberOfStopCodonsInIndel[j] : - ] - ) - conseq.append("inFrame") - - # prepare data array for esm model, Problem: only give the coding sequences i - - window = 250 - data_ref = [] - for i in range(0, len(transcript_id), 1): - if len(aa_seq_ref[i]) < window: - data_ref.append((transcript_id[i], aa_seq_ref[i])) - protPos_mod[i] = protPosStart[i] - numberOfStopCodons[i] - - elif ( - (len(aa_seq_ref[i]) >= window) - and ( - protPosStart[i] - numberOfStopCodons[i] + 1 + window / 2 - <= len(aa_seq_ref[i]) - ) - and (protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 >= 1) - ): - data_ref.append( - ( - transcript_id[i], - aa_seq_ref[i][ - protPosStart[i] - - numberOfStopCodons[i] - - int(window / 2) : protPosStart[i] - - numberOfStopCodons[i] - + int(window / 2) - ], - ) - ) # esm model can only handle 1024 amino acids, so if the sequence is longer , just the sequece around the mutaion i - protPos_mod[i] = int( - len( - aa_seq_ref[i][ - protPosStart[i] - - numberOfStopCodons[i] - - int(window / 2) : protPosStart[i] - - numberOfStopCodons[i] - + int(window / 2) - ] - ) - / 2 - ) - - elif ( - len(aa_seq_ref[i]) >= window - and protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 < 1 - ): - data_ref.append((transcript_id[i], aa_seq_ref[i][:window])) - protPos_mod[i] = protPosStart[i] - numberOfStopCodons[i] - - else: - data_ref.append((transcript_id[i], aa_seq_ref[i][-window:])) - protPos_mod[i] = ( - protPosStart[i] - numberOfStopCodons[i] - (len(aa_seq_ref[i]) - window) - ) - - data_alt = [] - - for i in range(0, len(transcript_id), 1): - if len(aa_seq_alt[i]) < window: - data_alt.append((transcript_id[i], aa_seq_alt[i])) - - elif ( - (len(aa_seq_alt[i]) >= window) - and ( - protPosStart[i] - numberOfStopCodons[i] + 1 + window / 2 - <= len(aa_seq_alt[i]) - ) - and (protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 >= 1) - ): - data_alt.append( - ( - transcript_id[i], - aa_seq_alt[i][ - protPosStart[i] - - numberOfStopCodons[i] - - int(window / 2) : protPosStart[i] - - numberOfStopCodons[i] - + int(window / 2) - ], - ) - ) # esm model can only handle 1024 amino acids, so if the sequence is longer , just the sequece around the mutaion i - - elif ( - len(aa_seq_alt[i]) >= window - and protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 < 1 - ): - data_alt.append((transcript_id[i], aa_seq_alt[i][:window])) - + elif "*" in nAA[j] or "X" in nAA[j]: # stop codon gained or complete frameshift + aa_seq_alt.append(aa_seq_ref[j]) # add alt seq without stop codon + conseq.append("FS") else: - data_alt.append((transcript_id[i], aa_seq_alt[i][-window:])) - - ref_alt_scores = [] - # load esm model(s) - for o in range(0, len([data_ref, data_alt]), 1): - data = [data_ref, data_alt][o] - modelScores = [] # scores of different models - if len(data) >= 1: - for k in range(0, len(modelsToUse), 1): - torch.cuda.empty_cache() - model, alphabet = pretrained.load_model_and_alphabet(modelsToUse[k]) - model.eval() # disables dropout for deterministic results - batch_converter = alphabet.get_batch_converter() - - if torch.cuda.is_available(): - model = model.cuda() - # print("transferred to GPU") - - # apply es model to sequence, tokenProbs hat probs von allen aa an jeder pos basierend auf der seq in "data" - seq_scores = [] - for t in range(0, len(data), batch_size): - # print (t) - if t + batch_size > len(data): - batch_data = data[t:] - else: - batch_data = data[t : t + batch_size] - - batch_labels, batch_strs, batch_tokens = batch_converter(batch_data) - with torch.no_grad(): # setzt irgeineine flag auf false - if torch.cuda.is_available(): - token_probs = torch.log_softmax( - model(batch_tokens.cuda())["logits"], dim=-1 - ) - else: - token_probs = torch.log_softmax( - model(batch_tokens)["logits"], dim=-1 - ) - - # test and extract scores from tokenProbs - if o == 1: # alt seqences - for i in range(0, len(batch_data), 1): - if ( - conseq[i + t] == "inFrame" - or conseq[i + t] == "MultiMissense" - ): - score = 0 - for y in range( - 0, len(batch_data[i][1]), 1 - ): # iterating over single AA in sequence - score = ( - score - + token_probs[ - i, - y + 1, - alphabet.get_idx(batch_data[i][1][y]), - ] - ) - seq_scores.append(float(score)) - - elif conseq[i + t] == "NA": - score = 0 - seq_scores.append(float(score)) - elif o == 0: # ref sequences - for i in range(0, len(batch_data), 1): - if ( - conseq[i + t] == "inFrame" - or conseq[i + t] == "MultiMissense" - ): - score = 0 - for y in range( - 0, len(batch_data[i][1]), 1 - ): # iterating over single AA in sequence - score = ( - score - + token_probs[ - i, - y + 1, - alphabet.get_idx(batch_data[i][1][y]), - ] - ) - seq_scores.append(float(score)) - - elif conseq[i + t] == "NA": - score = 999 # sollte nacher rausgeschissen werden, kein score sollte -999 sein - seq_scores.append(float(score)) - - modelScores.append(seq_scores) - ref_alt_scores.append(modelScores) - - np_array_scores = np.array(ref_alt_scores) - np_array_score_diff = np_array_scores[0] - np_array_scores[1] - - # write scores in cvf. file - - # get information from vcf file with SNVs and write them into lists (erstmal Bsp, später automatisch aus info zeile extrahieren) - - # identify positions of annotations important for esm score - header_end = False - for i in range(0, len(vcf_data), 1): - if vcf_data[i][0:6] == "#CHROM": - vcf_data[i - 1] = ( - vcf_data[i - 1] - + "##INFO=\n' + aa_seq_alt.append("NA") + conseq.append("NA") + warnings.warn( + "there is a problem with the ensembl data base and vep. The ESMframesift score of this variant will be artificially set to 0. Affected transcript is " + + str(transcript_ids[j]) ) - header_end = i - break - - for i in range(header_end + 1, len(vcf_data), 1): - j = 0 - while j < len(variant_ids): - if vcf_data[i].split("|")[0] == variant_ids[j]: - # count number of vep entires per variant that result in an esm score (i.e. with consequence "missense") - numberOfEsmScoresPerVariant = 0 - for l in range(j, len(variant_ids), 1): - if vcf_data[i].split("|")[0] == variant_ids[l]: - numberOfEsmScoresPerVariant = numberOfEsmScoresPerVariant + 1 - else: - break - # annotate vcf line with esm scores - # for k in range (0, len(modelsToUse), 1): - vcf_data[i] = ( - vcf_data[i][:-1] + ";EsmScoreInFrame" + "=" + vcf_data[i][-1:] - ) - for h in range(0, numberOfEsmScoresPerVariant, 1): - if aa_seq_ref[j + h] != "NA": - average_score = 0 - for k in range(0, len(modelsToUse), 1): - average_score = average_score + float( - np_array_score_diff[k][j + h] - ) - average_score = average_score / len(modelsToUse) - vcf_data[i] = ( - vcf_data[i][:-1] - + str(transcript_id[j + h][11:]) - + "|" - + str(round(float(average_score), 3)) - + vcf_data[i][-1:] - ) - else: - vcf_data[i] = ( - vcf_data[i][:-1] - + str(transcript_id[j + h][11:]) - + "|" - + "NA" - + vcf_data[i][-1:] - ) - - if h != numberOfEsmScoresPerVariant - 1: - vcf_data[i] = vcf_data[i][:-1] + "," + vcf_data[i][-1:] - - j = j + numberOfEsmScoresPerVariant - else: - j = j + 1 - - vcf_file_output = BgzfWriter(output_file, "w") - for line in vcf_data: - vcf_file_output.write(line) + data_ref, prot_pos_mod_ref = prepare_data_for_esm(aa_seq_ref, transcript_ids, prot_pos_start, stop_codons_before_mutation) + data_alt, prot_pos_mod_alt = prepare_data_for_esm(aa_seq_alt, transcript_ids, prot_pos_start, stop_codons_before_mutation) - vcf_file_output.close() + ref_scores = calculate_esm_scores(data_ref, prot_pos_mod_ref, modelsToUse, conseq, batch_size) + alt_scores = calculate_esm_scores(data_alt, prot_pos_mod_alt, modelsToUse, conseq, batch_size) + np_array_score_diff = alt_scores - ref_scores + annotate_vcf_and_write_output(vcf_data, variant_ids, transcript_ids, np_array_score_diff, modelsToUse, output_file, aa_seq_ref) if __name__ == "__main__": - cli() + cli() \ No newline at end of file From daf241c58733f59ade47ade3430d09b22b24ffa7 Mon Sep 17 00:00:00 2001 From: yangyxt Date: Thu, 19 Dec 2024 15:09:05 +0800 Subject: [PATCH 11/12] Refractor esmScore_frameshift_av.py to greatly improve the performance --- .../tools/esmScore/esmScore_frameshift_av.py | 681 +++++++----------- 1 file changed, 242 insertions(+), 439 deletions(-) diff --git a/src/scripts/lib/tools/esmScore/esmScore_frameshift_av.py b/src/scripts/lib/tools/esmScore/esmScore_frameshift_av.py index f6ba8bf..d7497d1 100644 --- a/src/scripts/lib/tools/esmScore/esmScore_frameshift_av.py +++ b/src/scripts/lib/tools/esmScore/esmScore_frameshift_av.py @@ -13,6 +13,8 @@ Author: Thorben Maass, Max Schubach Contact: tho.maass@uni-luebeck.de Year:2023 + +Refractored by yangyxt (using numpy array instead of list appending, greatly improved the performance when dealing with huge VCF file) """ import warnings @@ -22,6 +24,231 @@ from esm import pretrained import click +# Constants +WINDOW_SIZE = 250 +BATCH_SIZE = 20 + +def read_and_extract_vcf_data(input_file): + """Reads the VCF file and extracts relevant information.""" + vcf_data = [] + with BgzfReader(input_file, "r") as vcf_file: + for line in vcf_file: + vcf_data.append(line) + + info_pos = {} + for line in vcf_data: + if line.startswith("##INFO="): + info = line.split("|") + for i, item in enumerate(info): + if item in ("Feature", "Protein_position", "Amino_acids", "Consequence"): + info_pos[item] = i + if len(info_pos) == 4: + break + + # Preallocate NumPy arrays + num_variants = sum(1 for line in vcf_data if not line.startswith("#")) + variant_ids = np.empty(num_variants, dtype=object) + transcript_ids = np.empty(num_variants, dtype=object) + oAA = np.empty(num_variants, dtype=object) + nAA = np.empty(num_variants, dtype=object) + prot_pos_start = np.empty(num_variants, dtype=int) + prot_pos_end = np.empty(num_variants, dtype=int) + cons = np.empty(num_variants, dtype=object) + + idx = 0 + for variant in vcf_data: + if not variant.startswith("#"): + variant_entry = variant.split(",") + for i in range(len(variant_entry)): + variant_info = variant_entry[i].split("|") + consequences = variant_info[info_pos["Consequence"]].split("&") + if ("frameshift_variant" in consequences or "stop_gained" in consequences) and len(variant_info[info_pos["Amino_acids"]].split("/")) == 2: + variant_ids[idx] = variant_entry[0].split("|")[0] + transcript_ids[idx] = variant_info[info_pos["Feature"]] + cons[idx] = consequences + oAA[idx] = variant_info[info_pos["Amino_acids"]].split("/")[0] + nAA[idx] = variant_info[info_pos["Amino_acids"]].split("/")[1] + prot_pos_range = variant_info[info_pos["Protein_position"]].split("/")[0] + if "-" in prot_pos_range: + start, end = map(int, prot_pos_range.split("-")) + prot_pos_start[idx] = start + prot_pos_end[idx] = end + else: + pos = int(prot_pos_range) + prot_pos_start[idx] = pos + prot_pos_end[idx] = pos + idx += 1 + + # Trim arrays to actual size + variant_ids = variant_ids[:idx] + transcript_ids = transcript_ids[:idx] + cons = cons[:idx] + oAA = oAA[:idx] + nAA = nAA[:idx] + prot_pos_start = prot_pos_start[:idx] + prot_pos_end = prot_pos_end[:idx] + + return vcf_data, variant_ids, transcript_ids, oAA, nAA, prot_pos_start, prot_pos_end, cons + +def process_transcript_data(transcript_file, transcript_ids, prot_pos_start, prot_pos_end): + """Processes transcript data and creates aa_seq_ref.""" + with open(transcript_file, "r") as f: + transcript_info_entries = f.read().split(">")[1:] + + transcript_info = [] + transcript_info_id = [] + for entry in transcript_info_entries: + parts = entry.split(" ") + transcript_info.append(parts) + transcript_id_full = parts[4] + transcript_id = transcript_id_full.split(".")[0] + transcript_info_id.append(transcript_id) + + # Preallocate arrays + num_transcripts = len(transcript_ids) + aa_seq_ref = np.empty(num_transcripts, dtype=object) + total_stop_codons = np.zeros(num_transcripts, dtype=int) + stop_codons_before_mutation = np.zeros(num_transcripts, dtype=int) + stop_codons_in_indel = np.zeros(num_transcripts, dtype=int) + + for j, transcript_id in enumerate(transcript_ids): + transcript_found = False + for i, info_id in enumerate(transcript_info_id): + if info_id == transcript_id: + transcript_found = True + temp_seq = transcript_info[i][-1].replace("\n", "") + + stop_codons_before_mutation[j] = temp_seq[:prot_pos_start[j]].count("*") + stop_codons_in_indel[j] = temp_seq[prot_pos_start[j]:prot_pos_end[j]].count("*") + total_stop_codons[j] = temp_seq.count("*") + + aa_seq_ref[j] = temp_seq.replace("*", "") + break + + if not transcript_found: + aa_seq_ref[j] = "NA" + stop_codons_before_mutation[j] = 9999 + total_stop_codons[j] = 9999 + stop_codons_in_indel[j] = 9999 + + return aa_seq_ref, total_stop_codons, stop_codons_before_mutation, stop_codons_in_indel + +def prepare_data_for_esm(aa_seq, transcript_ids, prot_pos_start, stop_codons_before_mutation): + """Prepares data for the ESM model.""" + data = [] + prot_pos_mod = [] + for i in range(len(aa_seq)): + if aa_seq[i] == "NA": + continue + + adjusted_pos = prot_pos_start[i] - stop_codons_before_mutation[i] + seq_len = len(aa_seq[i]) + + if seq_len < WINDOW_SIZE: + data.append((transcript_ids[i], aa_seq[i])) + prot_pos_mod.append(adjusted_pos) + elif adjusted_pos + 1 + WINDOW_SIZE // 2 <= seq_len and adjusted_pos + 1 - WINDOW_SIZE // 2 >= 1: + start = adjusted_pos - WINDOW_SIZE // 2 + end = adjusted_pos + WINDOW_SIZE // 2 + data.append((transcript_ids[i], aa_seq[i][start:end])) + prot_pos_mod.append(WINDOW_SIZE // 2) + elif seq_len >= WINDOW_SIZE and adjusted_pos + 1 - WINDOW_SIZE // 2 < 1: + data.append((transcript_ids[i], aa_seq[i][:WINDOW_SIZE])) + prot_pos_mod.append(adjusted_pos) + else: + data.append((transcript_ids[i], aa_seq[i][-WINDOW_SIZE:])) + prot_pos_mod.append(adjusted_pos - (seq_len - WINDOW_SIZE)) + + return data, prot_pos_mod + +def calculate_esm_scores(data, prot_pos_mod, modelsToUse, conseq, batch_size=BATCH_SIZE): + """Runs the ESM model and calculates scores.""" + model_scores = [] + for model_name in modelsToUse: + torch.cuda.empty_cache() + model, alphabet = pretrained.load_model_and_alphabet(model_name) + model.eval() + batch_converter = alphabet.get_batch_converter() + + if torch.cuda.is_available(): + model = model.cuda() + + seq_scores = [] + for i in range(0, len(data), batch_size): + batch_data = data[i:i + batch_size] + batch_labels, batch_strs, batch_tokens = batch_converter(batch_data) + + with torch.no_grad(): + if torch.cuda.is_available(): + batch_tokens = batch_tokens.cuda() + token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1).cpu() + + for j, (transcript_id, seq) in enumerate(batch_data): + idx = i + j + if conseq[idx] == "FS": + score = 0 + for y, aa in enumerate(seq): + if y < prot_pos_mod[idx]: + score += token_probs[j, y + 1, alphabet.get_idx(aa)].item() + else: + aa_scores = [token_probs[j, y + 1, k].item() for k in range(4, 24)] + aa_scores.append(token_probs[j, y + 1, 26].item()) + aa_scores.sort() + mid = len(aa_scores) // 2 + median = (aa_scores[mid] + aa_scores[~mid]) / 2 + score += median + seq_scores.append(score) + elif conseq[idx] == "NA": + seq_scores.append(0) + + model_scores.append(seq_scores) + + return np.array(model_scores) + +def annotate_vcf_and_write_output(vcf_data, variant_ids, transcript_ids, np_array_score_diff, modelsToUse, output_file, aa_seq_ref): + """Adds scores to the VCF file and writes the output.""" + header_end = 0 + for i, line in enumerate(vcf_data): + if line.startswith("#CHROM"): + vcf_data[i - 1] += '##INFO=\n' + header_end = i + break + + vcf_data_modified = vcf_data[:header_end + 1] + + for i in range(header_end + 1, len(vcf_data)): + line = vcf_data[i] + new_line = line + j = 0 + while j < len(variant_ids): + if line.split("|")[0] == variant_ids[j]: + num_scores = 0 + for l in range(j, len(variant_ids)): + if line.split("|")[0] == variant_ids[l]: + num_scores += 1 + else: + break + + new_line = new_line[:-1] + ";EsmScoreFrameshift" + "=" + new_line[-1:] + + for h in range(num_scores): + if aa_seq_ref[j + h] != "NA": + avg_score = np.mean(np_array_score_diff[:, j + h]) + new_line = new_line[:-1] + "{0}|{1:.3f}".format(transcript_ids[j + h][11:], avg_score) + new_line[-1:] + else: + new_line = new_line[:-1] + "{0}|NA".format(transcript_ids[j + h][11:]) + new_line[-1:] + + if h < num_scores - 1: + new_line = new_line[:-1] + "," + new_line[-1:] + + j += num_scores + else: + j += 1 + vcf_data_modified.append(new_line) + + with BgzfWriter(output_file, "w") as vcf_file_output: + for line in vcf_data_modified: + vcf_file_output.write(line) @click.command() @click.option( @@ -72,196 +299,23 @@ "--batch-size", "batch_size", type=int, - default=20, + default=BATCH_SIZE, help="Batch size for esm model, default is 20", ) def cli( input_file, transcript_file, model_directory, modelsToUse, output_file, batch_size ): + """Main CLI function.""" torch.hub.set_dir(model_directory) - # get information from vcf file with SNVs and write them into lists (erstmal Bsp, später automatisch aus info zeile extrahieren) - vcf_file_data = BgzfReader(input_file, "r") # TM_example.vcf.gz - vcf_data = [] - for line in vcf_file_data: - vcf_data.append(line) - - info_pos_Feature = False # TranscriptID - info_pos_ProteinPosition = False # resdidue in protein that is mutated - info_pos_AA = False # mutation from aa (amino acid) x to y - info_pos_consequence = False - # identify positions of annotations importnat for esm score - for line in vcf_data: - if line[0:7] == "##INFO=": - info = line.split("|") - for i in range(0, len(info), 1): - if info[i] == "Feature": - info_pos_Feature = i - if info[i] == "Protein_position": - info_pos_ProteinPosition = i - if info[i] == "Amino_acids": - info_pos_AA = i - if info[i] == "Consequence": - info_pos_consequence = i - break - - # extract annotations important for esm score, "NA" for non-coding variants - variant_ids = [] - transcript_id = [] - oAA = [] - nAA = [] - protPosStart = [] - protPosEnd = [] - protPos_mod = [] - cons = [] - - for variant in vcf_data: - if variant[0:1] != "#": - variant_entry = variant.split(",") - for i in range(0, len(variant_entry), 1): - variant_info = variant_entry[i].split("|") - consequences = variant_info[info_pos_consequence].split("&") - if ( - "frameshift_variant" in consequences - or "stop_gained" in consequences - ) and len(variant_info[info_pos_AA].split("/")) == 2: - variant_ids.append(variant_entry[0].split("|")[0]) - transcript_id.append("transcript:" + variant_info[info_pos_Feature]) - cons.append(variant_info[info_pos_consequence].split("&")) - oAA.append( - variant_info[info_pos_AA].split("/")[0] - ) # can also be "-" if there is an insertion - nAA.append(variant_info[info_pos_AA].split("/")[1]) - if ( - "-" in variant_info[info_pos_ProteinPosition].split("/")[0] - ): # in case of frameshifts, vep only gives X as the new aa - protPosStart.append( - int( - variant_info[info_pos_ProteinPosition] - .split("/")[0] - .split("-")[0] - ) - ) - protPosEnd.append( - int( - variant_info[info_pos_ProteinPosition] - .split("/")[0] - .split("-")[1] - ) - ) - else: - protPosStart.append( - int(variant_info[info_pos_ProteinPosition].split("/")[0]) - ) - protPosEnd.append( - int(variant_info[info_pos_ProteinPosition].split("/")[0]) - ) - protPos_mod.append(False) - - # dissect file with all aa seqs to entries - transcript_data = open( - transcript_file, "r" - ) # - transcript_info_entries = transcript_data.read().split( - ">" - ) # evtl erstes > in file weglöschen - transcript_data.close() - transcript_info = [] - transcript_info_id = [] - - # transcript info contains aa seqs, becomes processed later - for i in range(0, len(transcript_info_entries), 1): - if transcript_info_entries[i] != "": - transcript_info.append(transcript_info_entries[i].split(" ")) - - # transcript ids - for i in range(0, len(transcript_info_entries), 1): - if transcript_info_entries[i] != "": - transcript_info_tmp = transcript_info_entries[i].split(" ")[4] - pointAt = False - # remove version of ENST ID vor comparison with vep annotation - for p in range(0, len(transcript_info_tmp), 1): - if transcript_info_tmp[p] == ".": - pointAt = p - - transcript_info_tmp = transcript_info_tmp[:pointAt] - transcript_info_id.append(transcript_info_tmp) - if (len(transcript_info_id)) != len(transcript_info): - print("ERROR!!!!!!") - - # create list with aa_seq_refs of transcript_ids, mal gucken, ob man alle auf einmal uebergebenkann an esm model - aa_seq_ref = [] - totalNumberOfStopCodons = [] - numberOfStopCodons = [] - numberOfStopCodonsInIndel = [] - for j in range(0, len(transcript_id), 1): - transcript_found = False - for i in range( - 1, len(transcript_info_id), 1 - ): # start bei 1 statt 0 weil das inputfile mit ">" anfaengt und 0. element in aa_seq_ref einfach [] ist - if transcript_info_id[i] == transcript_id[j]: - transcript_found = True - # prepare Seq remove remainings of header - temp_seq = transcript_info[i][-1] - for k in range(0, len(temp_seq), 1): - if temp_seq[k] != "\n": - k = k + 1 - else: - k = k + 1 - temp_seq = temp_seq[k:] - break - - # prepare seq (remove /n) - forbidden_chars = "\n" - for char in forbidden_chars: - temp_seq = temp_seq.replace(char, "") - - # count stop codons in seq before site of mutation - numberOfStopCodons.append(0) - if "*" in temp_seq: - for k in range(0, len(temp_seq), 1): - if temp_seq[k] == "*" and k < protPosStart[j]: - numberOfStopCodons[j] = numberOfStopCodons[j] + 1 - - # count stop codons in Indel - numberOfStopCodonsInIndel.append(0) - if "*" in temp_seq: - for k in range(0, len(temp_seq), 1): - if ( - temp_seq[k] == "*" - and k >= protPosStart[j] - and k < protPosEnd[j] - ): - numberOfStopCodonsInIndel[j] = ( - numberOfStopCodonsInIndel[j] + 1 - ) - - # count stop codons in seq - totalNumberOfStopCodons.append(0) - if "*" in temp_seq: - for k in range(0, len(temp_seq), 1): - if temp_seq[k] == "*": - totalNumberOfStopCodons[j] = totalNumberOfStopCodons[j] + 1 - - # remove additional stop codons (remove *) - forbidden_chars = "*" - for char in forbidden_chars: - temp_seq = temp_seq.replace(char, "") - - aa_seq_ref.append(temp_seq) - if transcript_found == False: - aa_seq_ref.append("NA") - numberOfStopCodons.append(9999) - totalNumberOfStopCodons.append(9999) - numberOfStopCodonsInIndel.append(9999) + vcf_data, variant_ids, transcript_ids, oAA, nAA, prot_pos_start, prot_pos_end, cons = read_and_extract_vcf_data(input_file) + aa_seq_ref, total_stop_codons, stop_codons_before_mutation, stop_codons_in_indel = process_transcript_data( + transcript_file, transcript_ids, prot_pos_start, prot_pos_end + ) conseq = [] aa_seq_alt = [] - for j in range(0, len(aa_seq_ref), 1): - # print(nAA[j]) - # print(oAA[j]) - # print("\n") - + for j in range(0, len(aa_seq_ref)): if aa_seq_ref[j] == "NA": aa_seq_alt.append("NA") conseq.append("NA") @@ -273,268 +327,17 @@ def cli( conseq.append("NA") warnings.warn( "there is a problem with the ensembl data base and vep. The ESMframesift score of this variant will be artificially set to 0. Affected transcript is " - + str(transcript_id[j]) - ) - - # prepare data array for esm model - - window = 250 - data_ref = [] - for i in range(0, len(transcript_id), 1): - if len(aa_seq_ref[i]) < window: - data_ref.append((transcript_id[i], aa_seq_ref[i])) - protPos_mod[i] = protPosStart[i] - numberOfStopCodons[i] - - elif ( - (len(aa_seq_ref[i]) >= window) - and ( - protPosStart[i] - numberOfStopCodons[i] + 1 + window / 2 - <= len(aa_seq_ref[i]) - ) - and (protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 >= 1) - ): - data_ref.append( - ( - transcript_id[i], - aa_seq_ref[i][ - protPosStart[i] - - numberOfStopCodons[i] - - int(window / 2) : protPosStart[i] - - numberOfStopCodons[i] - + int(window / 2) - ], - ) - ) # esm model can only handle 1024 amino acids, so if the sequence is longer , just the sequece around the mutaion i - protPos_mod[i] = int( - len( - aa_seq_ref[i][ - protPosStart[i] - - numberOfStopCodons[i] - - int(window / 2) : protPosStart[i] - - numberOfStopCodons[i] - + int(window / 2) - ] - ) - / 2 + + str(transcript_ids[j]) ) - elif ( - len(aa_seq_ref[i]) >= window - and protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 < 1 - ): - data_ref.append((transcript_id[i], aa_seq_ref[i][:window])) - protPos_mod[i] = protPosStart[i] - numberOfStopCodons[i] - - else: - data_ref.append((transcript_id[i], aa_seq_ref[i][-window:])) - protPos_mod[i] = ( - protPosStart[i] - numberOfStopCodons[i] - (len(aa_seq_ref[i]) - window) - ) - - data_alt = [] - - for i in range(0, len(transcript_id), 1): - if len(aa_seq_alt[i]) < window: - data_alt.append((transcript_id[i], aa_seq_alt[i])) - - elif ( - (len(aa_seq_alt[i]) >= window) - and ( - protPosStart[i] - numberOfStopCodons[i] + 1 + window / 2 - <= len(aa_seq_alt[i]) - ) - and (protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 >= 1) - ): - data_alt.append( - ( - transcript_id[i], - aa_seq_alt[i][ - protPosStart[i] - - numberOfStopCodons[i] - - int(window / 2) : protPosStart[i] - - numberOfStopCodons[i] - + int(window / 2) - ], - ) - ) # esm model can only handle 1024 amino acids, so if the sequence is longer , just the sequece around the mutaion i - - elif ( - len(aa_seq_alt[i]) >= window - and protPosStart[i] - numberOfStopCodons[i] + 1 - window / 2 < 1 - ): - data_alt.append((transcript_id[i], aa_seq_alt[i][:window])) - - else: - data_alt.append((transcript_id[i], aa_seq_alt[i][-window:])) - - ref_alt_scores = [] - # load esm model(s) - for o in range(0, len([data_ref, data_alt]), 1): - data = [data_ref, data_alt][o] - modelScores = [] # scores of different models - if len(data) >= 1: - for k in range(0, len(modelsToUse), 1): - torch.cuda.empty_cache() - model, alphabet = pretrained.load_model_and_alphabet(modelsToUse[k]) - model.eval() # disables dropout for deterministic results - batch_converter = alphabet.get_batch_converter() - - if torch.cuda.is_available(): - model = model.cuda() - # print("transferred to GPU") - - # apply es model to sequence, tokenProbs hat probs von allen aa an jeder pos basierend auf der seq in "data" - seq_scores = [] - for t in range(0, len(data), batch_size): - if t + batch_size > len(data): - batch_data = data[t:] - else: - batch_data = data[t : t + batch_size] - - batch_labels, batch_strs, batch_tokens = batch_converter(batch_data) - with torch.no_grad(): # setzt irgeineine flag auf false - if torch.cuda.is_available(): - token_probs = torch.log_softmax( - model(batch_tokens.cuda())["logits"], dim=-1 - ) - else: - token_probs = torch.log_softmax( - model(batch_tokens)["logits"], dim=-1 - ) - - # test and extract scores from tokenProbs - if o == 1: # alt seqences - for i in range(0, len(batch_data), 1): - # print (str(t+i)+" of "+ str(len(data))+ "alt seqs") - if conseq[i + t] == "FS": - score = 0 - for y in range( - 0, len(batch_data[i][1]), 1 - ): # iterating over single AA in sequence - if y < protPos_mod[i + t]: - score = ( - score - + token_probs[ - i, - y + 1, - alphabet.get_idx(batch_data[i][1][y]), - ] - ) - else: - # calc mean of all possible aa at this position - aa_scores = [] - for k in range(4, 24, 1): - aa_scores.append( - token_probs[i, y + 1, k] - ) # for all aa (except selenocystein) - aa_scores.append( - token_probs[i, y + 1, 26] - ) # for selenocystein - aa_scores.sort() - mid = len(aa_scores) // 2 - median = (aa_scores[mid] + aa_scores[~mid]) / 2 - score = score + median - - seq_scores.append(float(score)) - elif conseq[i + t] == "NA": - score = 0 - seq_scores.append(float(score)) - elif o == 0: # ref sequences - for i in range(0, len(batch_data), 1): - if conseq[i + t] == "FS": - score = 0 - for y in range( - 0, len(batch_data[i][1]), 1 - ): # iterating over single AA in sequence - score = ( - score - + token_probs[ - i, - y + 1, - alphabet.get_idx(batch_data[i][1][y]), - ] - ) - seq_scores.append(float(score)) - elif conseq[i + t] == "NA": - score = 999 # sollte nacher rausgeschissen werden, kein score sollte -999 sein - seq_scores.append(float(score)) - - modelScores.append(seq_scores) - ref_alt_scores.append(modelScores) - - np_array_scores = np.array(ref_alt_scores) - np_array_score_diff = np_array_scores[1] - np_array_scores[0] - - # write scores in cvf. file - - # get information from vcf file with SNVs and write them into lists (erstmal Bsp, später automatisch aus info zeile extrahieren) - - # identify positions of annotations important for esm score - header_end = False - for i in range(0, len(vcf_data), 1): - if vcf_data[i][0:6] == "#CHROM": - vcf_data[i - 1] = ( - vcf_data[i - 1] - + "##INFO=\n' - ) - header_end = i - break - - for i in range(header_end + 1, len(vcf_data), 1): - j = 0 - while j < len(variant_ids): - if vcf_data[i].split("|")[0] == variant_ids[j]: - # count number of vep entires per variant that result in an esm score (i.e. with consequence "missense") - numberOfEsmScoresPerVariant = 0 - for l in range(j, len(variant_ids), 1): - if vcf_data[i].split("|")[0] == variant_ids[l]: - numberOfEsmScoresPerVariant = numberOfEsmScoresPerVariant + 1 - else: - break - - # annotate vcf line with esm scores - # for k in range (0, len(modelsToUse), 1): - vcf_data[i] = ( - vcf_data[i][:-1] + ";EsmScoreFrameshift" + "=" + vcf_data[i][-1:] - ) - for h in range(0, numberOfEsmScoresPerVariant, 1): - if aa_seq_ref[j + h] != "NA": - average_score = 0 - for k in range(0, len(modelsToUse), 1): - average_score = average_score + float( - np_array_score_diff[k][j + h] - ) - average_score = average_score / len(modelsToUse) - vcf_data[i] = ( - vcf_data[i][:-1] - + str(transcript_id[j + h][11:]) - + "|" - + str(round(float(average_score), 3)) - + vcf_data[i][-1:] - ) - else: - vcf_data[i] = ( - vcf_data[i][:-1] - + str(transcript_id[j + h][11:]) - + "|" - + "NA" - + vcf_data[i][-1:] - ) - - if h != numberOfEsmScoresPerVariant - 1: - vcf_data[i] = vcf_data[i][:-1] + "," + vcf_data[i][-1:] - - j = j + numberOfEsmScoresPerVariant - else: - j = j + 1 - - vcf_file_output = BgzfWriter(output_file, "w") - for line in vcf_data: - vcf_file_output.write(line) + data_ref, prot_pos_mod_ref = prepare_data_for_esm(aa_seq_ref, transcript_ids, prot_pos_start, stop_codons_before_mutation) + data_alt, prot_pos_mod_alt = prepare_data_for_esm(aa_seq_alt, transcript_ids, prot_pos_start, stop_codons_before_mutation) - vcf_file_output.close() + ref_scores = calculate_esm_scores(data_ref, prot_pos_mod_ref, modelsToUse, conseq, batch_size) + alt_scores = calculate_esm_scores(data_alt, prot_pos_mod_alt, modelsToUse, conseq, batch_size) + np_array_score_diff = alt_scores - ref_scores + annotate_vcf_and_write_output(vcf_data, variant_ids, transcript_ids, np_array_score_diff, modelsToUse, output_file, aa_seq_ref) if __name__ == "__main__": - cli() + cli() \ No newline at end of file From 5e99d6becb45680232f12d45aa2d4cf45997c5c6 Mon Sep 17 00:00:00 2001 From: yangyxt Date: Thu, 19 Dec 2024 15:10:33 +0800 Subject: [PATCH 12/12] Refractor esmScore_inFrame_av.py to greatly improve the performance --- .../lib/tools/esmScore/esmScore_inFrame_av.py | 225 +++++++++--------- 1 file changed, 118 insertions(+), 107 deletions(-) diff --git a/src/scripts/lib/tools/esmScore/esmScore_inFrame_av.py b/src/scripts/lib/tools/esmScore/esmScore_inFrame_av.py index d7497d1..93bf320 100644 --- a/src/scripts/lib/tools/esmScore/esmScore_inFrame_av.py +++ b/src/scripts/lib/tools/esmScore/esmScore_inFrame_av.py @@ -1,23 +1,20 @@ """ Description: input is vep annotated vcf and a file containing all peptide sequences with Ensemble transcript IDs. -The script adds a score for frameshifts and stop gains to the info column of the vcf file. In brief, Scores for frameshift or stop gain variants were calculated -for variants annotated with the Ensembl VEP tools' frameshift -or stop gain consequence annotation. Amino acid sequences of reference alleles were obtained as described above, with the only difference that a window of 250 amino acids was used. -Log transformed probabilities were calculated and summed up for the entire reference amino acid sequence. -Calculation of a score for the alternative allele was carried out based on the entire reference amino acid sequence. Here, we summed up log probabilities of the reference -sequences amino acids up to the point where the frameshift or stop gain occurred as obtained from the Ensembl VEP tools' annotations. For every amino acid that is lost due -to the frameshift or stop gain, we used the median of log transformed probabilities calculated from all possible amino acids at each individual position in the remaining sequence -and added them to the sum corresponding to the alternative allele. -The average of logs odds ratios between the reference and alternative sequences from the five models was than used as a final score. - +The script adds a score for frameshifts and stop gains to the info column of the vcf file. In brief, scores for inframe InDel variants were calculated for variants annotated with the Ensembl VEP tools' +inframe insertion or inframe deletion consequence annotation. Variants with missense consequence annotations were only used if multiple amino acids are substituted. +Variants with stop gain, stop lost, and stop retained consequence annotations were explicitly excluded. Amino acid sequences of reference alleles were obtained as described +above, with the only difference that a window of 250 amino acids was used. As for InDel variants multiple amino acids can be affected, the amino acid sequence corresponding to the +alternative allele can differ in more than one position from the sequence of the reference allele. To account for this, we generated the entire amino acid sequence of the alternative +allele using the Ensembl VEP tools' annotations and the reference sequence. To calculate scores for inframe InDel variants, log transformed probabilities of the entire reference and +alternative sequences were added up, respectively, and substracted from each other, yielding log odds ratios. +The log odds ratios resulting from each of the five models were than averaged and used as final score. Author: Thorben Maass, Max Schubach Contact: tho.maass@uni-luebeck.de Year:2023 -Refractored by yangyxt (using numpy array instead of list appending, greatly improved the performance when dealing with huge VCF file) +Refractored by yangyxt, replacing list appending with numpy array storing. Improving computation speed by 20x when dealing with large amount of variants. """ -import warnings import numpy as np from Bio.bgzf import BgzfReader, BgzfWriter import torch @@ -62,9 +59,9 @@ def read_and_extract_vcf_data(input_file): for i in range(len(variant_entry)): variant_info = variant_entry[i].split("|") consequences = variant_info[info_pos["Consequence"]].split("&") - if ("frameshift_variant" in consequences or "stop_gained" in consequences) and len(variant_info[info_pos["Amino_acids"]].split("/")) == 2: + if ("inframe_insertion" in consequences or "inframe_deletion" in consequences or "missense_variant" in consequences) and len(variant_info[info_pos["Amino_acids"]].split("/")) == 2: variant_ids[idx] = variant_entry[0].split("|")[0] - transcript_ids[idx] = variant_info[info_pos["Feature"]] + transcript_ids[idx] = variant_info[info_pos["Feature"]].split(".")[0] cons[idx] = consequences oAA[idx] = variant_info[info_pos["Amino_acids"]].split("/")[0] nAA[idx] = variant_info[info_pos["Amino_acids"]].split("/")[1] @@ -133,49 +130,47 @@ def process_transcript_data(transcript_file, transcript_ids, prot_pos_start, pro return aa_seq_ref, total_stop_codons, stop_codons_before_mutation, stop_codons_in_indel -def prepare_data_for_esm(aa_seq, transcript_ids, prot_pos_start, stop_codons_before_mutation): - """Prepares data for the ESM model.""" +def prepare_data_for_esm(aa_seq, transcript_ids, prot_pos_start, stop_codons_before_mutation, window_size): + """Prepares data for ESM, handling windowing and edge cases.""" data = [] - prot_pos_mod = [] - for i in range(len(aa_seq)): - if aa_seq[i] == "NA": + prot_pos_mod = np.copy(prot_pos_start) + + for i, seq in enumerate(aa_seq): + if seq == "NA": continue - adjusted_pos = prot_pos_start[i] - stop_codons_before_mutation[i] - seq_len = len(aa_seq[i]) - - if seq_len < WINDOW_SIZE: - data.append((transcript_ids[i], aa_seq[i])) - prot_pos_mod.append(adjusted_pos) - elif adjusted_pos + 1 + WINDOW_SIZE // 2 <= seq_len and adjusted_pos + 1 - WINDOW_SIZE // 2 >= 1: - start = adjusted_pos - WINDOW_SIZE // 2 - end = adjusted_pos + WINDOW_SIZE // 2 - data.append((transcript_ids[i], aa_seq[i][start:end])) - prot_pos_mod.append(WINDOW_SIZE // 2) - elif seq_len >= WINDOW_SIZE and adjusted_pos + 1 - WINDOW_SIZE // 2 < 1: - data.append((transcript_ids[i], aa_seq[i][:WINDOW_SIZE])) - prot_pos_mod.append(adjusted_pos) + if len(seq) < window_size: + data.append((transcript_ids[i], seq)) + prot_pos_mod[i] -= stop_codons_before_mutation[i] else: - data.append((transcript_ids[i], aa_seq[i][-WINDOW_SIZE:])) - prot_pos_mod.append(adjusted_pos - (seq_len - WINDOW_SIZE)) + start = prot_pos_start[i] - stop_codons_before_mutation[i] + if start + 1 + window_size // 2 <= len(seq) and start + 1 - window_size // 2 >= 1: + seq_window = seq[start - window_size // 2 : start + window_size // 2] + data.append((transcript_ids[i], seq_window)) + prot_pos_mod[i] = len(seq_window) // 2 + elif start + 1 - window_size // 2 < 1: + data.append((transcript_ids[i], seq[:window_size])) + prot_pos_mod[i] = start + else: + data.append((transcript_ids[i], seq[-window_size:])) + prot_pos_mod[i] = start - (len(seq) - window_size) return data, prot_pos_mod -def calculate_esm_scores(data, prot_pos_mod, modelsToUse, conseq, batch_size=BATCH_SIZE): - """Runs the ESM model and calculates scores.""" +def calculate_esm_scores(data, prot_pos_mod, modelsToUse, conseq, batch_size): + """Calculates ESM scores for a given dataset.""" model_scores = [] for model_name in modelsToUse: - torch.cuda.empty_cache() model, alphabet = pretrained.load_model_and_alphabet(model_name) - model.eval() batch_converter = alphabet.get_batch_converter() - + model.eval() if torch.cuda.is_available(): - model = model.cuda() + model.cuda() seq_scores = [] for i in range(0, len(data), batch_size): - batch_data = data[i:i + batch_size] + batch_end = min(i + batch_size, len(data)) + batch_data = data[i:batch_end] batch_labels, batch_strs, batch_tokens = batch_converter(batch_data) with torch.no_grad(): @@ -183,72 +178,45 @@ def calculate_esm_scores(data, prot_pos_mod, modelsToUse, conseq, batch_size=BAT batch_tokens = batch_tokens.cuda() token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1).cpu() - for j, (transcript_id, seq) in enumerate(batch_data): - idx = i + j - if conseq[idx] == "FS": - score = 0 + for j, (_, seq) in enumerate(batch_data): + score = 0 + if conseq[i + j] in ["inFrame", "MultiMissense"]: for y, aa in enumerate(seq): - if y < prot_pos_mod[idx]: - score += token_probs[j, y + 1, alphabet.get_idx(aa)].item() - else: - aa_scores = [token_probs[j, y + 1, k].item() for k in range(4, 24)] - aa_scores.append(token_probs[j, y + 1, 26].item()) - aa_scores.sort() - mid = len(aa_scores) // 2 - median = (aa_scores[mid] + aa_scores[~mid]) / 2 - score += median - seq_scores.append(score) - elif conseq[idx] == "NA": - seq_scores.append(0) + score += token_probs[j, y + 1, alphabet.get_idx(aa)].item() + elif conseq[i + j] == "NA": + score = 999 + seq_scores.append(score) model_scores.append(seq_scores) return np.array(model_scores) -def annotate_vcf_and_write_output(vcf_data, variant_ids, transcript_ids, np_array_score_diff, modelsToUse, output_file, aa_seq_ref): - """Adds scores to the VCF file and writes the output.""" +def annotate_vcf_and_write_output(vcf_data, variant_ids, transcript_ids, score_diff, modelsToUse, output_file, aa_seq_ref): + """Annotates the VCF with ESM scores and writes the output.""" header_end = 0 for i, line in enumerate(vcf_data): if line.startswith("#CHROM"): - vcf_data[i - 1] += '##INFO=\n' + vcf_data[i - 1] += "##INFO=\n" header_end = i break - vcf_data_modified = vcf_data[:header_end + 1] + vcf_output = BgzfWriter(output_file, "w") + for line in vcf_data[:header_end + 1]: + vcf_output.write(line) - for i in range(header_end + 1, len(vcf_data)): - line = vcf_data[i] - new_line = line + for i, line in enumerate(vcf_data[header_end + 1:]): j = 0 while j < len(variant_ids): if line.split("|")[0] == variant_ids[j]: - num_scores = 0 - for l in range(j, len(variant_ids)): - if line.split("|")[0] == variant_ids[l]: - num_scores += 1 - else: - break - - new_line = new_line[:-1] + ";EsmScoreFrameshift" + "=" + new_line[-1:] - - for h in range(num_scores): - if aa_seq_ref[j + h] != "NA": - avg_score = np.mean(np_array_score_diff[:, j + h]) - new_line = new_line[:-1] + "{0}|{1:.3f}".format(transcript_ids[j + h][11:], avg_score) + new_line[-1:] - else: - new_line = new_line[:-1] + "{0}|NA".format(transcript_ids[j + h][11:]) + new_line[-1:] - - if h < num_scores - 1: - new_line = new_line[:-1] + "," + new_line[-1:] - - j += num_scores - else: - j += 1 - vcf_data_modified.append(new_line) - - with BgzfWriter(output_file, "w") as vcf_file_output: - for line in vcf_data_modified: - vcf_file_output.write(line) + if aa_seq_ref[j] != "NA": + score_string = [str(round(score_diff[m][j], 3)) for m in range(len(modelsToUse))] + line = line.strip() + ";EsmScoreInFrame=" + ",".join(score_string) + "\n" + else: + line = line.strip() + ";EsmScoreInFrame=NA\n" + break + j += 1 + vcf_output.write(line) + vcf_output.close() @click.command() @click.option( @@ -314,28 +282,71 @@ def cli( ) conseq = [] - aa_seq_alt = [] - for j in range(0, len(aa_seq_ref)): + aa_seq_alt = np.empty(len(aa_seq_ref), dtype=object) + for j in range(0, len(aa_seq_ref), 1): if aa_seq_ref[j] == "NA": - aa_seq_alt.append("NA") - conseq.append("NA") - elif "*" in nAA[j] or "X" in nAA[j]: # stop codon gained or complete frameshift - aa_seq_alt.append(aa_seq_ref[j]) # add alt seq without stop codon - conseq.append("FS") - else: - aa_seq_alt.append("NA") + aa_seq_alt[j] = "NA" conseq.append("NA") - warnings.warn( - "there is a problem with the ensembl data base and vep. The ESMframesift score of this variant will be artificially set to 0. Affected transcript is " - + str(transcript_ids[j]) + elif len(nAA[j]) == len(oAA[j]) and "-" not in oAA[j] and "-" not in nAA[j]: + nAA_mod = nAA[j].replace("*", "") + aa_seq_alt[j] = ( + aa_seq_ref[j][: prot_pos_start[j] - stop_codons_before_mutation[j] - 1] + + nAA_mod + + aa_seq_ref[j][ + prot_pos_end[j] + - stop_codons_before_mutation[j] + - stop_codons_in_indel[j] : + ] + ) + conseq.append("MultiMissense") + elif len(nAA[j]) >= len(oAA[j]) and "-" in oAA[j]: + aa_seq_alt[j] = ( + aa_seq_ref[j][: prot_pos_start[j] - stop_codons_before_mutation[j] - 1] + + nAA[j] + + aa_seq_ref[j][prot_pos_start[j] - stop_codons_before_mutation[j] - 1 :] + ) + conseq.append("inFrame") + elif len(nAA[j]) > len(oAA[j]): + nAA_mod = nAA[j].replace("*", "") + aa_seq_alt[j] = ( + aa_seq_ref[j][: prot_pos_start[j] - stop_codons_before_mutation[j] - 1] + + nAA_mod + + aa_seq_ref[j][ + prot_pos_end[j] + - stop_codons_before_mutation[j] + - stop_codons_in_indel[j] : + ] + ) + conseq.append("inFrame") + elif len(nAA[j]) <= len(oAA[j]) and "-" in nAA[j]: + aa_seq_alt[j] = ( + aa_seq_ref[j][: prot_pos_start[j] - stop_codons_before_mutation[j] - 1] + + aa_seq_ref[j][ + prot_pos_end[j] + - stop_codons_in_indel[j] + - stop_codons_before_mutation[j] : + ] + ) + conseq.append("inFrame") + elif len(nAA[j]) < len(oAA[j]): + nAA_mod = nAA[j].replace("*", "") + aa_seq_alt[j] = ( + aa_seq_ref[j][: prot_pos_start[j] - stop_codons_before_mutation[j] - 1] + + nAA_mod + + aa_seq_ref[j][ + prot_pos_end[j] + - stop_codons_before_mutation[j] + - stop_codons_in_indel[j] : + ] ) + conseq.append("inFrame") - data_ref, prot_pos_mod_ref = prepare_data_for_esm(aa_seq_ref, transcript_ids, prot_pos_start, stop_codons_before_mutation) - data_alt, prot_pos_mod_alt = prepare_data_for_esm(aa_seq_alt, transcript_ids, prot_pos_start, stop_codons_before_mutation) + data_ref, prot_pos_mod_ref = prepare_data_for_esm(aa_seq_ref, transcript_ids, prot_pos_start, stop_codons_before_mutation, WINDOW_SIZE) + data_alt, prot_pos_mod_alt = prepare_data_for_esm(aa_seq_alt, transcript_ids, prot_pos_start, stop_codons_before_mutation, WINDOW_SIZE) ref_scores = calculate_esm_scores(data_ref, prot_pos_mod_ref, modelsToUse, conseq, batch_size) alt_scores = calculate_esm_scores(data_alt, prot_pos_mod_alt, modelsToUse, conseq, batch_size) - np_array_score_diff = alt_scores - ref_scores + np_array_score_diff = ref_scores - alt_scores annotate_vcf_and_write_output(vcf_data, variant_ids, transcript_ids, np_array_score_diff, modelsToUse, output_file, aa_seq_ref)