#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
""" 
A script that can get and set anycast config
"""
import re
import os, sys, json
import subprocess
import copy

# global flag to help with debugging the script
debug = False


def PsmClient(args):
    """
    Runs the command "PsmClient" along with arguments args

    :param args: A list of strings, containing the arguments.
    :return: The output of the PsmClient command. Will throw an exception if an error
             is detected. Note: PsmClient always returns with exit code 0
    """
    assert isinstance(args, list)
    args.insert(0, "/usr/local/bluecat/PsmClient")
    result = subprocess.run(args, capture_output=True, text=True)
    out = result.stdout.strip()
    if out.find("retcode=ok") < 0:
        raise Exception(out.rstrip())
    return out


def PsmClientSet(argline):
    if debug:
        print("PsmClient {}".format(argline))
    return PsmClient([argline])


pIpv4 = re.compile(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}(?:\/\d{1,2})?")


def isIpv4(addr) -> bool:
    """
    Returns true if 'addr' is an IPv4 address or IPv4 network (CIDR)
    """
    return pIpv4.match(addr)


# Line will starts with "anycast get-notify " then have <key>=<value> pairs
# where value can be optional
pKeyValue = re.compile(r"([^\s=]+)(?:=([^\s=]+))?")


# assumes line contains key=value where value can be optional and
# returns all key, values found as an object
def getValues(line) -> dict:
    values = {}
    matches = pKeyValue.findall(line)
    for key, value in matches:
        values[key] = value if value else None
    return values


# return all key=values found in a anycast get notify
def getAnycastValues(line) -> dict:
    if line.startswith("anycast get-notify "):
        return getValues(line[19:])
    return {}


def getServices() -> dict:
    """
    Returns a dictionary of all services managed by PsmClient
    """
    regexp = re.compile(r"^node \S+ ([^\s=]+)(?:=([^\s=]+))?\s*$")
    services = {}
    out = PsmClient(["node", "get"])
    for line in str(out).splitlines():
        re_res = regexp.split(line)
        # print(re_res)
        if len(re_res) > 3:
            if re_res[1]:
                services[re_res[1]] = re_res[2]
    return services


# template for default config
anycast_template = {
    "override": False,
    "enabled": False,
    "bgp": {
        "enabled": False,
        "authenticate": False,
        "password": "",
        "addresses": [],
        "neighbors": [],
        "prefixLists": [],
    },
    "ospf": {"enabled": False, "authenticate": False, "password": "", "addresses": []},
}


