#!/usr/libexec/platform-python
# /* vim: set filetype=python : */
import collections
import itertools
import socket
import struct
import subprocess
import sys
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Type, TypeVar


class SubnetNotFoundException(Exception):
    """
    Exception raised when no subnet in the systems ifaces is on the VIP subnet
    """


class AddressNotFoundException(Exception):
    """
    Exception raised when no Address in the systems ifaces is on the VIP subnet
    """


TA = TypeVar('TA', bound='Address')


class Address:
    def __init__(self, cidr: str, name: str, family: str, index: int = -1, scope: str = '', flags: Iterable[str] = tuple(), label: Optional[str] = None) -> None:
        self.index = index
        self.name = name
        self.family = family
        self.cidr = cidr
        self.scope = scope
        self.flags = flags
        self.label = label

    @classmethod
    def from_line(cls: Type[TA], line: str) -> TA:
        tokens = collections.deque(line.split())
        index = int(tokens.popleft()[:-1])
        name = tokens.popleft()
        family = tokens.popleft()
        cidr = tokens.popleft()
        _ = tokens.popleft()  # dump scope label
        scope = tokens.popleft()
        flags = []
        label = None
        while True:
            token = tokens.popleft()
            if token[-1] == '\\':
                if len(token) > 1:
                    label = token[:-1]
                break
            flags.append(token)
        return cls(cidr, name, family, index, scope, flags, label)

    def __str__(self) -> str:
        return f'{self.__class__.__name__}({self.cidr}, dev={self.name})'


TR = TypeVar('TR', bound='V6Route')


class V6Route:
    def __init__(self, destination: str, dev: Optional[str] = None, proto: Optional[str] = None, metric: Optional[int] = None, pref: Optional[str] = None, via: Optional[str] = None) -> None:
            self.destination: str = destination
            self.via: Optional[str] = via
            self.dev: Optional[str] = dev
            self.proto: Optional[str] = proto
            self.metric: Optional[int] = metric
            self.pref: Optional[str] = pref

    @classmethod
    def from_line(cls: Type[TR], line: str) -> TR:
        items = line.split()
        dest = items[0]
        if dest == 'default':
            dest = '::/0'
        attrs = dict(itertools.zip_longest(*[iter(items[1:])]*2, fillvalue=None))
        attrs['destination'] = dest
        return cls(**attrs)

    def __str__(self) -> str:
        return f'{self.__class__.__name__}({self.destination}, dev={self.dev})'


SUBNET_MASK_LEN = {
    'inet': 32,
    'inet6': 128
}


def ntoa(family: str, num: int) -> str:
    if family == 'inet':
        result = socket.inet_ntoa(struct.pack("!I", num))
    else:
        lo_half = num & 0xFFFFFFFFFFFFFFFF
        hi_half = num >> 64
        result = socket.inet_ntop(socket.AF_INET6,
                                    struct.pack(">QQ", hi_half, lo_half))
    return result


def aton(family: str, rep: str) -> int:
    if family == 'inet':
        result = struct.unpack("!I", socket.inet_aton(rep))[0]
    else:
        hi_half, lo_half = struct.unpack(">QQ", socket.inet_pton(socket.AF_INET6, rep))
        result = (hi_half << 64) | lo_half
    return result


def addr_subnet_int_min_max(addr: Address) -> Tuple[int, int]:
    ip_addr, prefix = addr.cidr.split('/')
    ip_int = aton(addr.family, ip_addr)

    prefix_int = int(prefix)
    mask = int('1' * prefix_int +
                '0' * (SUBNET_MASK_LEN[addr.family] - prefix_int), 2)

    subnet_ip_int_min = ip_int & mask

    remainder = '1' * (SUBNET_MASK_LEN[addr.family] - prefix_int)
    subnet_ip_int_max = subnet_ip_int_min | (
        0 if remainder == '' else int(remainder, 2))
    return subnet_ip_int_min, subnet_ip_int_max


