import math
import os
import sys

from snakemake.utils import min_version

min_version("6.0")

bold = "\033[1m"
green = "\033[92m"
cyan = "\033[36m"
end = "\033[0m"
msg = f"""{green}{bold}Thanks for using StainedGalss and please remember to cite the tool!{end}{end}
    {cyan}Vollger, Mitchell R., Peter Kerpedjiev, Adam M. Phillippy, and Evan E. Eichler. 2021.
        “StainedGlass: Interactive Visualization of Massive Tandem Repeat Structures with Identity Heatmaps.” 
        BioRxiv. https://doi.org/10.1101/2021.08.19.457003. {end}
"""
sys.stderr.write(msg)

SDIR = os.path.realpath(os.path.dirname(srcdir("Snakefile")))
shell.prefix(f"set -eo pipefail; ")


configfile: "config/config.yaml"


#
# global options
#
N = config.pop("nbatch", 1)
NUM_DUPS = config.pop("num_dups", 100)
SM = config.pop("sample", "sample")
W = config.pop("window", 2000)
CW = config.pop("cooler_window", int(W))
ALN_T = config.pop("alnthreads", 4)
F = config.pop("mm_f", 10000)
S = config.pop("mm_s", int(W / 5))
TEMP_DIR = config.pop("tempdir", "temp")
SAMTOOLS_MEM = config.pop("samtools_mem", 1)
SLIDE = config.pop("slide", 0)

assert W > SLIDE, f"window ({W}) must be larger than slide ({SLIDE})."

# short vs long read mapping
MAP_PARAMS = "-ax ava-ont "
if W < 500:
    MAP_PARAMS = " -a --no-pairing -k21 --sr -A2 -B8 -O12,32 -E2,1 -r100 -g100 -2K50m --heap-sort=yes -X -w 5 -e0 "
elif SLIDE > 0:
    S = config.pop("mm_s", 100)
    MAP_PARAMS = "-ax map-ont --secondary=no"

if TEMP_DIR != "temp":
    if os.path.exists("temp"):
        if os.path.realpath("temp") == os.path.realpath(TEMP_DIR):
            print("The temp dir has already been linked.")
        else:
            sys.exit("temp/ already in use, please move it before running.")
    else:
        shell("ln -s {TEMP_DIR} temp")

#
# required arguments
#
FASTA = os.path.abspath(config["fasta"])

FAI = f"{FASTA}.fai"
if not os.path.exists(FAI):
    sys.exit(f"Input fasta must be indexed, try:\nsamtools faidx {FASTA}")

#
# setup split and gather
#
lengths = [math.ceil(int(line.split()[1]) / W) for line in open(FAI).readlines()]
names = [line.split()[0] for line in open(FAI).readlines()]
REF_IDS = [f"ref_{i}" for i in range(len(names))] if SLIDE > 0 else ["ref_0"]
N_WINDOWS = int(sum(lengths))
N = min(N, N_WINDOWS)
IDS = list(range(N))
sys.stderr.write(f"[INFO] The sequence will be split into {N} batches.\n")


localrules:
    all,
    bwa_index,
    bwa_aln,
    merge_aln,
    identity,
    pair_end_bed,
    window_fa,
    split_windows,
    merge_list,
    sort_aln,
    identity,
    pair_end_bed,
    make_R_figures,
    make_figures,
    cooler,
    cooler_identity,
    cooler_strand,
    cooler_zoomify_i,
    cooler_zoomify_s,


wildcard_constraints:
    ID="|".join(map(str, IDS)),
    REF_ID="|".join(map(str, REF_IDS)),
    SM=SM,
    W=W,
    F=F,


rule all:
    input:
        sort=expand("results/{SM}.{W}.{F}.sorted.bam", SM=SM, W=W, F=F),
        beds=expand("results/{SM}.{W}.{F}.bed.gz", SM=SM, W=W, F=F),
        fulls=expand("results/{SM}.{W}.{F}.full.tbl.gz", SM=SM, W=W, F=F),


rule cooler:
    input:
        cool_s=expand("results/{SM}.{W}.{F}.strand.cool", SM=SM, W=W, F=F),
        cool_i=expand("results/{SM}.{W}.{F}.identity.cool", SM=SM, W=W, F=F),
        cool_sm=expand("results/{SM}.{W}.{F}.strand.mcool", SM=SM, W=W, F=F),
        cool_im=expand("results/{SM}.{W}.{F}.identity.mcool", SM=SM, W=W, F=F),


rule cooler_density:
    input:
        cool_d=expand("results/{SM}.{W}.{CW}.density.mcool", SM=SM, W=W, CW=CW),


