/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2022 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jboss.installer.postinstall.task;

import javax.xml.namespace.QName;
import javax.xml.stream.XMLEventFactory;
import javax.xml.stream.XMLEventReader;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.events.Attribute;
import javax.xml.stream.events.StartElement;
import javax.xml.stream.events.XMLEvent;

import com.google.auto.service.AutoService;
import org.jboss.as.controller.OperationFailedException;
import org.jboss.as.controller.PathAddress;
import org.jboss.dmr.ModelNode;
import org.jboss.installer.auto.AutomaticInstallationParsingException;
import org.jboss.installer.auto.InstallationDataSerializer;
import org.jboss.installer.auto.ListXMLEventReader;
import org.jboss.installer.core.InstallationData;
import org.jboss.installer.core.LoggerUtils;
import org.jboss.installer.postinstall.CliPostInstallTask;
import org.jboss.installer.postinstall.PostInstallTask;
import org.jboss.installer.postinstall.TaskPrinter;
import org.jboss.installer.postinstall.server.DomainServer;
import org.jboss.installer.postinstall.server.EmbeddedServer;
import org.jboss.installer.postinstall.server.StandaloneServer;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import static org.jboss.as.controller.operations.common.Util.createEmptyOperation;

@AutoService(PostInstallTask.class)
public class PortConfigurationTask implements CliPostInstallTask {

    public static final String TASK_NAME_KEY = "post_install.task.port_config.name";
    public static final String DOMAIN_MANAGEMENT_HTTPS_PORT_PROPERTY = "jboss.management.https.port";
    private static final String DOMAIN_MANAGEMENT_PORT_PROPERTY = "jboss.management.http.port";
    public static final int DOMAIN_MANAGEMENT_HTTPS_PORT_DEFAULT = 9993;
    public static final int DOMAIN_MANAGEMENT_PORT_DEFAULT = 9990;
    private static final String JBOSS_PROPERTY_SUB = "${%s:%s}";
    private static final String DOMAIN_PRIMARY_PORT_PROPERTY = "jboss.domain.primary.port";
    private static final String SOCKET_OFFSET_PROPERTY = "jboss.socket.binding.port-offset";
    private static final String STANDARD_SOCKETS_BINDING_GROUP = "standard-sockets";
    private static final String[] BINDING_GROUPS = new String[]{"standard", "ha", "full", "full-ha"};
    private static final String DOMAIN_SCREEN_PREFIX = "domain-";
    public static final String PRIMARY_HOST_NAME = "primary";

    @Override
    public boolean applyToStandalone(InstallationData idata, StandaloneServer server, TaskPrinter printer) {
        printer.print("tasks.ports.started", server.currentConfiguration());
        final Config config = idata.getConfig(Config.class);
        try {
            if (config.getOffset() != 0) {
                printer.print("tasks.ports.offset", "" + config.getOffset());
                applyStandaloneOffset(server, config);
            } else {
                final List<SocketBinding> ports = config.getPorts(server.currentConfiguration());
                if (ports != null) {
                    for (SocketBinding port : ports) {
                        printer.print("tasks.ports.binding", port.getKey());
                        applySocketBinding(server, STANDARD_SOCKETS_BINDING_GROUP, port);
                    }
                }
            }
        } catch (OperationFailedException e) {
            LoggerUtils.taskLog.error("Failed to perform operation", e);
            printer.print("tasks.ports.failed");
            printer.print(e);
            return false;
        }

        printer.print("tasks.ports.finished");
        return true;
    }

    @Override
    public boolean applyToDomain(InstallationData idata, DomainServer server, TaskPrinter printer) {
        printer.print("tasks.ports.started", server.currentConfiguration());
        final Config config = idata.getConfig(Config.class);

        try {
            if (config.getOffset() != 0) {
                printer.print("tasks.ports.offset", "" + config.getOffset());
                applyOffsetToDomainConfigs(server, config);
            } else {
                applyIndividualPortsToDomain(server, config, printer);
            }
        } catch (OperationFailedException e) {
            LoggerUtils.taskLog.error("Failed to perform operation", e);
            printer.print("tasks.ports.failed");
            printer.print(e);
            return false;
        }

        printer.print("tasks.ports.finished");
        return true;
    }

