import sys
import argparse
from lxml import etree
import pandas as pd
import numpy as np
import sqlalchemy
import time
import re

def clean_mem_string(mem_string):
    if mem_string[0] == '.':
        return '0'+mem_string
    else:
        return mem_string


def get_mem_from_string(mem_string):
    """
    Get memory in GiB from a string. Permits suffixes M, G, and T (not case sensitive).
    """

    # Get mem and suffix
    match_obj = re.search('^(\\d+(\\.\\d+)?)([MGT]?)$', clean_mem_string(mem_string.upper()))
    
    if match_obj is None:
        raise RuntimeError('Unrecognized memory string: {} (expect a number with an optional M, G, or T suffix)'.format(mem_string))
    
    mem = float(match_obj[1])
    suffix = match_obj[3]
    
    # Get memory and suffix
    if suffix == 'G':
        return mem
    
    if suffix == "M":
        return mem / 1024
    
    if suffix == 'T':
        return mem * 1024
    
    if suffix == '':
        return 0.0

def get_user_data(input_file):
    owner_list, slots_list, mem_list = [], [], []

    for _, element in etree.iterparse(input_file, tag="job_list"):
        if element.get("state") == "running":
            nslots = int(element.findtext("slots"))
            owner_list.append(element.findtext("JB_owner"))
            slots_list.append(nslots)
            mem = [get_mem_from_string(entry.text) for entry in element.iter(tag=etree.Element) if entry.get("name") == "m_mem_free"][0]
            mem_list.append(mem * nslots)

    dat = pd.DataFrame(data={"User": owner_list, "Slots": slots_list, "Memory (GB)": mem_list})
    dat["Sessions"] = 1
    dat = dat[["User", "Sessions", "Slots", "Memory (GB)"]]
    grouped = dat.groupby("User")
    outfile = grouped.agg("sum").reset_index()
    outfile.sort_values(by=["Slots", "Memory (GB)"], ascending=False, inplace=True)
    outfile["Memory (GB)"] = outfile["Memory (GB)"].astype(int)
    return outfile

def get_mem_total_from_joblist(etree_element):
    mem_sum = 0
    
    for elem in etree_element.iter(tag=etree.Element):
        if elem.get("state") == "running":
            nslots = int(elem.findtext("slots"))
            mem = [get_mem_from_string(entry.text) for entry in elem.iter(tag=etree.Element) if entry.get("name") == "m_mem_free"][0]
            mem_sum += mem * nslots
    
    return mem_sum

def get_node_data(input_file):
    node_name_list, slots_used_list, slots_reserved_list, nslots_list, load_list, mem_list = [],[],[],[],[],[]

    for _, element in etree.iterparse(input_file, tag="Queue-List"):
        node_name = element.findtext("name").split("@")[1].split(".")[0]
        nslots = int(element.findtext("slots_total"))
        node_name_list.append(node_name)
        slots_used_list.append(int(element.findtext("slots_used")))
        nslots_list.append(nslots)
        load = element.findtext("np_load_avg")
        if load is not None:
            load_list.append(float(load))
        else:
            load_list.append(0)
        mem_list.append(get_mem_total_from_joblist(element))

    dat = pd.DataFrame(data={"Node": node_name_list,
                             "Slots Used": slots_used_list,
                             "Total Slots": nslots_list,
                             "Load": load_list,
                             "Memory Reserved (GB)": mem_list
                            })
    
    dat = dat[["Node", "Slots Used", "Total Slots", "Load", "Memory Reserved (GB)"]]
    #dat = dat[["Node", "Load", "Memory Reserved (GB)"]]
    grouped = dat.groupby(["Node", "Load", "Total Slots"]).sum().reset_index()
    grouped = grouped[["Node", "Slots Used", "Total Slots", "Load", "Memory Reserved (GB)"]]
    grouped.sort_values(by=["Memory Reserved (GB)", "Slots Used"], ascending=False, inplace=True)
    grouped["Memory Reserved (GB)"] = grouped["Memory Reserved (GB)"].astype(int)
    
    return grouped

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("infile")
    parser.add_argument("--modes", nargs="+", choices=["user", "node"], default=["user", "node"])
    parser.add_argument("--sql", type=str, default=None, help="Append table to specified SQL database")
    args = parser.parse_args()

    current_time = time.strftime("%Y-%m-%d %H:%M:%S")

    infile = args.infile

    if "user" in args.modes:
        usertable = get_user_data(infile)
        usertable["Datetime"] = current_time
        usercolumns = ["Datetime", "User"]
        
    if "node" in args.modes:
        nodetable = get_node_data(infile)
        nodetable["Datetime"] = current_time
        nodecolumns = ["Datetime", "Node"]

    if args.sql is not None:
        disk_engine = sqlalchemy.create_engine("sqlite:///" + args.sql)
        if "user" in args.modes:
            usertable.to_sql("user_data", disk_engine, if_exists="append", index=False, index_label=usercolumns)
        if "node" in args.modes:
            nodetable.to_sql("node_data", disk_engine, if_exists="append", index=False, index_label=nodecolumns)
            
    else:
        if len(args.modes) == 2:
            outtable = usertable.merge(nodetable, on="Datetime", how="outer")
        elif "user" in args.modes:
            outtable = usertable
        else:
            outtable = nodetable
            
        outtable.to_csv(sys.stdout, sep="\t", index=False)