rule make_windows:
    input:
        fai=FAI,
    output:
        bed=temp("temp/{SM}.{W}.bed"),
    log:
        "logs/make_windows.{SM}.{W}.log",
    conda:
        "envs/env.yaml"
    threads: 1
    resources:
        mem=1,
    params:
        slide=f"-s {SLIDE}" if SLIDE > 0 else "",
    shell:
        """
        bedtools makewindows -g {input.fai} -w {wildcards.W} {params.slide} > {output.bed}
        """


rule split_windows:
    input:
        bed=rules.make_windows.output.bed,
    output:
        bed=temp(expand("temp/{SM}.{W}.{ID}.bed", ID=IDS, allow_missing=True)),
    threads: 1
    conda:
        "envs/env.yaml"
    log:
        "logs/split_windows.{SM}.{W}.log",
    resources:
        mem=1,
    params:
        script=workflow.source_path("scripts/batch_bed_files.py"),
    shell:
        """
        python {params.script} {input.bed} --outputs {output.bed}
        """


rule window_fa:
    input:
        ref=FASTA,
        bed=rules.make_windows.output.bed,
    output:
        fasta="results/{SM}.{W}.fasta",
    conda:
        "envs/env.yaml"
    log:
        "logs/window_fa.{SM}.{W}.log",
    threads: 1
    resources:
        mem=4,
    shell:
        """
        bedtools getfasta -fi {input.ref} -bed {input.bed} \
            > {output.fasta}
        """


rule split_fa:
    input:
        ref=FASTA,
    output:
        fasta=temp("temp/{SM}.{W}.{REF_ID}.split.fasta"),
    conda:
        "envs/env.yaml"
    log:
        "logs/split_fa.{SM}.{W}.{REF_ID}.log",
    threads: 1
    params:
        name=lambda wc: names[int(wc.REF_ID.strip("ref_"))],
    resources:
        mem=4,
    shell:
        """
        samtools faidx {input.ref} {params.name} > {output.fasta} 
        """


rule aln_prep:
    input:
        ref=rules.split_fa.output.fasta if SLIDE > 0 else rules.window_fa.output.fasta,
    output:
        split_ref_index=temp("temp/{SM}.{W}.{F}.{REF_ID}.fasta.mmi"),
    conda:
        "envs/env.yaml"
    log:
        "logs/aln_prep.{SM}.{W}.{F}.{REF_ID}.log",
    threads: 1
    params:
        S=S,
        MAP_PARAMS=MAP_PARAMS,
    resources:
        mem=4,
    shell:
        """
        minimap2 \
            -f {wildcards.F} -s {params.S} \
            {params.MAP_PARAMS} \
             -d {output.split_ref_index} \
            {input.ref} \
        2> {log}
        """


rule query_prep:
    input:
        ref=FASTA,
        bed="temp/{SM}.{W}.{ID}.bed",
    output:
        query_fasta=temp("temp/{SM}.{W}.{ID}.query.fasta"),
    conda:
        "envs/env.yaml"
    log:
        "logs/query_prep.{SM}.{W}.{ID}.log",
    threads: 1
    resources:
        mem=4,
    shell:
        """
        bedtools getfasta -fi {input.ref} -bed {input.bed} \
            > {output.query_fasta}
        """


rule aln:
    input:
        split_ref=rules.aln_prep.output.split_ref_index,
        query=rules.query_prep.output.query_fasta,
    output:
        aln=temp("temp/{SM}.{W}.{F}.{ID}.{REF_ID}.bam"),
    conda:
        "envs/env.yaml"
    log:
        "logs/aln.{SM}.{W}.{F}.{ID}.{REF_ID}.log",
    threads: ALN_T
    params:
        S=S,
        MAP_PARAMS=MAP_PARAMS,
    resources:
        mem=4,
    shell:
        """
        ( minimap2 \
            -t {threads} \
            -f {wildcards.F} -s {params.S} \
            {params.MAP_PARAMS} \
            --dual=yes --eqx \
            {input.split_ref} {input.query} \
                | samtools sort -m {resources.mem}G \
                    -o {output.aln} \
        ) 2> {log}
        """


rule merge_list:
    input:
        aln=expand(
            "temp/{SM}.{W}.{F}.{ID}.{REF_ID}.bam",
            ID=IDS,
            REF_ID=REF_IDS,
            allow_missing=True,
        ),
    output:
        alns=temp("temp/{SM}.{W}.{F}.list"),
    threads: 1
    log:
        "logs/merge_list.{SM}.{W}.{F}.log",
    resources:
        mem=8,
    run:
        open(output.alns, "w").write("\n".join(input.aln) + "\n")