def getConfig() -> dict:

    anycast = copy.deepcopy(anycast_template)
    services = getServices()
    anycast["enabled"] = services.get("anycast-enable", "0") == "1"
    if services.get("manual-override") and "anycast" in services["manual-override"]:
        anycast["override"] = True

    addresses = []
    # Process each line
    out = PsmClient(["anycast", "get"])
    for line in out.splitlines():
        values = getAnycastValues(line)

        # remove the trailing 'd' of the component so we will have bgp, ospf, rip etc
        comp = values.get("component", "").rstrip("d")
        main_line = False  # for each component there might be one or more lines, one is the main

        if comp == "bgp" or comp == "ospf":
            # if authenticate is there this is the main line
            main_line = "authenticate" in values
            if "enabled" in values:
                anycast[comp]["enabled"] = values.get("enabled") == "1"
            if "authenticate" in values:
                anycast[comp]["authenticate"] = values.get("authenticate") == "1"
            if main_line and "password" in values:
                anycast[comp]["password"] = values["password"] or ""

        if comp == "common":
            # here we have the anycast addresses we will put them into who ever
            # active when done reading
            if values.get("anycast-ipv4"):
                addresses.extend(values.get("anycast-ipv4").split(","))
            if values.get("anycast-ipv6"):
                addresses.extend(values.get("anycast-ipv6").split(","))

        if comp == "ospf":
            if "area" in values:
                anycast[comp]["area"] = values.get("area", "")
            if "dead-interval" in values:
                anycast[comp]["deadInterval"] = int(values["dead-interval"])
            if "hello-interval" in values:
                anycast[comp]["helloInterval"] = int(values["hello-interval"])
            if "stub" in values:
                anycast[comp]["stub"] = values.get("stub") == "1"
            if "message-digest-key" in values:
                anycast[comp]["messageDigestKey"] = values["message-digest-key"] or ""

        if comp == "bgp":
            if "holdtime" in values:
                anycast[comp]["holdTime"] = int(values["holdtime"])
            if "keepalive" in values:
                anycast[comp]["keepAlive"] = int(values["keepalive"])
            if main_line and "asn" in values:
                anycast[comp]["asn"] = int(values["asn"])
            if "router-id" in values:
                anycast[comp]["routerID"] = values["router-id"] or ""
            if not main_line and (
                "neighbor-ipv4" in values or "neighbor-ipv6" in values
            ):
                item = {}
                item["asn"] = int(values["asn"])
                item["hopLimit"] = int(values["ebgp-multihop"])
                item["nextHopSelf"] = values.get("next-hop-self") == "1"
                item["password"] = values.get("password", "") or ""
                item["address"] = values.get(
                    "neighbor-ipv4", values.get("neighbor-ipv6")
                )
                anycast[comp]["neighbors"].append(item)
            if not main_line and "network" in values:
                item = {}
                item["name"] = values.get("prefix-list", "")
                item["action"] = values.get("action", "")
                item["network"] = values.get("network", "")
                item["seq"] = int(values.get("seq", "0"))
                anycast[comp]["prefixLists"].append(item)

    # Need to sort the prefix-list by 'seq'. It's seems to be outputted
    # sorted by network
    if len(anycast["bgp"]["prefixLists"]) > 0:
        anycast["bgp"]["prefixLists"].sort(key=lambda x: x["seq"])
        # Then we need to remove "seq" member from all items in the list
        for prefix in anycast["bgp"]["prefixLists"]:
            prefix.pop("seq", None)

    # addresses are common so it's the same for ospf and bgp
    anycast["ospf"]["addresses"] = addresses
    anycast["bgp"]["addresses"] = addresses
    return anycast


