#!/usr/bin/env python3

import argparse
import sys
import datetime
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')


def get_args():
    '''This function parses and return arguments passed in'''
    parser = argparse.ArgumentParser(
        prog='plotAndReport',
        description='plotAndReport')
    parser.add_argument(
        '-c',
        dest='countFile',
        default=None,
        help="Normalized read count file")
    parser.add_argument(
        '-i',
        dest='identity',
        default=0.97,
        help='Identity threshold to retain read alignment'
    )
    parser.add_argument(
        '-v',
        dest="version",
        default=None,
        help='coproID version'
    )
    parser.add_argument(
        "-csv",
        dest="csv",
        default=None,
        help="csv filename output. Default = coproid.csv"
    )
    parser.add_argument(
        "-adna",
        dest="adna",
        default='true',
        help="adna (true | false)"
    )
    parser.add_argument(
        '-o',
        dest="output",
        default="coproID_result.md",
        help="Output file basename. Default = coproID_result.md")

    args = parser.parse_args()

    countFile = args.countFile
    identity = float(str(args.identity))
    version = str(args.version)
    adna = str(args.adna)
    outfile = args.output
    csv = args.csv

    return(countFile, identity, version, adna, outfile, csv)


def getBasename(file_name):
    if ("/") in file_name:
        basename = file_name.split("/")[-1].split(".")[0]
    else:
        basename = file_name.split(".")[0]
    return(basename)


def getFile(path):
    if ("/") in path:
        filename = path.split("/")[-1]
    else:
        filename = path
    return(filename)


def write_csv(adict, CSV):
    with open(CSV, "w") as f:
        f.write("Sample_name,Organism_name1,Organism_name2,Genome1_size,Genome2_size,nb_aDNA_bp_aligned_genome1,nb_aDNA_bp_aligned_genome2,NormalizedReadRatio\n")
        for akey in adict:
            f.write(akey + "," + adict[akey]["orga1"] + "," + adict[akey]
                    ["orga2"] + "," + str(adict[akey]["gs1"]) + "," + str(adict[akey]["gs2"]) + "," + str(adict[akey]["nbp1"]) + "," + str(adict[akey]["nbp2"]) + "," + str(adict[akey]["nrr"]) + "\n")