rule merge_aln:
    input:
        alns=rules.merge_list.output.alns,
        aln=expand(
            "temp/{SM}.{W}.{F}.{ID}.{REF_ID}.bam",
            REF_ID=REF_IDS,
            ID=IDS,
            allow_missing=True,
        ),
    output:
        aln=temp("temp/{SM}.{W}.{F}.bam"),
    conda:
        "envs/env.yaml"
    log:
        "logs/merge_aln.{SM}.{W}.{F}.log",
    threads: 4
    resources:
        mem=4,
    shell:
        """
        #samtools cat -b {input.alns} -o {output.aln} 
        samtools merge -b {input.alns} {output.aln} 
        """


rule sort_aln:
    input:
        aln=rules.merge_aln.output.aln,
    output:
        aln="results/{SM}.{W}.{F}.sorted.bam",
    conda:
        "envs/env.yaml"
    log:
        "logs/sort_aln.{SM}.{W}.{F}.log",
    threads: 8
    resources:
        mem=SAMTOOLS_MEM,
    shell:
        """
        samtools sort -m {resources.mem}G -@ {threads} --write-index \
            -o {output.aln} {input.aln}
        """


rule identity:
    input:
        aln=rules.sort_aln.output.aln,
    output:
        tbl=temp("temp/{SM}.{W}.{F}.tbl.gz"),
    conda:
        "envs/env.yaml"
    threads: 8
    params:
        script=workflow.source_path("scripts/samIdentity.py"),
        S=S,
    log:
        "logs/identity.{SM}.{W}.{F}.log",
    resources:
        mem=8,
    threads: 8
    shell:
        """
        python {params.script} --threads {threads} \
            --matches  {params.S} --header \
            {input.aln} \
            | pigz -p {threads} > {output.tbl}
        """


rule pair_end_bed:
    input:
        tbl=rules.identity.output.tbl,
        fai=FAI,
    output:
        bed="results/{SM}.{W}.{F}.bed.gz",
        full="results/{SM}.{W}.{F}.full.tbl.gz",
    conda:
        "envs/env.yaml"
    threads: 1
    log:
        "logs/pair_end_bed.{SM}.{W}.{F}.log",
    resources:
        mem=64,
    params:
        script=workflow.source_path("scripts/refmt.py"),
        one="--one" if SLIDE > 0 else "",
    shell:
        """
        python {params.script} \
            --window {wildcards.W} --fai {input.fai} \
            --full {output.full} \
            {params.one} \
            {input.tbl} {output.bed}
        """


#
# for figures ~ 3Mbp or less
#
rule make_R_figures:
    input:
        bed=rules.pair_end_bed.output.bed,
    output:
        pdf="results/{SM}.{W}.{F}_figures/pdfs/{SM}.{W}.{F}.tri.TRUE__onecolorscale.FALSE__all.pdf",
        png=report(
            expand(
                "results/{SM}.{W}.{F}_figures/pngs/{SM}.{W}.{F}.tri.{first}__onecolorscale.{second}__all.png",
                allow_missing=True,
                first=["TRUE", "FALSE"],
                second=["TRUE", "FALSE"],
            ),
            category="figures",
        ),
        facet=report(
            "results/{SM}.{W}.{F}_figures/pngs/{SM}.{W}.{F}.facet.all.png",
            category="figures",
        ),
    conda:
        "envs/R.yaml"
    threads: 8
    log:
        "logs/make_R_figures.{SM}.{W}.{F}.log",
    params:
        script=workflow.source_path("scripts/aln_plot.R"),
    resources:
        mem=64,
    shell:
        """
        Rscript {params.script} \
            --bed {input.bed} \
            --threads {threads} \
            --prefix {wildcards.SM}.{wildcards.W}.{wildcards.F}
        """


rule make_figures:
    input:
        expand(rules.make_R_figures.output, SM=SM, W=W, F=F),


#
# for large scale visualization
#
rule cooler_strand:
    input:
        bed=rules.pair_end_bed.output.bed,
        fai=FAI,
    output:
        cool="results/{SM}.{W}.{F}.strand.cool",
    conda:
        "envs/env.yaml"
    threads: 1
    log:
        "logs/cooler_strand.{SM}.{W}.{F}.log",
    resources:
        mem=64,
    shell:
        """
        gunzip -c {input.bed} | tail -n +2 \
            | sed  's/+$/100/g' \
            | sed  's/-$/50/g' \
              | cooler cload pairs \
                -c1 1 -p1 2 -c2 4 -p2 5 \
                        --field count=8:agg=mean,dtype=float \
                --chunksize 50000000000 \
                {input.fai}:{wildcards.W} \
                --zero-based \
                - {output.cool}
        """