def vip_subnet_and_addrs_in_it(vip: str, addrs: List[Address]) -> Tuple[Address, List[Address]]:
    try:
        vip_int = aton('inet', vip)
    except Exception:
        vip_int = aton('inet6', vip)
    subnet = None
    candidates = []
    for addr in addrs:
        subnet_ip_int_min, subnet_ip_int_max = addr_subnet_int_min_max(addr)
        subnet_ip = ntoa(addr.family, subnet_ip_int_min)
        subnet_ip_max = ntoa(addr.family, subnet_ip_int_max)

        sys.stderr.write('Is %s between %s and %s\n' %
                            (vip, subnet_ip, subnet_ip_max))
        if subnet_ip_int_min < vip_int < subnet_ip_int_max:
            subnet_ip = ntoa(addr.family, subnet_ip_int_min)
            subnet = Address(name="subnet",
                                cidr='%s/%s' % (subnet_ip, addr.cidr.split('/')[1]),
                                family=addr.family,
                                scope='')
            candidates.append(addr)
    if subnet is None:
        raise SubnetNotFoundException()
    return subnet, candidates


def interface_addrs(filters: Optional[Iterable[Callable[[Address], bool]]] = None) -> Iterator[Address]:
    out = subprocess.check_output(["ip", "-o", "addr", "show"], encoding=sys.stdout.encoding)
    for addr in (Address.from_line(line) for line in out.splitlines()):
        if not filters or all(f(addr) for f in filters):
            if (addr.family == 'inet6' and
                    int(addr.cidr.split('/')[1]) == SUBNET_MASK_LEN[addr.family]):
                route_out = subprocess.check_output(["ip", "-o", "-6", "route", "show"],
                                                    encoding=sys.stdout.encoding)
                for route in (V6Route.from_line(rline) for rline in route_out.splitlines()):
                    if (route.dev == addr.name and route.proto == 'ra' and
                            route.destination != '::/0'):
                        sys.stderr.write('Checking %s for %s\n' % (route, addr))
                        route_net = Address(name=route.dev, cidr=route.destination, family='inet6')
                        route_filter = in_subnet(route_net)
                        if route_filter(addr):
                            ip_addr = addr.cidr.split('/')[0]
                            route_prefix = route_net.cidr.split('/')[1]
                            cidr = '%s/%s' % (ip_addr, route_prefix)
                            yield Address(cidr=cidr,
                                            family=addr.family,
                                            name=addr.name)
            yield addr


def non_host_scope(addr: Address) -> bool:
    if addr.scope == 'host':
        sys.stderr.write(f'Filtering out {addr} due to it having host scope\n')
        res = False
    else:
        res = True
    return res


def non_deprecated(addr: Address) -> bool:
    if 'deprecated' in addr.flags:
        sys.stderr.write(f'Filtering out {addr} due to it being deprecated\n')
        res = False
    else:
        res = True
    return res


def non_secondary(addr: Address) -> bool:
    if 'secondary' in addr.flags:
        sys.stderr.write(f'Filtering out {addr} due to it being secondary\n')
        res = False
    else:
        res = True
    return res


def in_subnet(subnet: Address) -> Callable[[Address], bool]:
    subnet_ip_int_min, subnet_ip_int_max = addr_subnet_int_min_max(subnet)

    def filt(addr: Address) -> bool:
        ip_addr, _ = addr.cidr.split('/')
        ip_int = aton(addr.family, ip_addr)
        return subnet_ip_int_min < ip_int < subnet_ip_int_max
    return filt


def main() -> None:
    api_vip = sys.argv[1]
    vips = set(sys.argv[1:4])
    filters = (non_host_scope, non_deprecated, non_secondary)
    iface_addrs = list(interface_addrs(filters))
    try:
        subnet, candidates = vip_subnet_and_addrs_in_it(api_vip, iface_addrs)
        sys.stderr.write('VIP Subnet %s\n' % subnet.cidr)

        for addr in candidates:
            ip_addr, _ = addr.cidr.split('/')
            if ip_addr not in vips:
                print(ip_addr)
                sys.exit(0)
    except SubnetNotFoundException:
        sys.exit(1)


if __name__ == '__main__':
    main()