if __name__ == "__main__":
    CF, ID, VERSION, ADNA, OUTFILE, CSV = get_args()

    if ADNA == 'true':
        ADNA = True
    elif ADNA == 'false':
        ADNA = False

    identity = ID * 100

    if CF is None:
        print("Count file is missing")
        sys.exit(1)

    # Template output file structure
    # Sample_name,Organism_name1,Organism_name2,Genome1_size,Genome2_size,nb_bp_aligned_genome1,nb_bp_aligned_genome2,NormalizedReadRatio

    d = {}
    d2 = {}
    all_orga = []
    with open(CF, "r") as f:
        for line in f:
            line = line.rstrip().split(",")
            sample = line[0]
            organism1 = line[1]
            organism2 = line[2]
            gs1 = int(line[3])
            gs2 = int(line[4])
            nbp1 = int(line[5])
            nbp2 = int(line[6])
            nrr = float(line[7])
            if sample not in d.keys():
                d[sample] = {}
            d[sample]["orga1"] = organism1
            d[sample]["orga2"] = organism2
            d[sample]["gs1"] = gs1
            d[sample]["gs2"] = gs2
            d[sample]["nbp1"] = nbp1
            d[sample]["nbp2"] = nbp2
            d[sample]["nrr"] = nrr

    write_csv(d, CSV)

    orga1_clean = organism1.replace("_", " ")
    orga2_clean = organism2.replace("_", " ")

    y = [d[i]["nrr"] for i in d.keys()]
    x = list(d.keys())

    maxratio = max(y)
    minratio = min(y)
    if minratio == maxratio and minratio > 0:
        minratio = -1.2
        maxratio = maxratio
    elif minratio == maxratio and minratio < 0:
        minratio = minratio
        maxratio = 1.2
    elif maxratio < 0 and minratio < 0:
        maxratio = 1.2
    elif minratio > 0 and maxratio > 0:
        minratio = -1.2

    step = round((maxratio - minratio) / 10, 1)

    ticklist = [orga2_clean] + [str(round(x, 1)) for x in list(
        np.arange(minratio, maxratio + step, step))] + [orga1_clean]

    # plt.plot(x, y, "ro")
    plt.bar(x, y)
    for a, b in zip(x, y):
        if (b < 0):
            fac = -1
        else:
            fac = 0.05
        plt.text(a, b + fac, str(round(b, 1)),
                 color="red", ha="center", rotation=45)

    # plt.grid(True)

    plt.axhline(y=0, color="black")
    plt.axhline(y=1, color="orange")
    plt.axhline(y=-1, color="orange")

    if minratio < 0:
        plt.yticks((np.arange(minratio - step, maxratio + 2 * step, step)),
                   (ticklist))
    else:
        plt.yticks((np.arange(-(2 * step), maxratio + 2 * step, step)),
                   (ticklist))
    plt.xticks(rotation='45', ha="right", rotation_mode="anchor")
    plt.yticks(rotation='45')
    plt.subplots_adjust(bottom=0.4, left=0.19)

    plt.savefig("./ratio.png")

    now = datetime.datetime.now()
    with open(OUTFILE, "w") as fw:
        fw.write("---\ntitle: coproID - Coprolite Identification\n---\n  ")
        fw.write(
            '<img src="https://raw.githubusercontent.com/maxibor/coproID/master/img/logo.png" height="150">\n\n')
        fw.write(
            '**Homepage/Documentation**: [github.com/maxibor/coproID](https://github.com/maxibor/coproID)  \n**Author:** Maxime Borry [borry@shh.mpg.de](mailto:borry@shh.mpg.de)\n\n  ')
        fw.write("Report automatically generated by coproID version " + VERSION + " on " + str(now.year) + "-" + str(now.month) + "-" +
                 str(now.day) + "  " + str(now.hour) + ":" + str(now.minute) + ":" + str(now.second) + "  \n")
        fw.write("\n\n")
        fw.write("***\n")
        fw.write("- **Organisms:**   \n\n    - *" + orga1_clean + "* (genome size = " + str(gs1) +
                 " bp )  \n\n    - *" + orga2_clean + "* (genome size = " + str(gs2) + " bp)  \n")
        fw.write(
            "- **The formula used to compute the read ratio is the following:** \n\n  $NormalizedRead_{Ratio} = \\log2\\left(\\frac{\\frac{N_{\\ aDNA\\ bp \\ aligned \\ genome1}}{size_{genome2} }}{\\frac{N_{ \\ aDNA \\ bp \\ aligned \\ genome2}}{size_{genome2}}}\\right)$\n\n  ")
        fw.write("***\n")
        fw.write("## coproID read ratio plot\n\n")
        fw.write("![Normalized read ratio (*" + orga1_clean +
                 "*/*" + orga2_clean + "*). The black line at 0 is an equal proportion of both genomes, while orange lines at |1| represent the uncertainty zone.](./ratio.png)\n")
        fw.write("\n\n***\n\n")
        fw.write("***\n\n")
        for asample in d.keys():
            orga1_clean = d[asample]["orga1"].replace("_", " ")
            orga2_clean = d[asample]["orga2"].replace("_", " ")
            if d[asample]["nrr"] > 1:
                conclusion = 'There are more *' + orga1_clean + '* aDNA reads than *' + orga2_clean + \
                    '* by > 2 folds. Therefore this coprolite is most likely originating from *' + \
                    orga1_clean + '*'
            elif d[asample]["nrr"] < -1:
                conclusion = 'There are more *' + orga2_clean + '* aDNA reads than *' + orga1_clean + \
                    '* by > 2 folds. Therefore this coprolite is most likely originating from *' + \
                    orga2_clean + '*'
            else:
                conclusion = 'There is a similar amount of aDNA reads originating from *' + \
                    orga1_clean + '* and *' + orga2_clean + \
                    '*. coproID can not reliably estimate which of these two organisms this coprolite is coming from.'
            fw.write("### **Sample:** " + asample + "  \n")
            fw.write("- **Number of base pairs (bp) in reads aligned to *" +
                     orga1_clean + "* (genome 1) at " + str(identity) + "\% identity:** " + str(d[asample]["nbp1"]) + "  \n")
            fw.write("- **Number of base pairs (bp) in reads aligned to *" +
                     orga2_clean + "* (genome 2) at " + str(identity) + "\% identity:** " + str(d[asample]["nbp2"]) + "  \n\n")
            fw.write("- $NormalizedRead_{ratio} = \\log2\\left(\\frac{" + d[asample]["orga1"].replace("_", "\\;") + "}{" +
                     d[asample]["orga2"].replace("_", "\\;") + "}\\right) = " + str(round(d[asample]["nrr"], 3)) + "$   \n\n")
            fw.write(conclusion + '\n\n')
            fw.write("- **MapDamage plots**\n\n")
            if d[asample]["nbp1"] > 20 and ADNA == True:
                fw.write("![MapDamage Fragmisincorporation plot for " +
                         orga1_clean + ". Very jagged and irregular substitutions are likely indicating spurious damages caused by a lack of reads. ](./" + asample + "." + d[asample]["orga1"] + ".fragmisincorporation_plot.png)\n\n")
            else:
                fw.write(
                    "\n\nNo plot is displayed because the number of bases aligned is too low for mapDamage, or it is Modern DNA\n\n")
            if d[asample]["nbp2"] > 20 and ADNA == True:
                fw.write("![MapDamage Fragmisincorporation plot for " +
                         orga2_clean + ". Very jagged and irregular substitutions are likely indicating spurious damages caused by a lack of reads. If there is not plot displayed, the aligned read count was too low for mapDamage](./" + asample + "." + d[asample]["orga2"] + ".fragmisincorporation_plot.png)\n\n")
            else:
                fw.write(
                    "\n\nNo plot is displayed because the number of bases aligned is too low for mapDamage, or it is Modern DNA\n\n")
            fw.write("\n\n***\n")
