package org.jboss.eap.util.xp.patch.stream.manager;

import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;

import javax.xml.namespace.QName;
import javax.xml.stream.XMLInputFactory;
import javax.xml.stream.XMLStreamConstants;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamReader;

import org.jboss.staxmapper.XMLElementReader;
import org.jboss.staxmapper.XMLExtendedStreamReader;
import org.jboss.staxmapper.XMLMapper;

/**
 * @author <a href="mailto:kabir.khan@jboss.com">Kabir Khan</a>
 */
public class PatchXml implements XMLStreamConstants, XMLElementReader<PatchXml.Validator> {
    private static final XMLMapper MAPPER = XMLMapper.Factory.create();
    private static final PatchXml XML = new PatchXml();
    private static final XMLInputFactory INPUT_FACTORY = XMLInputFactory.newInstance();

    static {
        MAPPER.registerRootElement(new QName(Namespace.PATCH_1_0.getNamespace(), "patch"), XML);
        MAPPER.registerRootElement(new QName(Namespace.PATCH_1_1.getNamespace(), "patch"), XML);
        MAPPER.registerRootElement(new QName(Namespace.PATCH_1_2.getNamespace(), "patch"), XML);
    }

    public enum Namespace {

        PATCH_1_0("urn:jboss:patch:1.0"),
        PATCH_1_1("urn:jboss:patch:1.1"),
        PATCH_1_2("urn:jboss:patch:1.2"),
        UNKNOWN(null),
        ;

        private final String namespace;
        Namespace(String namespace) {
            this.namespace = namespace;
        }

        public String getNamespace() {
            return namespace;
        }

        static Map<String, Namespace> elements = new HashMap<String, Namespace>();
        static {
            for(Namespace element : Namespace.values()) {
                if(element != UNKNOWN) {
                    elements.put(element.namespace, element);
                }
            }
        }

        static Namespace forUri(String name) {
            final Namespace element = elements.get(name);
            return element == null ? UNKNOWN : element;
        }
    }

    private PatchXml() {
        //
    }

    public static void validateBasePatch(final InputStream stream, ServerVersion minimumServerVersion) throws XMLStreamException {
        Validator validator = parse(stream);
        validator.validateBasePatch(minimumServerVersion);
    }

    public static void validateXpPatch(final InputStream stream, FileSet toFileSet) throws XMLStreamException {
        Validator validator = parse(stream);
        validator.validateXpPatch(toFileSet);
    }

    static Validator parse(final InputStream stream) throws XMLStreamException {
        XMLStreamReader reader = getXMLInputFactory().createXMLStreamReader(stream);
        try {
            final Validator validator = new Validator();
            MAPPER.parseDocument(validator, reader);
            return validator;
        } finally {
            reader.close();
        }
    }

    static InputStream findPatchXmlInZip(Path patchFile)  {
        byte[] buffer = new byte[1024];
        try (ZipInputStream zin = new ZipInputStream(new BufferedInputStream(new FileInputStream(patchFile.toFile())))) {
            ZipEntry entry = zin.getNextEntry();
            while (entry != null) {
                try {
                    if (!entry.isDirectory()) {
                        if (entry.getName().equals("patch.xml")) {
                            try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
                                int len;
                                while ((len = zin.read(buffer)) > 0) {
                                    out.write(buffer, 0, len);
                                }
                                return new ByteArrayInputStream(out.toByteArray());
                            }
                        }
                    }
                } finally {
                    zin.closeEntry();
                    entry = zin.getNextEntry();
                }
            }
        } catch (IOException e) {
            throw ManagerLogger.LOGGER.patchFileIsNotAnArchive(e, patchFile);
        }