    private void applyIndividualPortsToDomain(DomainServer server, Config config, TaskPrinter printer) throws OperationFailedException {
        if (!server.currentConfiguration().equals(DomainServer.HOST_SECONDARY_XML)) {
            for (String bindingGroup : BINDING_GROUPS) {
                final List<SocketBinding> ports = config.getPorts(DOMAIN_SCREEN_PREFIX + bindingGroup);
                if (ports != null) {
                    for (SocketBinding port : ports) {
                        printer.print("tasks.ports.binding", port.getKey());
                        applySocketBinding(server, bindingGroup + "-sockets", port);
                    }
                }
            }
        }
        if (config.getManagementInterface() != null) {
            printer.print("tasks.ports.management_interface", config.managementInterface.portValue + "");
            if (!server.currentConfiguration().equals(DomainServer.HOST_SECONDARY_XML)) {
                setHostInterfacePort(server, config.managementInterface);
            } else {
                final String hostGroupName = server.readHostGroupName();
                setStaticDiscoveryPort(server, hostGroupName, config.managementInterface.portValue);
            }
        }
    }

    private void applyOffsetToDomainConfigs(DomainServer server, Config config) throws OperationFailedException {
        int offset = config.getOffset();
        final String hostGroupName = server.readHostGroupName();
        for (String serverConfig : readServerConfigs(server, hostGroupName)) {
            updatePortOffsetForServer(server, offset, serverConfig, hostGroupName);
        }
        if (!server.currentConfiguration().equals(DomainServer.HOST_SECONDARY_XML)) {
            setHostInterfacePort(server, offset);
        } else {
            setStaticDiscoveryPort(server, hostGroupName, DOMAIN_MANAGEMENT_PORT_DEFAULT + offset);
        }
    }

    private void applyStandaloneOffset(StandaloneServer server, Config config) throws OperationFailedException {
        int offset = config.getOffset();
        final ModelNode defaultOffsetOp = createEmptyOperation("write-attribute",
                PathAddress.pathAddress("socket-binding-group", STANDARD_SOCKETS_BINDING_GROUP));
        defaultOffsetOp.get("name").set("port-offset");
        defaultOffsetOp.get("value").set(asExpression(SOCKET_OFFSET_PROPERTY, offset));
        server.execute(defaultOffsetOp, "Set default port offset");
    }

    private void applySocketBinding(EmbeddedServer server, String socketGroup, SocketBinding port) throws OperationFailedException {
        final ModelNode socketBindingOp = createEmptyOperation("write-attribute",
                PathAddress.pathAddress("socket-binding-group", socketGroup).append("socket-binding", port.key));
        socketBindingOp.get("name").set("port");
        if (port.mcastAddress == null) {
            socketBindingOp.get("value").set(asExpression(port.property, port.portValue));
            server.execute(socketBindingOp, "Set default port offset");
        } else if (port.portValue != null) {
            socketBindingOp.get("value").set(port.portValue);
            server.execute(socketBindingOp, "Set default port offset");
        }

        if (port.mcastAddress != null) {
            socketBindingOp.get("name").set("multicast-address");
            socketBindingOp.get("value").set(asExpression(port.property, port.mcastAddress));
            server.execute(socketBindingOp, "Set default port offset");

            socketBindingOp.get("name").set("multicast-port");
            socketBindingOp.get("value").set(port.mcastPort);
            server.execute(socketBindingOp, "Set default port offset");
        }
    }

    private void setStaticDiscoveryPort(DomainServer server, String hostGroupName, int portValue) throws OperationFailedException {
        final ModelNode hostInterfacePort = createEmptyOperation("write-attribute",
                PathAddress.pathAddress("host", hostGroupName).append("core-service", "discovery-options")
                        .append("static-discovery", "primary"));
        hostInterfacePort.get("name").set("port");
        hostInterfacePort.get("value").set(asExpression(DOMAIN_PRIMARY_PORT_PROPERTY, portValue));
        server.execute(hostInterfacePort, "Set default port offset");
    }

