package cert

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"encoding/pem"
	"errors"
	"fmt"
	"io"
	"os"
	"time"

	"github.com/spf13/cobra"
	pkgCommon "github.com/stackrox/rox/pkg/roxctl/common"
	"github.com/stackrox/rox/pkg/tlsutils"
	"github.com/stackrox/rox/pkg/utils"
	"github.com/stackrox/rox/roxctl/common/environment"
	"github.com/stackrox/rox/roxctl/common/flags"
	"github.com/stackrox/rox/roxctl/common/util"
)

type centralCertCommand struct {
	// Properties that are bound to cobra flags.
	filename string

	// Properties that are injected or constructed.
	env     environment.Environment
	timeout time.Duration
}

// Command defines the cert command tree
func Command(cliEnvironment environment.Environment) *cobra.Command {
	centralCertCommand := &centralCertCommand{env: cliEnvironment}
	cbr := &cobra.Command{
		Use: "cert",
		RunE: util.RunENoArgs(func(cmd *cobra.Command) error {
			if err := centralCertCommand.construct(cmd); err != nil {
				return err
			}
			return centralCertCommand.certs()
		}),
	}

	cbr.Flags().StringVar(&centralCertCommand.filename, "output", "-", "Filename to output PEM certificate to; '-' for stdout")
	flags.AddTimeout(cbr)
	return cbr
}

func (cmd *centralCertCommand) construct(cbr *cobra.Command) error {
	cmd.timeout = flags.Timeout(cbr)
	return nil
}

func (cmd *centralCertCommand) certs() error {
	// Parse out the endpoint and server name for connecting to.
	endpoint, serverName, err := cmd.env.ConnectNames()
	if err != nil {
		return err
	}

	// Connect to the given server. We're not expecting the endpoint be
	// trusted, but force the user to use insecure mode if needed.
	config := tls.Config{
		InsecureSkipVerify: skipTLSValidation(),
		ServerName:         serverName,
	}
	ctx, cancel := context.WithTimeout(pkgCommon.Context(), cmd.timeout)
	defer cancel()
	conn, err := tlsutils.DialContext(ctx, "tcp", endpoint, &config)
	if err != nil {
		return err
	}
	defer utils.IgnoreError(conn.Close)

	// Verify that at least 1 certificate was obtained from the connection.
	certs := conn.ConnectionState().PeerCertificates
	if len(certs) == 0 {
		return errors.New("server returned no certificates")
	}

	// "File" to output PEM certificate to.
	var handle io.WriteCloser

	switch cmd.filename {
	case "-":
		// Default to STDOUT.
		handle = os.Stdout
	default:
		// Open the given filename.
		handle, err = os.Create(cmd.filename)
		if err != nil {
			return err
		}
	}

	// Print out information about the leaf cert to STDERR.
	writeCertInfo(os.Stderr, certs[0])

	// Write out the leaf cert in PEM format.
	if err := writeCertPEM(handle, certs[0]); err != nil {
		return err
	}
	return handle.Close()
}

func skipTLSValidation() bool {
	if value := flags.SkipTLSValidation(); value != nil {
		return *value
	}
	return false
}

func writeCertPEM(writer io.Writer, cert *x509.Certificate) error {
	var pemkey = &pem.Block{
		Type:  "CERTIFICATE",
		Bytes: cert.Raw,
	}
	if err := pem.Encode(writer, pemkey); err != nil {
		return err
	}
	return nil
}

func writeCertInfo(writer io.Writer, cert *x509.Certificate) {
	fmt.Fprintf(writer, "Issuer:  %v\n", cert.Issuer)
	fmt.Fprintf(writer, "Subject: %v\n", cert.Subject)
	fmt.Fprintf(writer, "Not valid before: %v\n", cert.NotBefore)
	fmt.Fprintf(writer, "Not valid after:  %v\n", cert.NotAfter)
}