        return null;
    }

    private static XMLInputFactory getXMLInputFactory() throws XMLStreamException {
        final XMLInputFactory inputFactory = INPUT_FACTORY;
        setIfSupported(inputFactory, XMLInputFactory.IS_VALIDATING, Boolean.FALSE);
        setIfSupported(inputFactory, XMLInputFactory.SUPPORT_DTD, Boolean.FALSE);
        return inputFactory;
    }

    private static void setIfSupported(final XMLInputFactory inputFactory, final String property, final Object value) {
        if (inputFactory.isPropertySupported(property)) {
            inputFactory.setProperty(property, value);
        }
    }



    @Override
    public void readElement(XMLExtendedStreamReader reader, Validator validator) throws XMLStreamException {
        final int count = reader.getAttributeCount();
        for (int i = 0; i < count; i++) {
            final String value = reader.getAttributeValue(i);
            switch (reader.getAttributeLocalName(i)) {
                case "id": {
                    validator.setPatchId(value);
                }
                break;
            }
        }

        while (reader.hasNext() && reader.nextTag() != END_ELEMENT) {
            final String localName = reader.getLocalName();
            switch (localName) {
                case "upgrade": {
                    parseTopLevelUpgradeElement(reader, validator);
                }
                break;
                case "element": {
                    parsePatchElement(reader, validator);
                }
                break;
                default:
                    consumeChildren(reader);
            }
        }
    }

    private void parseTopLevelUpgradeElement(XMLExtendedStreamReader reader, Validator validator) throws XMLStreamException {
        // Note: <upgrade> happens on two levels. This is the top one
        String patchStreamName = null;
        final int count = reader.getAttributeCount();
        for (int i = 0; i < count; i++) {
            final String value = reader.getAttributeValue(i);
            switch (reader.getAttributeLocalName(i)) {
                case "name": {
                    validator.setPatchStreamName(value);
                }
                break;
                case "to-version": {
                    ServerVersion serverVersion = ServerVersion.parse(value);
                    validator.setPatchVersion(serverVersion);
                }
                break;
            }
        }

        while (reader.hasNext() && reader.nextTag() != END_ELEMENT) {
            consumeChildren(reader);
        }
    }

    private void parsePatchElement(XMLExtendedStreamReader reader, Validator validator) throws XMLStreamException {
        String id = null;
        final int count = reader.getAttributeCount();
        for (int i = 0; i < count; i++) {
            final String value = reader.getAttributeValue(i);
            switch (reader.getAttributeLocalName(i)) {
                case "id": {
                    id = value;
                }
                break;
            }
        }

        String layerName = null;
        while (reader.hasNext() && reader.nextTag() != END_ELEMENT) {
            final String localName = reader.getLocalName();
            switch (localName) {
                case "upgrade": {
                    layerName = getLayerFromUpgradeElement(reader);
                }
                break;
                default:
                    consumeChildren(reader);
            }
        }

        validator.addLayerIdAndName(id, layerName);
    }

    private String getLayerFromUpgradeElement(XMLExtendedStreamReader reader) throws XMLStreamException {
        // Note: <upgrade> happens on two levels. This is the one nested under patch/element
        String layerName = null;
        final int count = reader.getAttributeCount();
        for (int i = 0; i < count; i++) {
            final String value = reader.getAttributeValue(i);
            switch (reader.getAttributeLocalName(i)) {
                case "name": {
                    layerName = value;
                }
                break;
            }
        }

        while (reader.hasNext() && reader.nextTag() != END_ELEMENT) {
            consumeChildren(reader);
        }

        return layerName;
    }

    public void consumeChildren(XMLExtendedStreamReader reader) throws XMLStreamException {

        while (reader.hasNext()) {
            int type = reader.next();
            switch (type) {
                case START_ELEMENT:
                    consumeChildren(reader);
                    break;
                case END_ELEMENT:
                    return;
                default:
                    // Do nothing, easy!
            }
        }
    }



    static class Validator {
        private static final Pattern BASE_PATCH_ID = Pattern.compile("jboss-eap-7\\.\\d+\\.\\d+\\.CP");
        private static final Pattern XP_PATCH_ID = Pattern.compile("jboss-eap-xp-\\d+\\.\\d+\\.\\d+\\.CP");
        private String patchId;
        private final List<String> layerIds = new ArrayList<>();
        private final List<String> layerNames = new ArrayList<>();
        private String patchStreamName;
        private ServerVersion patchVersion;

        public Validator() {
        }

        public void setPatchId(String patchName) {
            this.patchId = patchName;
        }

        public void addLayerIdAndName(String layerId, String layerName) {
            layerIds.add(layerId);
            layerNames.add(layerName);
        }

        public void setPatchStreamName(String patchStreamName) {
            this.patchStreamName = patchStreamName;
        }

        public void setPatchVersion(ServerVersion patchVersion) {
            this.patchVersion = patchVersion;
        }

        public void validateBasePatch(ServerVersion minimumServerVersion) {
            validatePatch(ManagerArgsParser.ARG_BASE_PATCH, BASE_PATCH_ID, "layer-base-", "base");

            if (minimumServerVersion.compareTo(patchVersion) > 0) {
                throw ManagerLogger.LOGGER.incompatibleBasePatchVersion(patchVersion, minimumServerVersion);
            }
        }

        public boolean isBasePatch() {
            return BASE_PATCH_ID.matcher(patchId).matches();
        }

        public boolean isXpPatch() {
            return XP_PATCH_ID.matcher(patchId).matches();
        }

        public String getPatchStreamName() {
            return patchStreamName;
        }

        public void validateXpPatch(FileSet toFileSet) {
            validatePatch(ManagerArgsParser.ARG_XP_PATCH, XP_PATCH_ID, "layer-", "microprofile");
            if (toFileSet != null) {
                // For the unit tests we don't pass this in
                Set<String> patchStreams = new HashSet<>();
                for (Path path : toFileSet.getPatchStreamFiles()) {
                    String confFileName = path.getFileName().toString();
                    if (confFileName.endsWith(".conf")) {
                        int index = confFileName.lastIndexOf(".conf");
                        confFileName = confFileName.substring(0, index);
                        patchStreams.add(confFileName);
                    }
                }
                if (!patchStreams.contains(patchStreamName)) {
                    throw ManagerLogger.LOGGER.invalidXpPatchForManagerVersion(patchStreamName, patchStreams);
                }
            }
        }

        private void validatePatch(String patchArg, Pattern patchIdPattern, String layerIdPrefix, String layer) {
            if (patchId == null || !patchIdPattern.matcher(patchId).matches()) {
                throw ManagerLogger.LOGGER.invalidPatchIdForPatch(patchArg, patchId);
            }

            if (layerIds.size() != 1 || layerNames.size() != 1) {
                // layerNames will be the same size as layerIds according to the patching xsd
                throw ManagerLogger.LOGGER.patchElementsNotEqualsOneInPatch(patchArg);
            }

            if (!layerIds.get(0).equals(layerIdPrefix + patchId)) {
                throw ManagerLogger.LOGGER.badPatchElementIdInPatch(patchArg, layerIds.get(0));
            }

            if (!layerNames.get(0).equals(layer)) {
                throw ManagerLogger.LOGGER.badUpgradeElementNameInPatch(patchArg, layerNames.get(0), layer);
            }
        }
    }
}