    private void setHostInterfacePort(DomainServer server, int offset) throws OperationFailedException {
        setHostInterfacePort(server, new SocketBinding(null, DOMAIN_MANAGEMENT_PORT_PROPERTY, DOMAIN_MANAGEMENT_PORT_DEFAULT + offset));
    }

    private void setHostInterfacePort(DomainServer server, SocketBinding socketBinding) throws OperationFailedException {
        final ModelNode hostInterfacePort = createEmptyOperation("write-attribute",
                PathAddress.pathAddress("host", PRIMARY_HOST_NAME).append("core-service", "management")
                        .append("management-interface", "http-interface"));
        hostInterfacePort.get("name").set("port");
        hostInterfacePort.get("value").set(asExpression(socketBinding.property, socketBinding.portValue));
        server.execute(hostInterfacePort, "Set default port offset");
    }

    private String asExpression(String key, int value) {
        return asExpression(key, "" + value);
    }

    private String asExpression(String key, String value) {
        if (key == null || key.trim().isEmpty()) {
            return value;
        }
        return String.format(JBOSS_PROPERTY_SUB, key, value);
    }

    private List<String> readServerConfigs(DomainServer server, String hostGroupName) throws OperationFailedException {
        final ModelNode readServerConfigs = createEmptyOperation("read-children-names", PathAddress.pathAddress("host", hostGroupName));
        readServerConfigs.get("child-type").set("server-config");
        final ModelNode res = server.execute(readServerConfigs, "Set default port offset");
        return res.asList().stream().map(n->n.asString()).collect(Collectors.toList());
    }

    private void updatePortOffsetForServer(DomainServer server, int offset, String serverName, String hostGroupName) throws OperationFailedException {
        final ModelNode readCurrentOffset = createEmptyOperation("read-attribute", PathAddress.pathAddress("host", hostGroupName).append("server-config", serverName));
        readCurrentOffset.get("name").set("socket-binding-port-offset");
        final ModelNode res = server.execute(readCurrentOffset, "Set default port offset");
        int port = res.asInt() + offset;

        final ModelNode defaultOffsetOp = createEmptyOperation("write-attribute",
                PathAddress.pathAddress("host", hostGroupName).append("server-config", serverName));
        defaultOffsetOp.get("name").set("socket-binding-port-offset");
        defaultOffsetOp.get("value").set(port);
        server.execute(defaultOffsetOp, "Set default port offset");
    }

    @Override
    public String getName() {
        return TASK_NAME_KEY;
    }

    @Override
    public String getSerializationName() {
        return "change-port-configurations";
    }

    @Override
    public Class<? extends InstallationData.PostInstallConfig> getConfigClass() {
        return Config.class;
    }

    public static class Config implements InstallationData.PostInstallConfig {
        public static final String OFFSET_TAG = "offset";
        public static final String PORT_CONFIG_TAG = "port-config";
        public static final String MANAGEMENT_INTERFACE_TAG = "management-interface";
        public static final String SOCKET_BINDINGS_TAG = "socket-bindings";
        public static final String SOCKET_BINDING_TAG = "socket-binding";
        public static final String KEY_ATTRIBUTE = "key";
        public static final String PROPERTY_ATTRIBUTE = "property";
        public static final String MCAST_ADDRESS_ATTRIBUTE = "mcastAddress";
        public static final String PORT_VALUE_ATTRIBUTE = "portValue";
        public static final String MCAST_PORT_ATTRIBUTE = "mcastPort";
        public static final String CONFIG_ATTRIBUTE = "config";
        private int offset;
        private Map<String, List<SocketBinding>> socketBindings = new HashMap<>();
        private SocketBinding managementInterface = null;

        public Config() {
            // no-op for deserializer
        }

        public Config(int offset) {
            this.offset = offset;
        }

        public int getOffset() {
            return offset;
        }

        public List<SocketBinding> getPorts(String config) {
            return socketBindings.get(config);
        }

        public void setPorts(String config, List<SocketBinding> bindings) {
            socketBindings.put(config, bindings);
        }

        public void removePorts(String config) {
            socketBindings.remove(config);
        }

        public void setManagementInterfacePort(SocketBinding socketBinding) {
            this.managementInterface = socketBinding;
        }

