package cluster

import (
	"context"
	"fmt"
	"io"
	"strings"
	"time"

	corev1 "k8s.io/api/core/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/apimachinery/pkg/util/wait"
	"k8s.io/client-go/kubernetes"
	runtimeclient "sigs.k8s.io/controller-runtime/pkg/client"

	sriovv1 "github.com/k8snetworkplumbingwg/sriov-network-operator/api/v1"
	testclient "github.com/k8snetworkplumbingwg/sriov-network-operator/test/util/client"
	"github.com/k8snetworkplumbingwg/sriov-network-operator/test/util/nodes"
	"github.com/k8snetworkplumbingwg/sriov-network-operator/test/util/pod"
)

// EnabledNodes provides info on sriov enabled nodes of the cluster.
type EnabledNodes struct {
	Nodes               []string
	States              map[string]sriovv1.SriovNetworkNodeState
	IsSecureBootEnabled map[string]bool
}

var (
	supportedPFDrivers = []string{"mlx5_core", "i40e", "ixgbe", "ice"}
	supportedVFDrivers = []string{"iavf", "vfio-pci", "mlx5_core"}
	mlxVendorID        = "15b3"
	intelVendorID      = "8086"
)

// DiscoverSriov retrieves Sriov related information of a given cluster.
func DiscoverSriov(clients *testclient.ClientSet, operatorNamespace string) (*EnabledNodes, error) {
	nodeStates, err := clients.SriovNetworkNodeStates(operatorNamespace).List(context.Background(), metav1.ListOptions{})
	if err != nil {
		return nil, fmt.Errorf("Failed to retrieve note states %v", err)
	}

	res := &EnabledNodes{}
	res.States = make(map[string]sriovv1.SriovNetworkNodeState)
	res.Nodes = make([]string, 0)
	res.IsSecureBootEnabled = make(map[string]bool)

	ss, err := nodes.MatchingOptionalSelectorState(clients, nodeStates.Items)
	if err != nil {
		return nil, fmt.Errorf("Failed to find matching node states %v", err)
	}

	err = sriovv1.InitNicIdMap(kubernetes.NewForConfigOrDie(clients.Config), operatorNamespace)
	if err != nil {
		return nil, fmt.Errorf("Failed to InitNicIdMap %v", err)
	}

	for _, state := range ss {
		isStable, err := stateStable(state)
		if err != nil {
			return nil, err
		}
		if !isStable {
			return nil, fmt.Errorf("Sync status still in progress")
		}

		node := state.Name
		for _, itf := range state.Status.Interfaces {
			if IsPFDriverSupported(itf.Driver) && sriovv1.IsSupportedDevice(itf.DeviceID) {
				res.Nodes = append(res.Nodes, node)
				res.States[node] = state
				break
			}
		}
	}

	for _, node := range res.Nodes {
		isSecureBootEnabled, err := GetNodeSecureBootState(clients, node, operatorNamespace)
		if err != nil {
			return nil, err
		}

		res.IsSecureBootEnabled[node] = isSecureBootEnabled
	}

	if len(res.Nodes) == 0 {
		return nil, fmt.Errorf("No sriov enabled node found")
	}
	return res, nil
}

// FindOneSriovDevice retrieves a valid sriov device for the given node.
func (n *EnabledNodes) FindOneSriovDevice(node string) (*sriovv1.InterfaceExt, error) {
	s, ok := n.States[node]
	if !ok {
		return nil, fmt.Errorf("Node %s not found", node)
	}
	for _, itf := range s.Status.Interfaces {
		if IsPFDriverSupported(itf.Driver) && sriovv1.IsSupportedDevice(itf.DeviceID) {

			// Skip mlx interfaces if secure boot is enabled
			// TODO: remove this when mlx support secure boot/lockdown mode
			if itf.Vendor == mlxVendorID && n.IsSecureBootEnabled[node] {
				continue
			}

			return &itf, nil
		}
	}
	return nil, fmt.Errorf("Unable to find sriov devices in node %s", node)
}

// FindSriovDevices retrieves all valid sriov devices for the given node.
func (n *EnabledNodes) FindSriovDevices(node string) ([]*sriovv1.InterfaceExt, error) {
	devices := []*sriovv1.InterfaceExt{}
	s, ok := n.States[node]
	if !ok {
		return nil, fmt.Errorf("Node %s not found", node)
	}

	for i, itf := range s.Status.Interfaces {
		if IsPFDriverSupported(itf.Driver) && sriovv1.IsSupportedDevice(itf.DeviceID) {
			// Skip mlx interfaces if secure boot is enabled
			// TODO: remove this when mlx support secure boot/lockdown mode
			if itf.Vendor == mlxVendorID && n.IsSecureBootEnabled[node] {
				continue
			}

			devices = append(devices, &s.Status.Interfaces[i])
		}
	}
	return devices, nil
}

// FindOneVfioSriovDevice retrieves a node with a valid sriov device for vfio
func (n *EnabledNodes) FindOneVfioSriovDevice() (string, sriovv1.InterfaceExt) {
	for _, node := range n.Nodes {
		for _, nic := range n.States[node].Status.Interfaces {
			if nic.Vendor == intelVendorID {
				return node, nic
			}
		}
	}
	return "", sriovv1.InterfaceExt{}
}

