/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.shared.spring.security.factory;

import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.KeyStore.PrivateKeyEntry;
import java.security.NoSuchProviderException;
import java.security.PrivateKey;
import java.security.Provider;
import java.security.Security;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.slf4j.Logger;
import org.springframework.beans.factory.FactoryBean;

import net.shibboleth.shared.annotation.constraint.NotEmpty;
import net.shibboleth.shared.primitive.LoggerFactory;

/**
 * Spring bean factory for extracting a {@link PrivateKey} from a PKCS#11 keystore.
 * 
 * This relies on the SunPKCS11 provider.
 */
public class PKCS11PrivateKeyFactoryBean implements FactoryBean<PrivateKey> {

    /** The name for the base PKCS#11 provider. */
    @Nonnull @NotEmpty private static final String UNCONFIGURED_PROVIDER_NAME = "SunPKCS11";

    /** Singleton {@link Provider} for all instances of this factory. */
    @Nullable private static Provider provider;

    /** Class logger. */
    @Nonnull private final Logger log = LoggerFactory.getLogger(PKCS11PrivateKeyFactoryBean.class);

    /** PKCS#11 provider parameter string. */
    @Nullable private String pkcs11Config;

    /** Alias for the private key. */
    @Nullable private String keyAlias;

    /** Password for the private key. */
    @Nullable private String keyPassword;

    /** The singleton instance of the private key produced by this factory. */
    @Nullable private PrivateKey key;

    /**
     * Returns the PKCS#11 configuration.
     * 
     * @return returns the PKCS#11 configuration.
     */
    @Nullable public String getPkcs11Config() {
        return pkcs11Config;
    }

    /**
     * Sets the PKCS#11 configuration to use.
     * 
     * @param config the PKCS#11 configuration to use
     */
    public void setPkcs11Config(@Nullable final String config) {
        pkcs11Config = config;
    }

    /**
     * Gets the key alias in use.
     * 
     * @return returns the key alias in use
     */
    @Nullable public String getKeyAlias() {
        return keyAlias;
    }

    /**
     * Sets the key alias to use.
     * 
     * @param alias the key alias to use
     */
    public void setKeyAlias(@Nullable final String alias) {
        keyAlias = alias;
    }

    /**
     * Gets the key password in use.
     * 
     * @return returns the key password in use
     */
    @Nullable public String getKeyPassword() {
        return keyPassword;
    }

    /**
     * Set the key password to use.
     * 
     * @param password the key password to use
     */
    public void setKeyPassword(@Nullable final String password) {
        keyPassword = password;
    }

    /**
     * Gets the singleton PKCS#11 {@link Provider}.
     * 
     * The constructed {@link Provider} is also added to the system's list of providers.
     * 
     * @return the singleton {@link Provider}
     * @throws Exception if something goes wrong building the {@link Provider}
     */
    @Nonnull private Provider getProvider() throws Exception {
        if (provider == null) {
            final var baseProvider = Security.getProvider(UNCONFIGURED_PROVIDER_NAME);
            if (baseProvider == null) {
                throw new NoSuchProviderException("could not acquire PKCS#11 bridge: " + UNCONFIGURED_PROVIDER_NAME);
            }
            provider = baseProvider.configure(pkcs11Config);
            Security.addProvider(provider);
        }
        assert provider != null;
        return provider;
    }

    /**
     * Gets a PKCS#11 {@link KeyStore} from the {@link Provider}.
     * 
     * @return the {@link KeyStore}
     * @throws Exception if something goes wrong building the keystore
     */
    @Nonnull private KeyStore getKeyStore() throws Exception {
        final KeyStore keystore = KeyStore.getInstance("PKCS11", getProvider());

        log.debug("Initializing PKCS11 keystore");
        keystore.load(null, keyPassword != null ? keyPassword.toCharArray() : null);
        return keystore;
    }

    @Override
    @Nonnull public PrivateKey getObject() throws Exception {
        if (key == null) {
            if (keyPassword == null) {
                throw new GeneralSecurityException("Key password was null");
            }
            final KeyStore keystore = getKeyStore();

            assert keyPassword != null;
            final KeyStore.Entry keyEntry = keystore.getEntry(keyAlias,
                    new KeyStore.PasswordProtection(keyPassword.toCharArray()));
            if (keyEntry == null) {
                throw new GeneralSecurityException("entry " + keyAlias + " not found");
            }
            
            if (keyEntry instanceof PrivateKeyEntry) {
                final PrivateKeyEntry privKeyEntry = (PrivateKeyEntry) keyEntry;
                key = privKeyEntry.getPrivateKey();
            } else {
                throw new GeneralSecurityException("entry " + keyAlias + " is not a private key entry");
            }
        }

        assert key != null;
        return key;
    }

    @Override
    @Nonnull public Class<?> getObjectType() {
        return PrivateKey.class;
    }

    @Override
    public boolean isSingleton() {
        return true;
    }

}