        public void removeManagementInterfacePort() {
            this.managementInterface = null;
        }

        public SocketBinding getManagementInterface() {
            return managementInterface;
        }

        @Override
        public XMLEventReader serialize(XMLEventFactory eventFactory, Set<String> variables) {
            final ArrayList<XMLEvent> events = new ArrayList<>();
            events.add(eventFactory.createStartElement(InstallationDataSerializer.PREFIX, InstallationDataSerializer.NS, PORT_CONFIG_TAG));
            events.add(eventFactory.createAttribute(OFFSET_TAG, ""+ offset));

            if (managementInterface != null) {
                events.add(eventFactory.createStartElement(InstallationDataSerializer.PREFIX, InstallationDataSerializer.NS, MANAGEMENT_INTERFACE_TAG));
                events.addAll(serializeSocketBinding(eventFactory, managementInterface));
                events.add(eventFactory.createEndElement(InstallationDataSerializer.PREFIX, InstallationDataSerializer.NS, MANAGEMENT_INTERFACE_TAG));
            }

            for (String key : socketBindings.keySet()) {
                List<SocketBinding> bindings = socketBindings.get(key);
                events.add(eventFactory.createStartElement(InstallationDataSerializer.PREFIX, InstallationDataSerializer.NS, SOCKET_BINDINGS_TAG));
                events.add(eventFactory.createAttribute(CONFIG_ATTRIBUTE, key));
                for (SocketBinding binding : bindings) {
                    events.addAll(serializeSocketBinding(eventFactory, binding));
                }
                events.add(eventFactory.createEndElement(InstallationDataSerializer.PREFIX, InstallationDataSerializer.NS, SOCKET_BINDINGS_TAG));
            }

            events.add(eventFactory.createEndElement(InstallationDataSerializer.PREFIX, InstallationDataSerializer.NS, PORT_CONFIG_TAG));

            return new ListXMLEventReader(events);
        }

        private List<XMLEvent> serializeSocketBinding(XMLEventFactory eventFactory, SocketBinding socketBindining) {
            final List<XMLEvent> events = new ArrayList<>();
            events.add(eventFactory.createStartElement(InstallationDataSerializer.PREFIX, InstallationDataSerializer.NS, SOCKET_BINDING_TAG));
            if (socketBindining.getKey() != null) {
                events.add(eventFactory.createAttribute(KEY_ATTRIBUTE, socketBindining.getKey()));
            }
            if (socketBindining.getProperty() != null) {
                events.add(eventFactory.createAttribute(PROPERTY_ATTRIBUTE, socketBindining.getProperty()));
            }
            if (socketBindining.getMcastAddress() != null) {
                events.add(eventFactory.createAttribute(MCAST_ADDRESS_ATTRIBUTE, socketBindining.getMcastAddress()));
            }
            if (socketBindining.getPortValue() != null) {
                events.add(eventFactory.createAttribute(PORT_VALUE_ATTRIBUTE, socketBindining.getPortValue().toString()));
            }
            events.add(eventFactory.createAttribute(MCAST_PORT_ATTRIBUTE, "" + socketBindining.getMcastPort()));
            events.add(eventFactory.createEndElement(InstallationDataSerializer.PREFIX, InstallationDataSerializer.NS, SOCKET_BINDING_TAG));
            return events;
        }