rule cooler_identity:
    input:
        bed=rules.pair_end_bed.output.bed,
        fai=FAI,
    output:
        cool="results/{SM}.{W}.{F}.identity.cool",
    conda:
        "envs/env.yaml"
    log:
        "logs/cooler_identity.{SM}.{W}.{F}.log",
    threads: 1
    resources:
        mem=64,
    shell:
        """
        gunzip -c {input.bed} | tail -n +2 \
              | cooler cload pairs \
                -c1 1 -p1 2 -c2 4 -p2 5 \
                        --field count=7:agg=mean,dtype=float \
                --chunksize 50000000000 \
                {input.fai}:{wildcards.W} \
                --zero-based \
                - {output.cool}

        """


rule cooler_zoomify_i:
    input:
        i=rules.cooler_identity.output.cool,
    output:
        i="results/{SM}.{W}.{F}.identity.mcool",
    conda:
        "envs/env.yaml"
    threads: 8
    log:
        "logs/cooler_identity_i.{SM}.{W}.{F}.log",
    resources:
        mem=8,
    shell:
        """
        cooler zoomify --field count:agg=mean,dtype=float {input.i} \
                -n {threads} \
             -o {output.i}
        """


rule cooler_zoomify_s:
    input:
        s=rules.cooler_strand.output.cool,
    output:
        s="results/{SM}.{W}.{F}.strand.mcool",
    conda:
        "envs/env.yaml"
    log:
        "logs/cooler_identity_s.{SM}.{W}.{F}.log",
    threads: 8
    resources:
        mem=8,
    shell:
        """
        cooler zoomify --field count:agg=mean,dtype=float {input.s} \
                -n {threads} \
             -o {output.s}
        """


rule bwa_index:
    input:
        fasta=FASTA,
    output:
        done=touch("temp/bwa_ref_{SM}"),
    conda:
        "envs/env.yaml"
    log:
        "logs/bwa_index_{SM}.log",
    threads: 1
    shell:
        """
        bwa index -p temp/ref_{wildcards.SM} {input.fasta}
        """


rule bwa_aln:
    input:
        bwa_index_done="temp/bwa_ref_{SM}",
        reads=rules.window_fa.output.fasta,
    output:
        sam="results/output_{SM}_{W}.sam.gz",
    conda:
        "envs/env.yaml"
    log:
        "logs/bwa_aln_{SM}_{W}.log",
    params:
        num_dups=NUM_DUPS,
    threads: ALN_T
    shell:
        """
        bwa aln -t {threads} temp/ref_{wildcards.SM} {input.reads} \
            | bwa samse -n {params.num_dups} temp/ref_{wildcards.SM} - {input.reads} \
            | gzip > {output.sam}
        """


rule create_contacts:
    input:
        sam=rules.bwa_aln.output.sam,
    output:
        contacts="results/contacts_{SM}_{W}.gz",
    conda:
        "envs/env.yaml"
    log:
        "logs/create_contacts_{SM}_{W}.log",
    params:
        script=workflow.source_path("scripts/ntn_bam_to_contacts.py"),
    threads: 1
    shell:
        """
        gunzip -c {input.sam} | python {params.script} - \
            | gzip > {output.contacts}
        """


rule cooler_density_d:
    input:
        contacts=rules.create_contacts.output.contacts,
        fai=FAI,
    output:
        cool="results/{SM}.{W}.{CW}.density.cool",
    conda:
        "envs/env.yaml"
    log:
        "logs/cooler_density.{SM}.{W}.{CW}.log",
    threads: 1
    resources:
        mem=64,
    shell:
        """
        cooler cload pairs \
                -c1 1 -p1 2 -c2 4 -p2 5 \
                --chunksize 50000000000 \
                {input.fai}:{wildcards.CW} \
                --zero-based \
                {input.contacts} {output.cool}
        """


rule cooler_zoomify_d:
    input:
        d=rules.cooler_density_d.output.cool,
    output:
        d="results/{SM}.{W}.{CW}.density.mcool",
    conda:
        "envs/env.yaml"
    threads: 8
    log:
        "logs/cooler_density_d.{SM}.{W}.{CW}.log",
    resources:
        mem=8,
    shell:
        """
        cooler zoomify {input.d} \
                -n {threads} \
             -o {output.d}
        """
