#!/usr/bin/python3

import glob
import os
import subprocess
import re
import string
from operator import itemgetter
from tabulate import tabulate


smart_attr_map = {
    "Reallocated_Sector_Ct": "Reallocated_Sector_Ct",
    "Pending_Sector_Count" : "Pending_Sector_Count",
    "Current_Pending_Sector" : "Pending_Sector_Count",
    "Offline_Uncorrectable": "Uncorrectable_Error_Cnt",
    "Uncorrectable_Error_Cnt": "Uncorrectable_Error_Cnt",
}


def get_smart(dev_path):
    ans = dict()
    smart_pipe = subprocess.Popen(["smartctl", "-A", dev_path],
                                  stdout=subprocess.PIPE)
    (smart_out, smart_err) = smart_pipe.communicate()
    for line in iter(smart_out.splitlines()):
        line = line.decode("ascii")
        sid, sname, raw_value = line[0:3].strip(), line[4:27].strip(), line[87:].strip()
        if sname in smart_attr_map:
            ans[smart_attr_map[sname]] = raw_value
    return ans


def get_phy(vpdpath):
    """
    If vpdpath has .../host[0-9]+/port-*/ return phy-* under it.
    If .../host[0-9]+/ has single target* return host[0-9]+/
    Otherwise, throw error.
    """
    upto_port = re.search("(^.*/port-.*?/)", vpdpath)
    if upto_port:
        port_path = upto_port.group(1)
        phys = glob.glob(port_path + "phy-*")
        if len(phys) > 1:
            raise Exception
        has_phy = re.search("(phy-.*)", phys[0])
        if has_phy:
            return has_phy.group(1)
    upto_host = re.search("(^.*/(host[0-9]+))/", vpdpath)
    if upto_host:
        host_path = upto_host.group(1)
        targets = glob.glob(host_path + "/target*")
        if len(targets) > 1:
            raise Exception
        return upto_host.group(2)
    raise Exception


def get_model(vpdpath):
    modelpath = os.path.join(os.path.dirname(vpdpath), "model")
    return open(modelpath, "rb").read().decode("ascii").strip()


def get_scsi_dev(vpdpath):
    sdevdir = os.path.join(os.path.dirname(vpdpath), "block")
    return os.listdir(sdevdir)


def get_map():
    print("Rear:  host0 host1")
    print("Front: host2 host3  host4   host8")
    print("       host9 host10 host11 host12")
    print()
    phy_list = [ "host0", "host1", "host2", "host3", "host4", "host8",
            "host9", "host10", "host11", "host12" ]
    vpdname = "vpd_pg80"
    tabs = list()
    for dir, dnames, fnames in os.walk("/sys/devices"):
        if vpdname in fnames:
            vpdpath = os.path.join(dir, vpdname)
            sdev = get_scsi_dev(vpdpath)
            smart = get_smart(os.path.join("/dev", sdev[0]))
            phy = get_phy(vpdpath)
            try:
                phy_index = phy_list.index(phy)
            except ValueError:
                continue
            vpd = open(vpdpath, "rb").read()[4:].decode("ascii")
            did = ''.join(filter(lambda x: x in string.printable, vpd)).strip()
            tabs.append([phy_index, phy, get_model(vpdpath), did, sdev,
                         smart["Reallocated_Sector_Ct"],
                         smart["Pending_Sector_Count"],
                         smart["Uncorrectable_Error_Cnt"]])
    tabs = sorted(tabs, key=itemgetter(0))
    print(tabulate(tabs, ["", "Port", "Model", "SerialNum", "Dev",
                          "Real", "Pend", "Unco"], showindex=False))


def get_model(scsi_dev):
    udevadm = subprocess.check_output(["udevadm", "info", "--query=all",
                                       "--name="+scsi_dev])
    for line in iter(udevadm.splitlines()):
        line = line.decode("ascii")
        match = re.search("ID_SERIAL=(.*)$", line)
        if match:
            return(match.group(1))
    raise ValueError


def get_lshw():
    hwpath_to_model = dict()
    hwpath_to_dev = dict()
    lshw = subprocess.check_output(["lshw", "-short", "-c", "disk"],
                                   stderr=subprocess.DEVNULL)
    for line in iter(lshw.splitlines()):
        line = line.decode("ascii")
        fields = line.split()
        if len(fields) < 3:
            continue
        hwpath, device, cls = fields[0], fields[1], fields[2]
        if cls != "disk":
            continue
        model = get_model(device)
        hwpath_to_model[hwpath] = model
        hwpath_to_dev[hwpath] = device

    hwpath_order = [
        "/0/100/17/0", "/0/100/17/3",
        "/0/100/17/4", "/0/100/17/5", "/0/100/17/6", "/0/100/17/7",
        "/0/100/17/1", "/0/100/17/2", "/0/100/1c.2/0/0", "/0/100/1c.2/0/1"
    ]

    tabs = list()
    for hwpath in hwpath_order:
        model = hwpath_to_model[hwpath]
        device = hwpath_to_dev[hwpath]
        smart = get_smart(device)
        if "Reallocated_Sector_Ct" not in smart: continue
        if "Pending_Sector_Count" not in smart: continue
        if "Uncorrectable_Error_Cnt" not in smart: continue
        tabs.append([hwpath, model, device,
                     smart["Reallocated_Sector_Ct"],
                     smart["Pending_Sector_Count"],
                     smart["Uncorrectable_Error_Cnt"]])
    print(tabulate(tabs, ["Port", "Model+Serial", "Dev",
                          "Re", "Pe", "Un"], showindex=False))
        

if __name__ == "__main__":
    get_lshw()
