import numpy as np
import matplotlib.pyplot as plt
from Bio import SeqIO
import math
import argparse
from itertools import zip_longest, islice

CMAP = dict(zip('0123456789BDEFHIJKLMNOPQRSUVWXYZbdefhijklmnopqrsuvwxyz', list(reversed(range(55)))))

def parse():
    parser = argparse.ArgumentParser()
    parser.description = "Generate colored dotplots, one per pair of sequences in the fasta files."
    parser.add_argument('mode', choices = ['nocol', 'byfreq', 'bykmer'], 
        help = "nocol: generate gray scale dotplots. \
byfreq: color dotplots by motif frequency. \
Different motifs with the same frequency are assigned with identical color. \
bykmer: color dotplots by motif. \
Best used for small regions and set the k-mer size to the motif size.")
    parser.add_argument('fasta_hor', type = str, help = 'Sequences on the vertical line')
    parser.add_argument('fasta_vert', type = str, help = 'Sequences on the horizontal line')
    parser.add_argument('-k', dest = 'k', type = int, default = 10, help = 'k-mer size')
    parser.add_argument('-t', dest = 'do_log_transformation', default = False,
        action = 'store_true', help = "Only for plotting in byfreq mode. Perform log transformation on the motif\
frequency. Some times can better separate different motifs")
    parser.add_argument('--cmap', dest = 'cmap', type = str, default = 'jet', help = 'Matplotlib cmap string')
    parser.add_argument('--alpha', dest = 'alpha', type = int, default = 1, help = 'Transparency of each dot')
    parser.add_argument('--markersize', dest = 'ms', type = int, default = 10, help = 'Dot size')
    return parser.parse_args()

# Dotplot functions

def to_int_keys_best(l):
    """
    l: iterable of keys
    returns: a list with integer keys
    """
    seen = set()
    ls = []
    for e in l:
        if not e in seen:
            ls.append(e)
            seen.add(e)
    ls.sort()
    index = {v: i for i, v in enumerate(ls)}
    return [index[v] for v in l]

def suffix_array_percise(s, k):
    rank_nearest_2base = int(math.log(k,2))
    k_res = k - 2 ** rank_nearest_2base
    line = to_int_keys_best(s)
    k_i = 1
    while k_i <= k >> 1:
        line = to_int_keys_best(
            list(zip_longest(line, islice(line, k_i, None),
                             fillvalue=-1)))
        k_i <<= 1
    sa_fnl = sort_matrix_helper(line, k_res)    
    return sa_fnl

def sort_matrix_helper(line, k):
    if len(set(line)) == len(line):
        return line
    if k < 1:
        return line
    else:
        return sort_matrix_helper(sort_matrix(line), k - 1)

def sort_matrix(line):
    n = len(line)
    line = to_int_keys_best(
            [a * (n + 1) + b + 1
             for (a, b) in
             zip_longest(line, islice(line, 1, None),
                         fillvalue=-1)])
    return line

def sort_by_index(arr):
    result = [[]for i in range(max(arr) + 1)]
    for i in range(len(arr)):
        result[arr[i]].append(i)
    return result



# Coloring functions
def nocol(s1, s2, k):
    l1 = len(s1)
    l2 = len(s2)
    s = s1 + s2
    sa = suffix_array_percise(s, k)
    kmers = sort_by_index(sa)
    x = []
    y = []
    for i in range(len(kmers)):
        lis = kmers[i]
        xx = [j for j in lis if j < l1 ]
        yy = [j - l1 for j in lis if j >= l1 ]
        if len(xx) > 0 and len(yy) > 0:
            for i in xx:
                for j in yy:
                    x.append(i)
                    y.append(j)
    return x, y, None

def bykmer(s1, s2, k, cmap):
    sc = s1 + s2
    kmers = sort_by_index(suffix_array_percise(sc, k))
    masked_arr = [i for i in sc]
    coloritr = iter(cmap.keys())
    kmers_rank = []
    try:   # stop iter when run out of colors
        while True:
            for group in sorted(kmers, key = len, reverse = True):
                if len(group) < 3:     # kmer must repeated at least once in either seqs to be colored
                    break
                color = next(coloritr)
                word = sc[group[0]:group[0] + k]
                kmers_rank.append(word)
                for i in range(len(masked_arr)):
                    if ''.join(masked_arr[i:i+k]) == word:
                        masked_arr[i:i+k] = [color] * k
    except StopIteration:
        pass
    masked_str = ''.join(masked_arr)
    masked_kmers = sort_by_index(suffix_array_percise(masked_str, k))    
    return get_coordinate(masked_kmers, masked_str, len(s1), cmap), list(reversed(kmers_rank))

def get_coordinate(kmers, string, l1, cmap):
    x, y, c= ([], [], [])
    for group in kmers:
        group_x = [j for j in group if j < l1 ]
        group_y = [j - l1 for j in group if j >= l1 ]
        if len(group_x) > 0 and len(group_y) > 0:   # If kmer only exist in one sequence don't plot
            if string[group[0]] in cmap.keys():   
                color = cmap[string[group[0]]]
            else: color = 1    # If not masked plot with lowest color
            for i in group_x:
                for j in group_y:
                    x.append(i)
                    y.append(j)
                    c.append(color)
    return x, y, c