def setConfig(config):

    def addOptionalString(comp, option, arg_name):
        """
        Returns a string for the argument 'arg_name' of type text. If the text value is empty
        we cannot add the equal sign but just output the argument name
        """
        text = ""
        if option in comp:
            text += " " + arg_name
            if comp[option] and len(comp[option]) > 0:
                text += "=" + comp[option]
        return text

    def addOptionalValue(comp, option, arg_name):
        """
        Returns a string for the argument 'arg_name' of type value.
        """
        text = ""
        if option in comp and comp[option] is not None:
            text += " {}={:d}".format(arg_name, comp[option])
        return text

    curr_config = getConfig()

    # If anycast is enabled we turn it off while doing changes
    if "enabled" in curr_config and curr_config["enabled"]:
        PsmClientSet("node set anycast-enable=0")

    common_addr = []

    if "ospf" in config:
        comp = config["ospf"]
        password = comp.get("password", "")
        authenticate = comp.get("authenticate", True) and len(password) > 0
        argline = "anycast set component=ospfd enabled={:d} authenticate={:d} area={}".format(
            comp["enabled"], authenticate, comp["area"]
        )

        # assume that password, dead-interval, hello-interval, stub and message-digest-key are optional
        argline += addOptionalString(comp, "password", "password")
        argline += addOptionalValue(comp, "deadInterval", "dead-interval")
        argline += addOptionalValue(comp, "helloInterval", "hello-interval")
        argline += addOptionalValue(comp, "stub", "stub")
        argline += addOptionalString(comp, "messageDigestKey", "message-digest-key")

        PsmClientSet(argline)
        if comp["enabled"] and "addresses" in comp:
            common_addr = comp["addresses"]

    if "bgp" in config:
        comp = config["bgp"]
        password = comp.get("password", "")
        authenticate = comp.get("authenticate", True) and len(password) > 0
        argline = "anycast set component=bgpd enabled={:d} authenticate={:d} asn={} router-id={}".format(
            comp["enabled"],
            authenticate,
            comp["asn"],
            comp["routerID"],
        )
        argline += addOptionalString(comp, "password", "password")
        argline += addOptionalValue(comp, "holdTime", "holdtime")
        argline += addOptionalValue(comp, "keepAlive", "keepalive")

        ipv4 = []
        ipv6 = []
        for item in comp["neighbors"]:
            addr = item["address"]
            if isIpv4(addr):
                ipv4.append(addr)
            else:
                ipv6.append(addr)
        prefixLists = []
        for item in comp["prefixLists"]:
            # only list each name once
            if not item["name"] in prefixLists:
                prefixLists.append(item["name"])

        # Note: setting it's an error to set neighbors-ipvx= with an empty
        # value. If we skip the keyword it isn't cleared but if we have the keyword
        # with no assignment and no value, it will be cleared
        argline += " neighbors-ipv4"
        if len(ipv4) > 0:
            argline += "=" + ",".join(ipv4)
        argline += " neighbors-ipv6"
        if len(ipv6) > 0:
            argline += "=" + ",".join(ipv6)
        argline += " prefix-lists"
        if len(prefixLists) > 0:
            argline += "=" + ",".join(prefixLists)

        # if prefix list is non empty we need to clear it to make sure
        # we are starting with a clean slate
        if len(prefixLists) > 0:
            PsmClientSet("anycast set component=bgpd prefix-lists")

        PsmClientSet(argline)

        # then for each neighboor
        for item in comp["neighbors"]:
            argline = "anycast set component=bgpd asn={} ebgp-multihop={:d} next-hop-self={:d}".format(
                item["asn"], item["hopLimit"], item["nextHopSelf"]
            )
            argline += addOptionalString(item, "password", "password")

            if isIpv4(item["address"]):
                argline += " neighbor-ipv4=" + item["address"]
                # argline += " prefix-list-in=INPUTv4 prefix-list-out=OUTPUTv4"
            else:
                argline += " neighbor-ipv6=" + item["address"]
                # argline += " prefix-list-in=INPUTv6 prefix-list-out=OUTPUTv6"

            PsmClientSet(argline)

        # then for each prefix-list
        seq = 5
        for item in comp["prefixLists"]:
            # PsmClient anycast set component=bgpd action=permit network=fd00:1111::/32 prefix-list=INPUTv6 seq=5
            argline = "anycast set component=bgpd action={} prefix-list={} network={} seq={:d}".format(
                item["action"], item["name"], item["network"], seq
            )
            seq += 1
            PsmClientSet(argline)

        if len(common_addr) == 0 and "addresses" in comp:
            common_addr = comp["addresses"]

    # common
    ipv4 = []
    ipv6 = []
    for addr in common_addr:
        if isIpv4(addr):
            ipv4.append(addr)
        else:
            ipv6.append(addr)

    # NOTE: should state=ACTIVE be set??!?! unrecognized command
    vtysh_enabled = True

    argline = "anycast set component=common vtysh-enable={:d}".format(vtysh_enabled)
    argline += " anycast-ipv4"
    if len(ipv4) > 0:
        argline += "=" + ",".join(ipv4)
    argline += " anycast-ipv6"
    if len(ipv6) > 0:
        argline += "=" + ",".join(ipv6)

    PsmClientSet(argline)

    # and the last we do if the anycast should be enabled we turn it on
    if "enabled" in config and config["enabled"]:
        PsmClientSet("node set anycast-enable=1")


if __name__ == "__main__":
    try:
        result = {}
        if len(sys.argv) > 1:
            if sys.argv[1] == "-set":
                # input can either be an argument or a through stdin
                text = sys.argv[2] if len(sys.argv) > 2 else sys.stdin.read()
                input = json.loads(text)
                setConfig(input)
            elif (len(sys.argv) > 2) and (sys.argv[1] == "-f"):
                # arguments as a file
                text = ""
                with open(sys.argv[2], "r") as file:
                    text = file.read()
                input = json.loads(text)
                setConfig(input)

        # If it not -set or -f we assume this is -get
        # and in all cases we will return the current config
        result = getConfig()
        if debug:
            print(json.dumps(result, indent=4))
        else:
            print(json.dumps(result))
    except Exception as e:
        error_str = str(e)
        # check if we can extract a better error message
        match = re.search(r"error=\"([^\"]*)\"", error_str)
        if match:
            error_str = match.group(1)
        sys.stderr.write(error_str + "\n")
        sys.exit(1)