// FindOneMellanoxSriovDevice retrieves a valid sriov device for the given node.
func (n *EnabledNodes) FindOneMellanoxSriovDevice(node string) (*sriovv1.InterfaceExt, error) {
	s, ok := n.States[node]
	if !ok {
		return nil, fmt.Errorf("Node %s not found", node)
	}

	// return error here as mlx interfaces are not supported when secure boot is enabled
	// TODO: remove this when mlx support secure boot/lockdown mode
	if n.IsSecureBootEnabled[node] {
		return nil, fmt.Errorf("Secure boot is enabled on the node mellanox cards are not supported")
	}

	for _, itf := range s.Status.Interfaces {
		if itf.Vendor == mlxVendorID {
			return &itf, nil
		}
	}

	return nil, fmt.Errorf("Unable to find a mellanox sriov devices in node %s", node)
}

// SriovStable tells if all the node states are in sync (and the cluster is ready for another round of tests)
func SriovStable(operatorNamespace string, clients *testclient.ClientSet) (bool, error) {
	nodeStates, err := clients.SriovNetworkNodeStates(operatorNamespace).List(context.Background(), metav1.ListOptions{})
	switch err {
	case io.ErrUnexpectedEOF:
		return false, err
	case nil:
		break
	default:
		return false, fmt.Errorf("Failed to fetch nodes state %v", err)
	}

	if len(nodeStates.Items) == 0 {
		return false, nil
	}
	for _, state := range nodeStates.Items {
		nodeReady, err := stateStable(state)
		if err != nil {
			return false, err
		}
		if !nodeReady {
			return false, nil
		}
	}
	return true, nil
}

func stateStable(state sriovv1.SriovNetworkNodeState) (bool, error) {
	switch state.Status.SyncStatus {
	case "Succeeded":
		return true, nil
	// When the config daemon is restarted the status will be empty
	// This doesn't mean the config was applied
	case "":
		return false, nil
	}
	return false, nil
}

func IsPFDriverSupported(driver string) bool {
	for _, supportedDriver := range supportedPFDrivers {
		if strings.Contains(driver, supportedDriver) {
			return true
		}
	}
	return false
}

func IsVFDriverSupported(driver string) bool {
	for _, supportedDriver := range supportedVFDrivers {
		if strings.Contains(driver, supportedDriver) {
			return true
		}
	}
	return false
}

func IsClusterStable(clients *testclient.ClientSet) (bool, error) {
	nodes, err := clients.Nodes().List(context.Background(), metav1.ListOptions{})
	if err != nil {
		return false, err
	}

	for _, node := range nodes.Items {
		if node.Spec.Unschedulable {
			return false, nil
		}
	}

	return true, nil
}

// IsSingleNode validates if the environment is single node cluster
// This is done by checking numer of nodes, it can later be substituted by an env variable if needed
func IsSingleNode(clients *testclient.ClientSet) (bool, error) {
	nodes, err := clients.Nodes().List(context.Background(), metav1.ListOptions{})
	if err != nil {
		return false, err
	}
	return len(nodes.Items) == 1, nil
}

func GetNodeDrainState(clients *testclient.ClientSet, operatorNamespace string) (bool, error) {
	sriovOperatorConfg := &sriovv1.SriovOperatorConfig{}
	err := clients.Get(context.TODO(), runtimeclient.ObjectKey{Name: "default", Namespace: operatorNamespace}, sriovOperatorConfg)
	return sriovOperatorConfg.Spec.DisableDrain, err
}

func SetDisableNodeDrainState(clients *testclient.ClientSet, operatorNamespace string, state bool) error {
	sriovOperatorConfg := &sriovv1.SriovOperatorConfig{}
	err := clients.Get(context.TODO(), runtimeclient.ObjectKey{Name: "default", Namespace: operatorNamespace}, sriovOperatorConfg)
	if err != nil {
		return err
	}
	sriovOperatorConfg.Spec.DisableDrain = state
	err = clients.Update(context.TODO(), sriovOperatorConfg)
	if err != nil {
		return err
	}
	return nil
}

func GetNodeSecureBootState(clients *testclient.ClientSet, nodeName, namespace string) (bool, error) {
	podDefinition := pod.GetDefinition()
	podDefinition = pod.RedefineWithNodeSelector(podDefinition, nodeName)
	podDefinition = pod.RedefineAsPrivileged(podDefinition)
	podDefinition.Namespace = namespace

	volume := corev1.Volume{Name: "host", VolumeSource: corev1.VolumeSource{HostPath: &corev1.HostPathVolumeSource{Path: "/"}}}
	mount := corev1.VolumeMount{Name: "host", MountPath: "/host"}
	podDefinition = pod.RedefineWithMount(podDefinition, volume, mount)
	created, err := clients.Pods(namespace).Create(context.Background(), podDefinition, metav1.CreateOptions{})
	if err != nil {
		return false, err
	}

	var runningPod *corev1.Pod
	err = wait.PollImmediate(time.Second, 3*time.Minute, func() (bool, error) {
		runningPod, err = clients.Pods(namespace).Get(context.Background(), created.Name, metav1.GetOptions{})
		if err != nil {
			return false, err
		}

		if runningPod.Status.Phase != corev1.PodRunning {
			return false, nil
		}

		return true, nil
	})
	if err != nil {
		return false, err
	}

	stdout, _, err := pod.ExecCommand(clients, runningPod, "cat", "/host/sys/kernel/security/lockdown")

	if strings.Contains(stdout, "No such file or directory") {
		return false, nil
	}
	if err != nil {
		return false, err
	}

	return strings.Contains(stdout, "[integrity]") || strings.Contains(stdout, "[confidentiality]"), nil
}
