#!/usr/bin/env python3

import argparse
import sys
import datetime
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
plt.switch_backend('agg')


def get_args():
    '''This function parses and return arguments passed in'''
    parser = argparse.ArgumentParser(
        prog='plotAndReport',
        description='plotAndReport')
    parser.add_argument(
        '-c',
        dest='countFile',
        default=None,
        help="Normalized read count file")
    parser.add_argument(
        '-i',
        dest='identity',
        default=0.97,
        help='Identity threshold to retain read alignment'
    )
    parser.add_argument(
        '-v',
        dest="version",
        default=None,
        help='coproID version'
    )
    parser.add_argument(
        "-csv",
        dest="csv",
        default=None,
        help="csv filename output. Default = coproid.csv"
    )
    parser.add_argument(
        "-adna",
        dest="adna",
        default='true',
        help="adna (true | false)"
    )
    parser.add_argument(
        '-o',
        dest="output",
        default="coproID_result.md",
        help="Output file basename. Default = coproID_result.md")

    args = parser.parse_args()

    countFile = args.countFile
    identity = float(str(args.identity))
    version = str(args.version)
    adna = str(args.adna)
    outfile = args.output
    csv = args.csv

    return(countFile, identity, version, adna, outfile, csv)


def getBasename(file_name):
    if ("/") in file_name:
        basename = file_name.split("/")[-1].split(".")[0]
    else:
        basename = file_name.split(".")[0]
    return(basename)


def getFile(path):
    if ("/") in path:
        filename = path.split("/")[-1]
    else:
        filename = path
    return(filename)


def write_csv(adict, CSV, n_samp):
    with open(CSV, "w") as f:
        if n_samp == 2:
            f.write("Sample_name,Organism_name1,Organism_name2,Genome1_size,Genome2_size,sourcepredict_genome1,sourcepredict_genome2,sourcepredict_unknown,nb_bp_aligned_genome1,nb_bp_aligned_genome2,normalized_nb_bp_aligned_genome1,normalized_nb_bp_aligned_genome2,NormalizedReadRatio_1,NormalizedReadRatio_2\n")
            for akey in adict:
                f.write(f"{akey},{adict[akey]['orga1']}, {adict[akey]['orga2']}, {adict[akey]['gs1']}, {dict[akey]['gs2']}, {adict[akey]['sourcepredict_g1']},{adict[akey]['sourcepredict_g2']},{adict[akey]['sourcepredict_unknown']}, {adict[akey]['nbp1']}, {adict[akey]['nbp2']}, {adict[akey]['nrr1']},{adict[akey]['nrr2']}\n")
        else:
            f.write("Sample_name,Organism_name1,Organism_name2,Organism_name3,Genome1_size,Genome2_size,Genome3_size,sourcepredict_genome1,sourcepredict_genome2,sourcepredict_genome3,sourcepredict_unknown,nb_bp_aligned_genome1,nb_bp_aligned_genome2,nb_bp_aligned_genome3,normalized_nb_bp_aligned_genome1,normalized_nb_bp_aligned_genome2,normalized_nb_bp_aligned_genome3,NormalizedReadRatio_1,NormalizedReadRatio_2,NormalizedReadRatio_3\n")
            for akey in adict:
                f.write(f"{akey},{adict[akey]['orga1']},{adict[akey]['orga2']},{adict[akey]['orga3']},{adict[akey]['gs1']},{adict[akey]['gs2']},{adict[akey]['gs3']},{adict[akey]['sourcepredict_g1']},{adict[akey]['sourcepredict_g2']},{adict[akey]['sourcepredict_g3']},{adict[akey]['sourcepredict_unknown']},{adict[akey]['nbp1']},{adict[akey]['nbp2']},{adict[akey]['nbp3']},{adict[akey]['nnbp1']},{adict[akey]['nnbp2']},{adict[akey]['nnbp3']},{adict[akey]['nrr1']},{adict[akey]['nrr2']},{adict[akey]['nrr3']}\n")


def sample_text(sample_dict):
    res = []
    if nrr3 not in asample.keys():
        orga1_clean = asample["orga1"].replace("_", " ")
        orga2_clean = asample["orga2"].replace("_", " ")
    else:
        orga1_clean = asample["orga1"].replace("_", " ")
        orga2_clean = asample["orga2"].replace("_", " ")
        orga3_clean = asample["orga2"].replace("_", " ")


