#!/bin/env python

import argparse
import pandas as pd
from pybedtools import BedTool
import gzip
import numpy as np

parser = argparse.ArgumentParser()

parser.add_argument("--input", "-i", type=str, required=True, help="Input VCF file")
parser.add_argument("--output", "-o", type=str, required=True, help="Output VCF file")
parser.add_argument("--freq", "-f", type=str, required=True, help='SV Frequency table with columns ["#CHROM", "POS", "END", "SVTYPE", "SVLEN", "ID"] in any order')

args = parser.parse_args()


def svtype_classify(svtype):
	if svtype > 0:
		return 'DEL'
	elif svtype < 0:
		return 'INS'
	else:
		return 'SNV'


if args.input.endswith('gz'):
	file_handle = gzip.open(args.input, 'rt')
else:
	file_handle = open(args.input, 'r')

# Handle header and extract header names
header = []

with file_handle as infile:
	for line in infile:
		if line.startswith('##'):
			header.append(line)
		elif line.startswith('#'):
			columns = line.rstrip().split('\t')
		else:
			break


header.append('##INFO=<ID=FREQ,Number=.,Type=Float,Description="Population Allele Frequency">\n')

vcf = pd.read_csv(args.input, sep='\t', header=None, comment="#", names=columns)

vcf['SVLEN'] = vcf.apply(lambda row: len(row['REF']) - len(row['ALT']) if row['ALT'] != '.' else -1, axis=1)
vcf['SVTYPE'] = vcf.apply(lambda row: svtype_classify(row['SVLEN']), axis=1)
vcf['SVLEN'] = np.abs(vcf['SVLEN'])

print('Reading in SV database')
sv_df = pd.read_csv(args.freq, sep='\t')

vcf['END'] = vcf.apply(lambda row: row['POS'] + row['SVLEN'] if row['SVTYPE'] == 'DEL' else row['POS']+1, axis=1)

vcf_int_check = vcf.copy()

print('Prepping input files')
# Add buffer for intersect
vcf_int_check['BED_POS'] = vcf_int_check.apply(lambda row: max(0,row['POS']-200), axis=1)
sv_df['BED_POS'] = sv_df.apply(lambda row: max(0,row['POS']-200), axis=1)
vcf_int_check['BED_END'] = vcf_int_check.apply(lambda row: max(0,row['END']+200), axis=1) 
sv_df['BED_END'] = sv_df.apply(lambda row: max(0,row['END']+200), axis=1)

# Check that IDs are unique
if len(vcf_int_check['ID'].unique()) < len(vcf_int_check):
	vcf_int_check['ID'] = vcf_int_check.index
	index_change = True
else:
	index_change = False

merge_dict = {}
for svtype in vcf_int_check['SVTYPE'].unique():
	print(f'Intersecting {svtype}')
	if svtype != 'SNV':
		vcf_bed = BedTool.from_dataframe(vcf_int_check.loc[vcf_int_check['SVTYPE'] == svtype][['#CHROM', 'BED_POS', 'BED_END', 'ID', 'SVTYPE', 'SVLEN']])
		sv_bed = BedTool.from_dataframe(sv_df.loc[sv_df['SVTYPE'] == svtype][['#CHROM', 'BED_POS', 'BED_END', 'ID', 'SVTYPE', 'SVLEN']])
		if svtype == 'INS':
			int_df = vcf_bed.intersect(sv_bed, wa=True, wb=True).to_dataframe(names=['#CHROM_VCF', 'POS_VCF', 'END_VCF', 'ID_VCF', 'SVTYPE_VCF', 'SVLEN_VCF', '#CHROM_FREQ', 'POS_FREQ', 'END_FREQ', 'ID_FREQ', 'SVTYPE_FREQ', 'SVLEN_FREQ'])
			if len(int_df) > 0:
				int_df = int_df.loc[(int_df['SVLEN_VCF']/int_df['SVLEN_FREQ'] <= 2) & (int_df['SVLEN_VCF']/int_df['SVLEN_FREQ'] >= 0.5) ]
		else:
			int_df = vcf_bed.intersect(sv_bed, f=0.5, r=True, wa=True, wb=True).to_dataframe(names=['#CHROM_VCF', 'POS_VCF', 'END_VCF', 'ID_VCF', 'SVTYPE_VCF', 'SVLEN_VCF', '#CHROM_FREQ', 'POS_FREQ', 'END_FREQ', 'ID_FREQ', 'SVTYPE_FREQ', 'SVLEN_FREQ'])
		if len(int_df) > 0:
			for sv_df_event in int_df['ID_VCF'].unique():
				merge_dict[sv_df_event] = int_df.loc[int_df['ID_VCF'] == sv_df_event]['ID_FREQ'].values
	else:
		vcf_int_check['REF'] = vcf_int_check['REF'].str.upper()
		vcf_int_check['ALT'] = vcf_int_check['ALT'].str.upper()
		int_df = vcf_int_check.merge(sv_df.loc[sv_df['SVTYPE'] == 'SNV'], on=['#CHROM', 'POS', 'REF', 'ALT'], suffixes=['_VCF', '_FREQ'])
		for sv_df_event in int_df['ID_VCF'].unique():
			merge_dict[sv_df_event] = int_df.loc[int_df['ID_VCF'] == sv_df_event]['ID_FREQ'].values


sv_df = sv_df.set_index('ID', drop=True)


print('Updating INFO Field')
if index_change == False:
	vcf['INFO'] = vcf.apply(lambda row: row['INFO']+";FREQ="+','.join(set(['%.4g' % sv_df.at[x, 'AF'] for x in merge_dict[row['ID']]])) if row['ID'] in merge_dict else row['INFO']+";FREQ=0.0000", axis=1)
else:
	vcf['INFO'] = vcf.apply(lambda row: row['INFO']+";FREQ="+','.join(set(['%.4g' % sv_df.at[x, 'AF'] for x in merge_dict[int(row.name)]])) if int(row.name) in merge_dict else row['INFO']+";FREQ=0.0000", axis=1)


print('Writing output')
if args.output.endswith('gz'):
	out_handle = gzip.open(args.output, 'wt')
else:
	out_handle = open(args.output, 'w')


with out_handle as outfile:
	for line in header:
		outfile.write(line)

vcf[columns].to_csv(args.output, mode='a', index=False, sep='\t')







