/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.shared.testing;

import java.io.IOException;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import javax.annotation.Nonnull;
import javax.net.ServerSocketFactory;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;

import com.unboundid.asn1.ASN1OctetString;
import com.unboundid.ldap.listener.InMemoryDirectoryServer;
import com.unboundid.ldap.listener.InMemoryDirectoryServerConfig;
import com.unboundid.ldap.listener.InMemoryListenerConfig;
import com.unboundid.ldap.listener.InMemoryRequestHandler;
import com.unboundid.ldap.listener.InMemorySASLBindHandler;
import com.unboundid.ldap.sdk.BindResult;
import com.unboundid.ldap.sdk.Control;
import com.unboundid.ldap.sdk.DN;
import com.unboundid.ldap.sdk.LDAPException;
import com.unboundid.ldap.sdk.LDAPResult;
import com.unboundid.ldap.sdk.OperationType;
import com.unboundid.ldap.sdk.ResultCode;
import com.unboundid.ldif.LDIFException;
import com.unboundid.ldif.LDIFReader;
import com.unboundid.util.ssl.SSLUtil;

import net.shibboleth.shared.annotation.ParameterName;
import net.shibboleth.shared.annotation.constraint.Positive;
import net.shibboleth.shared.logic.Constraint;
import net.shibboleth.shared.primitive.LoggerFactory;

import org.slf4j.Logger;
import org.springframework.core.io.Resource;

/**
 * Manages an instance of the in-memory directory server for unit testing.
 */
public class InMemoryDirectory {

    /** Class logger. */
    @Nonnull private final Logger log = LoggerFactory.getLogger(InMemoryDirectory.class);

    /** Directory server. */
    @Nonnull private final InMemoryDirectoryServer directoryServer;

    /** Server socket factory to track created sockets. */
    @Nonnull private final CustomServerSocketFactory customServerSocketFactory;

    /**
     * Constructor without STARTTLS support.
     *
     * @param baseDNs to use in the directory server
     * @param ldif the LDIF resource to be imported
     * @param port port to listen on
     *
     * @throws RuntimeException if the in-memory directory cannot be created
     */
    public InMemoryDirectory(@ParameterName(name="baseDNs") @Nonnull final String[] baseDNs,
                             @ParameterName(name="ldif") @Nonnull final Resource ldif,
                             @ParameterName(name="port") @Positive final int port) {
        this(baseDNs, ldif, port, false);
    }

    /**
     * Constructor without STARTTLS support.
     *
     * @param baseDNs to use in the directory server
     * @param ldif the LDIF resource to be imported
     * @param port port to listen on
     * @param requireAuthForSearch whether to require authentication in order to perform searches
     *
     * @throws RuntimeException if the in-memory directory cannot be created
     */
    public InMemoryDirectory(@ParameterName(name="baseDNs") @Nonnull final String[] baseDNs,
                             @ParameterName(name="ldif") @Nonnull final Resource ldif,
                             @ParameterName(name="port") @Positive final int port,
                             @ParameterName(name="requireAuthForSearch") final boolean requireAuthForSearch) {
        Constraint.isNotNull(ldif, "LDIF resource cannot be null");
        try {
            final InMemoryDirectoryServerConfig config = new InMemoryDirectoryServerConfig(baseDNs);
            customServerSocketFactory = new CustomServerSocketFactory();
            final InMemoryListenerConfig listenerConfig =
                new InMemoryListenerConfig(
                    "default",
                    InetAddress.getByName("localhost"),
                    port,
                    customServerSocketFactory,
                    null,
                    null);
            config.setListenerConfigs(listenerConfig);
            config.addAdditionalBindCredentials("cn=Directory Manager", "password");
            if (requireAuthForSearch) {
                config.setAuthenticationRequiredOperationTypes(OperationType.SEARCH);
            }
            addSuccessSaslBindHandlers(config);
            directoryServer = new InMemoryDirectoryServer(config);
            directoryServer.importFromLDIF(true, new LDIFReader(ldif.getInputStream()));
        } catch (final Exception e) {
            throw new RuntimeException("Error creating directory server", e);
        }
    }