def byfreq(s1, s2, k, do_log_trans = False):
    sc = s1 + s2
    sa = suffix_array_percise(sc, k)
    kmers = sort_by_index(sa)
    x, y, c = ([], [], [])
    l1 = len(s1)
    for kmer in sorted(kmers, key = len, reverse = True):
        kmer_x = [j for j in kmer if j < l1 ]
        kmer_y = [j - l1 for j in kmer if j >= l1 ]
        if len(kmer_x) > 0 and len(kmer_y) > 0:
            for i in kmer_x:
                for j in kmer_y:
                    x.append(i)
                    y.append(j)
                    c.append(len(kmer))
    if do_log_trans:
        c = [math.log(i, 10) for i in c]
    return x, y, c

# Plotting functions
def plot(
    x, y, c, rec_hor, rec_vert, ticks = None, 
    ticklabel = None, do_log_trans = False, 
    alpha = 1, marker = '.', markersize = 10,
    cmap = "jet"):
    
    cmap = plt.get_cmap(cmap)

    # Don't perform log_trans when ticks and ticklabel are set
    if ticks and ticklabel:
        do_log_trans = False
    elif not ticks and not ticklabel:
        pass
    else:
        raise ValueError("Must set both ticks and ticklabel")
        
    # plot dotplot
    f, ax =plt.subplots(figsize=(10, 10))
    sct = plt.scatter(
        x, y, c = c, s = markersize, 
        marker = marker, edgecolor= '', 
        alpha=alpha , rasterized = False, 
        cmap = cmap)
    ax.set_xlim([0, len(rec_hor.seq)])
    ax.set_ylim([0, len(rec_vert.seq)])
    ax.set_xlabel(rec_hor.id)
    ax.set_ylabel(rec_vert.id)

    # Plot colorbar seperatly
    pos = ax.get_position()
    ax_cb=f.add_axes([pos.x0 + 0.8, pos.y0, 0.05, pos.y1 - 0.125])
    if do_log_trans:
        log_sm = plt.cm.ScalarMappable(cmap=plt.get_cmap("jet"), norm=plt.Normalize(vmin=0, vmax=10**max(c)))
        log_sm._A = [] 
        sm = log_sm
    else: sm = sct
    cb = plt.colorbar(sm, cax=ax_cb)
    
    
    # set ticks
    if ticks and ticklabel:
        cb.set_ticks(ticks)
        cb.set_ticklabels(ticklabel)
    
    return f

def main():

    args = parse()

    # Check user input error
    if args.mode != "byfreq" and args.do_log_transformation:
        raise ValueError("-t option is exclusive to byfreq mode.")

    # Parse fasta files
    rec_list_hor = [record for record in SeqIO.parse(args.fasta_hor, 'fasta')]
    rec_list_vert = [record for record in SeqIO.parse(args.fasta_vert, 'fasta')]

    # Warn if two fasta files have different number of sequences
    if len(rec_list_vert) != len(rec_list_hor):
        warnings.warn("Two fasta files have different number of sequences, additional sequences will be ignored")


    for rec_hor, rec_vert, count in zip(rec_list_hor, rec_list_vert, list(range(max(len(rec_list_hor), len(rec_list_vert))))):
        if args.mode == "nocol":
            x, y, c = nocol(rec_hor.seq, rec_vert.seq, args.k)
            f = plot(x, y, c, rec_hor, rec_vert, alpha = args.alpha, markersize = args.ms, cmap = args.cmap)
            f.savefig('nocol.{}.png'.format(count), facecolor='w', bbox_inches='tight', dpi = 300, pad_inches = 0)
            f.savefig('nocol.{}.pdf'.format(count), facecolor='w', bbox_inches='tight', pad_inches = 0)
        elif args.mode == "byfreq":
            x, y, c = byfreq(rec_hor.seq, rec_vert.seq, args.k, do_log_trans = args.do_log_transformation)
            f = plot(x, y, c, rec_hor, rec_vert, do_log_trans = args.do_log_transformation, alpha = args.alpha, markersize = args.ms, cmap = args.cmap)         
            f.savefig('byfreq.{}.png'.format(count), facecolor='w', bbox_inches='tight', dpi = 300, pad_inches = 0)
            f.savefig('byfreq.{}.pdf'.format(count), facecolor='w', bbox_inches='tight', pad_inches = 0)
        elif args.mode == "bykmer":
            (x, y, c), tick_label = bykmer(rec_hor.seq, rec_vert.seq, args.k, cmap = CMAP)
            f = plot(x, y, c, rec_hor, rec_vert, ticks=range(1,56), ticklabel=tick_label, alpha = args.alpha, markersize = args.ms, cmap = args.cmap)
            f.savefig('bykmer.{}.png'.format(count), facecolor='w', bbox_inches='tight', dpi = 300, pad_inches = 0)

if __name__ == '__main__':
    main()