        @Override
        public void deserialize(XMLEventReader reader, BiFunction<String, String, String> variableResolver) throws AutomaticInstallationParsingException {
            try {
                while (reader.hasNext()) {
                    final XMLEvent xmlEvent = reader.nextEvent();
                    if (xmlEvent.isStartElement()) {
                        final StartElement elem = xmlEvent.asStartElement();
                        if (elem.getName().getLocalPart().equals(PORT_CONFIG_TAG)) {
                            this.offset = Integer.parseInt(elem.getAttributeByName(new QName(OFFSET_TAG)).getValue());
                        } else if (elem.getName().getLocalPart().equals(MANAGEMENT_INTERFACE_TAG)) {
                            while (reader.hasNext()) {
                                final XMLEvent xmlEvent1 = reader.nextEvent();
                                if (xmlEvent1.isEndElement() && xmlEvent1.asEndElement().getName().getLocalPart().equals(MANAGEMENT_INTERFACE_TAG)) {
                                    break;
                                }
                                if (xmlEvent1.isStartElement()) {
                                    if (xmlEvent1.asStartElement().getName().getLocalPart().equals(SOCKET_BINDING_TAG)) {
                                        StartElement bindingElement = xmlEvent1.asStartElement();
                                        this.managementInterface = deserializeSocketBinding(bindingElement);
                                    } else {
                                        throw InstallationDataSerializer.unexpectedElement(xmlEvent1.asStartElement());
                                    }
                                }
                            }
                        } else if (elem.getName().getLocalPart().equals(SOCKET_BINDINGS_TAG)) {
                            final String config = elem.getAttributeByName(new QName(CONFIG_ATTRIBUTE)).getValue();
                            List<SocketBinding> bindings = new ArrayList<>();
                            while (reader.hasNext()) {
                                final XMLEvent xmlEvent1 = reader.nextEvent();
                                if (xmlEvent1.isEndElement() && xmlEvent1.asEndElement().getName().getLocalPart().equals(SOCKET_BINDINGS_TAG)) {
                                    break;
                                }
                                if (xmlEvent1.isStartElement()) {
                                    if (xmlEvent1.asStartElement().getName().getLocalPart().equals(SOCKET_BINDING_TAG)) {
                                        bindings.add(deserializeSocketBinding(xmlEvent1.asStartElement()));
                                    } else {
                                        throw InstallationDataSerializer.unexpectedElement(xmlEvent1.asStartElement());
                                    }
                                }
                            }
                            socketBindings.put(config, bindings);
                        }
                    }
                }
            } catch (XMLStreamException e) {
                throw InstallationDataSerializer.unableToParse(e);
            }
        }

        private SocketBinding deserializeSocketBinding(StartElement bindingElement) {
            final String key = readBindingAttribute(bindingElement, KEY_ATTRIBUTE);
            final String property = readBindingAttribute(bindingElement, PROPERTY_ATTRIBUTE);
            final String mcastAddress = readBindingAttribute(bindingElement, MCAST_ADDRESS_ATTRIBUTE);
            final String portValueText = readBindingAttribute(bindingElement, PORT_VALUE_ATTRIBUTE);
            Integer portValue = portValueText!=null?Integer.parseInt(portValueText):null;
            final int mcastPort = Integer.parseInt(readBindingAttribute(bindingElement, MCAST_PORT_ATTRIBUTE));

            return new SocketBinding(key, property, mcastAddress, mcastPort, portValue);
        }

        private String readBindingAttribute(StartElement bindingElement, String key) {
            Attribute attr = bindingElement.getAttributeByName(new QName(key));
            return (attr != null)? attr.getValue():null;
        }
    }


    public static class SocketBinding {
        private final String key;
        private final String property;
        private final String mcastAddress;
        private final Integer portValue;
        private final int mcastPort;

        public SocketBinding(String key, String portProperty, int portValue) {
            this.key = key;
            this.property = portProperty;
            this.portValue = portValue;
            this.mcastAddress = null;
            this.mcastPort = 0;
        }

        public SocketBinding(String key, String property, String mcastAddress, int mcastPort, Integer port) {
            this.key = key;
            this.property = property;
            this.portValue = port;
            this.mcastAddress = mcastAddress;
            this.mcastPort = mcastPort;
        }

        public String getKey() {
            return key;
        }

        public String getProperty() {
            return property;
        }

        public String getMcastAddress() {
            return mcastAddress;
        }

        public Integer getPortValue() {
            return portValue;
        }

        public int getMcastPort() {
            return mcastPort;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            SocketBinding that = (SocketBinding) o;
            return mcastPort == that.mcastPort && Objects.equals(key, that.key) && Objects.equals(property, that.property) && Objects.equals(mcastAddress, that.mcastAddress) && Objects.equals(portValue, that.portValue);
        }

        @Override
        public int hashCode() {
            return Objects.hash(key, property, mcastAddress, portValue, mcastPort);
        }
    }
}