    /**
     * Constructor with STARTTLS support.
     *
     * @param baseDNs to use in the directory server
     * @param ldif the LDIF resource to be imported
     * @param port port to listen on
     * @param keystore to use for startTLS
     * @param truststore to use for startTLS
     *
     * @throws RuntimeException if the in-memory directory cannot be created
     */
    public InMemoryDirectory(@ParameterName(name="baseDNs") @Nonnull final String[] baseDNs,
                             @ParameterName(name="ldif") @Nonnull final Resource ldif,
                             @ParameterName(name="port") @Positive final int port,
                             @ParameterName(name="keystore") @Nonnull final Resource keystore,
                             @ParameterName(name="truststore") @Nonnull final Optional<Resource> truststore) {
        this(baseDNs, ldif, port, false, keystore, truststore);
    }

    /**
     * Constructor with STARTTLS support.
     *
     * @param baseDNs to use in the directory server
     * @param ldif the LDIF resource to be imported
     * @param port port to listen on
     * @param requireAuthForSearch whether to require authentication in order to perform searches
     * @param keystore to use for startTLS
     * @param truststore to use for startTLS
     *
     * @throws RuntimeException if the in-memory directory cannot be created
     */
    public InMemoryDirectory(@ParameterName(name="baseDNs") @Nonnull final String[] baseDNs,
                             @ParameterName(name="ldif") @Nonnull final Resource ldif,
                             @ParameterName(name="port") @Positive final int port,
                             @ParameterName(name="requireAuthForSearch") final boolean requireAuthForSearch,
                             @ParameterName(name="keystore") @Nonnull final Resource keystore,
                             @ParameterName(name="truststore") @Nonnull final Optional<Resource> truststore) {
        Constraint.isNotNull(ldif, "LDIF resource cannot be null");
        try {
            final InMemoryDirectoryServerConfig config = new InMemoryDirectoryServerConfig(baseDNs);
            final KeyManager[] keyManagers = getKeyManagerFactory(keystore).getKeyManagers();
            final TrustManager[] trustManagers = truststore.isPresent() ?
                getTrustManagerFactory(truststore.get()).getTrustManagers() : null;
            final SSLUtil sslUtil = new SSLUtil(keyManagers, trustManagers);
            customServerSocketFactory = new CustomServerSocketFactory();
            final InMemoryListenerConfig listenerConfig =
                new InMemoryListenerConfig(
                    "default",
                    InetAddress.getByName("localhost"),
                    port,
                    customServerSocketFactory,
                    null,
                    sslUtil.createSSLSocketFactory());
            config.setListenerConfigs(listenerConfig);
            config.addAdditionalBindCredentials("cn=Directory Manager", "password");
            if (requireAuthForSearch) {
                config.setAuthenticationRequiredOperationTypes(OperationType.SEARCH);
            }
            addSuccessSaslBindHandlers(config);
            directoryServer = new InMemoryDirectoryServer(config);
            directoryServer.importFromLDIF(true, new LDIFReader(ldif.getInputStream()));
        } catch (final Exception e) {
            throw new RuntimeException("Error creating directory server", e);
        }
    }

    /**
     * Adds DIGEST-MD5 and EXTERNAL SASL bind handlers that always return success.
     *
     * @param config to add SASL bind handlers to
     */
    private void addSuccessSaslBindHandlers(final InMemoryDirectoryServerConfig config) {
        config.addSASLBindHandler(new InMemorySASLBindHandler() {
            @Override
            public String getSASLMechanismName() {
                return "DIGEST-MD5";
            }

            @Override
            public BindResult processSASLBind(final InMemoryRequestHandler handler, final int messageID,
                                              final DN bindDN, final ASN1OctetString credentials,
                                              final List<Control> controls) {
                // return success for all digest MD5 bind requests
                return new BindResult(new LDAPResult(messageID, ResultCode.SUCCESS));
            }
        });
        config.addSASLBindHandler(new InMemorySASLBindHandler() {
            @Override
            public String getSASLMechanismName() {
                return "EXTERNAL";
            }

            @Override
            public BindResult processSASLBind(final InMemoryRequestHandler handler, final int messageID,
                                              final DN bindDN, final ASN1OctetString credentials,
                                              final List<Control> controls) {
                // return success for all EXTERNAL bind requests
                return new BindResult(new LDAPResult(messageID, ResultCode.SUCCESS));
            }
        });
    }