if __name__ == "__main__":
    CF, ID, VERSION, ADNA, OUTFILE, CSV = get_args()

    if ADNA == 'true':
        ADNA = True
    elif ADNA == 'false':
        ADNA = False

    identity = ID * 100

    if CF is None:
        print("Count file is missing")
        sys.exit(1)

    # Template output file structure
    # Sample_name,Organism_name1,Organism_name2,Genome1_size,Genome2_size,sourcepredict_genome1,sourcepredict_genome2,sourcepredict_unknown,nb_bp_aligned_genome1,nb_bp_aligned_genome2,normalized_nb_bp_aligned_genome1,normalized_nb_bp_aligned_genome2,NormalizedReadRatio

    d = {}
    d2 = {}
    all_orga = []
    with open(CF, "r") as f:
        for line in f:
            line = line.rstrip().split(",")
            if len(line) <= 14:
                nsamp = 2
                sample = line[0]
                organism1 = line[1]
                organism2 = line[2]
                gs1 = int(line[3])
                gs2 = int(line[4])
                sourcepredict_g1 = float(line[5])
                sourcepredict_g2 = float(line[6])
                sourcepredict_unknown = float(line[7])
                nbp1 = int(line[8])
                nbp2 = int(line[9])
                nnbp1 = float(line[10])
                nnbp2 = float(line[11])
                nrr1 = float(line[12])
                nrr1 = float(line[13])
                if sample not in d.keys():
                    d[sample] = {}
                d[sample]["orga1"] = organism1
                d[sample]["orga2"] = organism2
                d[sample]["gs1"] = gs1
                d[sample]["gs2"] = gs2
                d[sample]["sourcepredict_g1"] = sourcepredict_g1
                d[sample]["sourcepredict_g2"] = sourcepredict_g2
                d[sample]["sourcepredict_unknown"] = sourcepredict_unknown
                d[sample]["nbp1"] = nbp1
                d[sample]["nbp2"] = nbp2
                d[sample]["nnbp1"] = nnbp1
                d[sample]["nnbp2"] = nnbp2
                d[sample]["nrr1"] = nrr1
                d[sample]["nrr2"] = nrr2
            else:
                # Sample_name,Organism_name1,Organism_name2,Organism_name3,Genome1_size,Genome2_size,Genome3_size,sourcepredict_genome1,sourcepredict_genome2,sourcepredict_genome3,sourcepredict_unknown,nb_bp_aligned_genome1,nb_bp_aligned_genome2,nb_bp_aligned_genome3,normalized_nb_bp_aligned_genome1,normalized_nb_bp_aligned_genome2,normalized_nb_bp_aligned_genome3,NormalizedReadRatio_1vs2,NormalizedReadRatio_1vs3,NormalizedReadRatio_2vs3
                nsamp = 3
                sample = line[0]
                organism1 = line[1]
                organism2 = line[2]
                organism3 = line[3]
                gs1 = int(line[4])
                gs2 = int(line[5])
                gs3 = int(line[6])
                sourcepredict_g1 = float(line[7])
                sourcepredict_g2 = float(line[8])
                sourcepredict_g3 = float(line[9])
                sourcepredict_unknown = float(line[10])
                nbp1 = int(line[11])
                nbp2 = int(line[12])
                nbp3 = int(line[13])
                nnbp1 = float(line[14])
                nnbp2 = float(line[15])
                nnbp3 = float(line[16])
                nrr1 = float(line[17])
                nrr2 = float(line[18])
                nrr3 = float(line[19])
                if sample not in d.keys():
                    d[sample] = {}
                    d[sample]["orga1"] = organism1
                    d[sample]["orga2"] = organism2
                    d[sample]["orga3"] = organism3
                    d[sample]["gs1"] = gs1
                    d[sample]["gs2"] = gs2
                    d[sample]["gs3"] = gs3
                    d[sample]["sourcepredict_g1"] = sourcepredict_g1
                    d[sample]["sourcepredict_g2"] = sourcepredict_g2
                    d[sample]["sourcepredict_g3"] = sourcepredict_g3
                    d[sample]["sourcepredict_unknown"] = sourcepredict_unknown
                    d[sample]["nbp1"] = nbp1
                    d[sample]["nbp2"] = nbp2
                    d[sample]["nbp3"] = nbp3
                    d[sample]["nnbp1"] = nnbp1
                    d[sample]["nnbp2"] = nnbp2
                    d[sample]["nnbp3"] = nnbp3
                    d[sample]["nrr1"] = nrr1
                    d[sample]["nrr2"] = nrr2
                    d[sample]["nrr3"] = nrr3
    write_csv(d, CSV, n_samp=nsamp)

    orga1_clean = organism1.replace("_", " ")
    orga2_clean = organism2.replace("_", " ")
    if nsamp == 3:
        orga3_clean = organism3.replace("_", " ")

    r = list(range(len(list(d.keys()))))
    raw_data = {}
    barWidth = 0.85
    if nsamp == 2:
        raw_data['greenBars'] = [i for i in [d[sample]['nrr1']
                                             for sample in list(d.keys())]]
        print(raw_data['greenBars'])
        raw_data['orangeBars'] = [i/j for i in [d[sample]['nrr2']
                                                for sample in list(d.keys())]]
        print(raw_data['orangeBars'])
        df = pd.DataFrame(raw_data)

        totals = [i+j for i, j in zip(df['greenBars'], df['orangeBars'])]
        greenBars = [i / j * 100 for i, j in zip(df['greenBars'], totals)]
        print(greenBars)
        orangeBars = [i / j * 100 for i, j in zip(df['orangeBars'], totals)]
        print(orangeBars)

        names = list(d.keys())
        # Create green Bars
        plt.bar(r, greenBars, color='#b5ffb9', edgecolor='white',
                width=barWidth, label=orga1_clean)
        # Create orange Bars
        plt.bar(r, orangeBars, bottom=greenBars, color='#f9bc86',
                edgecolor='white', width=barWidth, label=orga2_clean)

    elif nsamp == 3:
        raw_data['greenBars'] = [i for i in [d[sample]['nrr1']
                                             for sample in list(d.keys())]]
        raw_data['orangeBars'] = [i for i in [d[sample]['nrr2']
                                              for sample in list(d.keys())]]
        raw_data['blueBars'] = [i for i in [d[sample]['nrr3']
                                            for sample in list(d.keys())]]
        df = pd.DataFrame(raw_data)

        totals = [i+j+k for i,
                  j, k in zip(df['greenBars'], df['orangeBars'], df['blueBars'])]
        greenBars = [i / j * 100 for i, j in zip(df['greenBars'], totals)]
        orangeBars = [i / j * 100 for i, j in zip(df['orangeBars'], totals)]
        blueBars = [i / j * 100 for i, j in zip(df['blueBars'], totals)]

        names = list(d.keys())
        # Create green Bars
        plt.bar(r, greenBars, color='#b5ffb9', edgecolor='white',
                width=barWidth, label=orga1_clean)
        # Create orange Bars
        plt.bar(r, orangeBars, bottom=greenBars, color='#f9bc86',
                edgecolor='white', width=barWidth, label=orga2_clean)
        # Create blue Bars
        plt.bar(r, blueBars, bottom=[i+j for i, j in zip(greenBars, orangeBars)], color='#a3acff',
                edgecolor='white', width=barWidth, label=orga3_clean)

    plt.xticks(r, names, rotation='45', ha="right", rotation_mode="anchor")
    plt.xlabel("Sample")
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1), ncol=1)
    plt.subplots_adjust(bottom=0.4, right=0.7)

    plt.savefig("./ratio.png")

    now = datetime.datetime.now()
    with open(OUTFILE, "w") as fw:
        fw.write("---\ntitle: coproID - Coprolite Identification\n---\n  ")
        fw.write(
            '<img src="https://raw.githubusercontent.com/maxibor/coproID/master/img/logo.png" height="150">\n\n')
        fw.write(
            '**Homepage/Documentation**: [github.com/maxibor/coproID](https://github.com/maxibor/coproID)  \n**Author:** Maxime Borry [borry@shh.mpg.de](mailto:borry@shh.mpg.de)\n\n  ')
        fw.write("Report automatically generated by coproID version " + VERSION + " on " + str(now.year) + "-" + str(now.month) + "-" +
                 str(now.day) + "  " + str(now.hour) + ":" + str(now.minute) + ":" + str(now.second) + "  \n")
        # fw.write("\n\n")
        # fw.write("***\n")
        # fw.write("- **Organisms:**   \n\n    - *" + orga1_clean + "* (genome size = " + str(gs1) +
        #          " bp )  \n\n    - *" + orga2_clean + "* (genome size = " + str(gs2) + " bp)  \n")
        # fw.write(
        #     "- **The formula used to compute the read ratio is the following:** \n\n  $NormalizedRead_{Ratio} = \\log2\\left(\\frac{\\frac{N_{\\ aDNA\\ bp \\ aligned \\ genome1}}{size_{genome2} }}{\\frac{N_{ \\ aDNA \\ bp \\ aligned \\ genome2}}{size_{genome2}}}\\right)$\n\n  ")
        fw.write("\n***\n")
        fw.write("## coproID read ratio plot\n\n")
        fw.write("![Normalized read ratio (insert legend here)](./ratio.png)\n")
        fw.write("\n\n***\n\n")
        fw.write("***\n\n")
        for asample in d.keys():
            orga1_clean = d[asample]["orga1"].replace("_", " ")
            orga2_clean = d[asample]["orga2"].replace("_", " ")
        #     if d[asample]["nrr"] > 1:
        #         conclusion = 'There are more *' + orga1_clean + '* aDNA reads than *' + orga2_clean + \
        #             '* by > 2 folds. Therefore this coprolite is most likely originating from *' + \
        #             orga1_clean + '*'
        #     elif d[asample]["nrr"] < -1:
        #         conclusion = 'There are more *' + orga2_clean + '* aDNA reads than *' + orga1_clean + \
        #             '* by > 2 folds. Therefore this coprolite is most likely originating from *' + \
        #             orga2_clean + '*'
        #     else:
        #         conclusion = 'There is a similar amount of aDNA reads originating from *' + \
        #             orga1_clean + '* and *' + orga2_clean + \
        #             '*. coproID can not reliably estimate which of these two organisms this coprolite is coming from.'
        #     fw.write("### **Sample:** " + asample + "  \n")
        #     fw.write("- **Number of base pairs (bp) in reads aligned to *" +
        #              orga1_clean + "* (genome 1) at " + str(identity) + "\% identity:** " + str(d[asample]["nbp1"]) + "  \n")
        #     fw.write("- **Number of base pairs (bp) in reads aligned to *" +
        #              orga2_clean + "* (genome 2) at " + str(identity) + "\% identity:** " + str(d[asample]["nbp2"]) + "  \n\n")
        #     fw.write("- $NormalizedRead_{ratio} = \\log2\\left(\\frac{" + d[asample]["orga1"].replace("_", "\\;") + "}{" +
        #              d[asample]["orga2"].replace("_", "\\;") + "}\\right) = " + str(round(d[asample]["nrr"], 3)) + "$   \n\n")
        #     fw.write(conclusion + '\n\n')
        #     fw.write("- **MapDamage plots**\n\n")
        #     if d[asample]["nbp1"] > 20 and ADNA == True:
        #         fw.write("![MapDamage Fragmisincorporation plot for " +
        #                  orga1_clean + ". Very jagged and irregular substitutions are likely indicating spurious damages caused by a lack of reads. ](./" + asample + "." + d[asample]["orga1"] + ".fragmisincorporation_plot.png)\n\n")
        #     else:
        #         fw.write(
        #             "\n\nNo plot is displayed because the number of bases aligned is too low for mapDamage, or it is Modern DNA\n\n")
        #     if d[asample]["nbp2"] > 20 and ADNA == True:
        #         fw.write("![MapDamage Fragmisincorporation plot for " +
        #                  orga2_clean + ". Very jagged and irregular substitutions are likely indicating spurious damages caused by a lack of reads. If there is not plot displayed, the aligned read count was too low for mapDamage](./" + asample + "." + d[asample]["orga2"] + ".fragmisincorporation_plot.png)\n\n")
        #     else:
        #         fw.write(
        #             "\n\nNo plot is displayed because the number of bases aligned is too low for mapDamage, or it is Modern DNA\n\n")
            fw.write("\n\n***\n")