    /**
     * Adds the supplied LDIF lines to the directory server.
     *
     * @param ldifLines to add
     * @throws RuntimeException if an error occurs adding the LDIF
     */
    public void add(final String... ldifLines) {
        try {
            directoryServer.add(ldifLines);
        } catch (LDIFException | LDAPException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Returns the number of open sockets.
     *
     * @return number of open sockets
     */
    public long openConnectionCount() {
        return customServerSocketFactory.sockets.stream().filter(s -> !s.isClosed()).count();
    }

    /**
     * Starts the directory server.
     *
     * @throws RuntimeException if the in-memory directory server cannot be started
     */
    public void start() {
        try {
            directoryServer.startListening();
        } catch (final LDAPException e) {
            throw new RuntimeException(e);
        }
        log.info("In-memory directory server started");
    }

    /**
     * Returns the port the server is listening on.
     *
     * @return  port number
     */
    public int getListenPort() {
        return directoryServer.getListenPort();
    }

    /**
     * Stops the directory server. Note that in general resources should be configured so that LDAP connections are
     * closed at the conclusion of a test method or test class.
     *
     * @param closeConnections whether to close existing connections
     */
    public void stop(final boolean closeConnections) {
        directoryServer.shutDown(closeConnections);
        log.info("In-memory directory server stopped");
    }

    /**
     * Creates a KeyManagerFactory from the supplied resource. A keystore password of "changeit" is assumed.
     *
     * @param keystore resource to read
     * @return key manager factory built from the keystore
     *
     * @throws GeneralSecurityException if the keystore password is incorrect
     * @throws IOException if the resource cannot be read
     */
    private static KeyManagerFactory getKeyManagerFactory(final Resource keystore)
            throws GeneralSecurityException, IOException {
        final KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(
            KeyManagerFactory.getDefaultAlgorithm());
        keyManagerFactory.init(loadKeyStore(keystore, "changeit"), "changeit".toCharArray());
        return keyManagerFactory;
    }

    /**
     * Creates a TrustManagerFactory from the supplied resource. A keystore password of "changeit" is assumed.
     *
     * @param keystore resource to read
     * @return trust manager factory built from the keystore
     *
     * @throws GeneralSecurityException if the keystore password is incorrect
     * @throws IOException if the resource cannot be read
     */
    private static TrustManagerFactory getTrustManagerFactory(final Resource keystore)
            throws GeneralSecurityException, IOException {
        final TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(
            TrustManagerFactory.getDefaultAlgorithm());
        trustManagerFactory.init(loadKeyStore(keystore, "changeit"));
        return trustManagerFactory;
    }

    /**
     * Creates a new KeyStore from the supplied resource.
     *
     * @param keystore resource to read
     * @param password to unlock the keystore
     * @return keystore
     *
     * @throws GeneralSecurityException if the keystore cannot be created from the resource
     * @throws IOException if the resource cannot be read
     */
    private static KeyStore loadKeyStore(final Resource keystore, final String password)
            throws GeneralSecurityException, IOException {
        final KeyStore ks = KeyStore.getInstance("JKS");
        ks.load(keystore.getInputStream(), password.toCharArray());
        return ks;
    }

    /** ServerSocketFactory wrapper class to track created sockets. */
    private static class CustomServerSocketFactory extends ServerSocketFactory {

        /** The socket list. */
        @Nonnull private List<Socket> sockets = new ArrayList<>();

        /** {@inheritDoc} */
        @Override
        public ServerSocket createServerSocket(final int port) throws IOException {
            return new CustomServerSocket(port, 50 ,null);
        }

        /** {@inheritDoc} */
        @Override
        public ServerSocket createServerSocket(final int port, final int backlog) throws IOException {
            return new CustomServerSocket(port, backlog ,null);
        }

        /** {@inheritDoc} */
        @Override
        public ServerSocket createServerSocket(final int port, final int backlog, final InetAddress ifAddress)
            throws IOException {
            return new CustomServerSocket(port, backlog ,ifAddress);
        }

        /** ServerSocket wrapper class to track created sockets. */
        private class CustomServerSocket extends ServerSocket {

            /**
             * Constructor.
             *
             * @param port socket port
             * @param backlog listen backlog size
             * @param bindAddr socket address
             * 
             * @throws IOException on error
             */
            public CustomServerSocket(final int port, final int backlog, final InetAddress bindAddr)
                throws IOException {
                super(port, backlog, bindAddr);
            }

            /** {@inheritDoc} */
            @Override
            public Socket accept() throws IOException {
                final Socket socket = super.accept();
                sockets.add(socket);
                return socket;
            }
        }
    }

}