/*-
 * Copyright 2015 Square Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package main

import (
	"encoding/json"
	"errors"
	"net"
	"net/url"
	"os"
	"sync"
	"testing"
	"time"

	"github.com/square/ghostunnel/proxy"
	"github.com/stretchr/testify/assert"
)

func TestIntegrationMain(t *testing.T) {
	// This function serves as an entry point for running integration tests.
	// We're wrapping it in a test case so that we can record the test coverage.
	isIntegration := os.Getenv("GHOSTUNNEL_INTEGRATION_TEST")

	// Catch panics to make sure test exits normally and writes coverage
	// even if we got a crash (we might want to test error cases)
	defer func() {
		if err := recover(); err != nil {
			t.Error(err)
		}
	}()

	if isIntegration != "true" {
		return
	}

	finished := make(chan bool, 1)
	once := &sync.Once{}

	// override exit function for test, to make sure calls to exitFunc() don't
	// actually terminate the process and kill the test w/o capturing results.
	exitFunc = func(exit int) {
		once.Do(func() {
			if exit != 0 {
				t.Errorf("exit code from ghostunnel: %d", exit)
			}
		})
		finished <- true
		select {} // block
	}

	var wrappedArgs []string
	err := json.Unmarshal([]byte(os.Getenv("GHOSTUNNEL_INTEGRATION_ARGS")), &wrappedArgs)
	panicOnError(err)

	go func() {
		err := run(wrappedArgs)
		if err != nil {
			t.Errorf("got error from run: %s", err)
		}
		finished <- true
	}()

	select {
	case <-finished:
		return
	case <-time.Tick(10 * time.Minute):
		panic("timed out")
	}
}

func TestInitLoggerQuiet(t *testing.T) {
	originalLogger := logger
	err := initLogger(false, []string{"all"})
	assert.Nil(t, err)

	updatedLogger := logger
	assert.NotEqual(t, originalLogger, updatedLogger, "should have updated logger object")
	assert.NotNil(t, logger, "logger should never be nil after init")
}

func TestInitLoggerSyslog(t *testing.T) {
	originalLogger := logger
	err := initLogger(true, []string{})
	updatedLogger := logger
	if err != nil {
		// Tests running in containers often don't have access to syslog,
		// so we can't depend on syslog being available for testing. If we
		// get an error from the syslog setup we just warn and skip test.
		t.Logf("Error setting up syslog for test, skipping: %s", err)
		t.SkipNow()
		return
	}
	assert.NotEqual(t, originalLogger, updatedLogger, "should have updated logger object")
	assert.NotNil(t, logger, "logger should never be nil after init")
}

func TestPanicOnError(t *testing.T) {
	defer func() {
		if err := recover(); err == nil {
			t.Error("panicOnError should panic, but did not")
		}
	}()

	panicOnError(errors.New("error"))
}

func TestFlagValidation(t *testing.T) {
	*enableProf = true
	*statusAddress = ""
	err := validateFlags(nil)
	assert.NotNil(t, err, "--enable-pprof implies --status")

	*enableProf = false
	*metricsURL = "127.0.0.1"
	err = validateFlags(nil)
	assert.NotNil(t, err, "invalid --metrics-url should be rejected")
	*metricsURL = ""

	*timeoutDuration = 0
	err = validateFlags(nil)
	assert.NotNil(t, err, "invalid --connect-timeout should be rejected")
	*timeoutDuration = 10 * time.Second
}

func TestServerFlagValidation(t *testing.T) {
	*serverAllowAll = false
	*serverAllowedCNs = nil
	*serverAllowedOUs = nil
	*serverAllowedDNSs = nil
	*serverAllowedIPs = nil
	*serverAllowedURIs = nil
	err := serverValidateFlags()
	assert.NotNil(t, err, "invalid access control flags accepted")

	*serverAllowAll = true
	*serverAllowedCNs = []string{"test"}
	err = serverValidateFlags()
	assert.NotNil(t, err, "--allow-all and --allow-cn are mutually exclusive")

	*serverAllowedCNs = nil
	*serverAllowedOUs = []string{"test"}
	err = serverValidateFlags()
	assert.NotNil(t, err, "--allow-all and --allow-ou are mutually exclusive")

	*serverAllowedOUs = nil
	*serverAllowedDNSs = []string{"test"}
	err = serverValidateFlags()
	assert.NotNil(t, err, "--allow-all and --allow-dns-san are mutually exclusive")

	*serverAllowedDNSs = nil
	*serverAllowedIPs = []net.IP{net.IPv4(0, 0, 0, 0)}
	err = serverValidateFlags()
	assert.NotNil(t, err, "--allow-all and --allow-ip-san are mutually exclusive")

	*serverAllowedIPs = nil
	*serverAllowAll = true
	*serverDisableAuth = true
	err = serverValidateFlags()
	assert.NotNil(t, err, "--disable-authentication mutually exclusive with --allow-all and other server access control flags")

	*serverAllowedCNs = nil
	*serverAllowAll = true
	*serverDisableAuth = true
	err = serverValidateFlags()
	assert.NotNil(t, err, "--disable-authentication mutually exclusive with --allow-all and other server access control flags")

	*keystorePath = "file"
	*serverAllowedCNs = []string{"test"}
	*serverDisableAuth = false
	err = serverValidateFlags()
	assert.NotNil(t, err, "--allow-all mutually exclusive with other access control flags")
	*serverAllowedCNs = nil

	*serverAllowAll = false
	*serverUnsafeTarget = false
	*serverForwardAddress = "foo.com"
	err = serverValidateFlags()
	assert.NotNil(t, err, "unsafe target should be rejected")

	*certPath = "file"
	err = serverValidateFlags()
	assert.NotNil(t, err, "--cert also requires --key or should error")
	*certPath = ""

	test := "test"
	*keystorePath = "file"
	keychainIdentity = &test
	err = serverValidateFlags()
	assert.NotNil(t, err, "--keystore and --keychain-identity can't be set at the same time")

	*keystorePath = ""
	*certPath = "file"
	*keyPath = "file"
	err = serverValidateFlags()
	assert.NotNil(t, err, "--cert and --keychain-identity can't be set at the same time")
	*certPath = ""
	*keyPath = ""
	keychainIdentity = nil

	*keystorePath = "test"
	*serverDisableAuth = true
	*serverAllowAll = true
	err = serverValidateFlags()
	assert.NotNil(t, err, "can't use access control flags if auth is disabled")
	*serverDisableAuth = false

	*serverForwardAddress = "example.com:443"
	err = serverValidateFlags()
	assert.NotNil(t, err, "should reject non-local address if unsafe flag not set")

	*enabledCipherSuites = "ABC"
	*serverForwardAddress = "127.0.0.1:8080"
	err = serverValidateFlags()
	assert.NotNil(t, err, "invalid cipher suite option should be rejected")

	*enabledCipherSuites = "AES,CHACHA"
	*serverForwardAddress = ""
	*serverAllowAll = false
	*keystorePath = ""
}

func TestClientFlagValidation(t *testing.T) {
	*keystorePath = "file"
	*clientUnsafeListen = false
	*clientListenAddress = "0.0.0.0:8080"
	err := clientValidateFlags()
	assert.NotNil(t, err, "unsafe listen should be rejected")

	*clientDisableAuth = true
	err = clientValidateFlags()
	assert.NotNil(t, err, "--keystore can't be used with --disable-authentication")
	*clientDisableAuth = false

	test := "test"
	keychainIdentity = &test
	err = clientValidateFlags()
	assert.NotNil(t, err, "--keystore can't be used with --keychain-identity")
	keychainIdentity = nil

	*enabledCipherSuites = "ABC"
	*clientListenAddress = "127.0.0.1:8080"
	err = clientValidateFlags()
	assert.NotNil(t, err, "invalid cipher suite option should be rejected")

	invalidURL, _ := url.Parse("ftp://invalid")
	*enabledCipherSuites = "AES"
	*clientConnectProxy = invalidURL
	err = clientValidateFlags()
	assert.NotNil(t, err, "invalid connect proxy option should be rejected")

	*clientDisableAuth = false
	*keystorePath = ""
	err = clientValidateFlags()
	assert.NotNil(t, err, "one of --keystore or --disable-authentication is required")
}

func TestAllowsLocalhost(t *testing.T) {
	*serverUnsafeTarget = false
	assert.True(t, validateUnixOrLocalhost("localhost:1234"), "localhost should be allowed")
	assert.True(t, validateUnixOrLocalhost("127.0.0.1:1234"), "127.0.0.1 should be allowed")
	assert.True(t, validateUnixOrLocalhost("[::1]:1234"), "[::1] should be allowed")
	assert.True(t, validateUnixOrLocalhost("unix:/tmp/foo"), "unix:/tmp/foo should be allowed")
}

func TestDisallowsFooDotCom(t *testing.T) {
	*serverUnsafeTarget = false
	assert.False(t, validateUnixOrLocalhost("foo.com:1234"), "foo.com should be disallowed")
	assert.False(t, validateUnixOrLocalhost("alocalhost.com:1234"), "alocalhost.com should be disallowed")
	assert.False(t, validateUnixOrLocalhost("localhost.com.foo.com:1234"), "localhost.com.foo.com should be disallowed")
	assert.False(t, validateUnixOrLocalhost("74.122.190.83:1234"), "random ip address should be disallowed")
}

func TestServerBackendDialerError(t *testing.T) {
	*serverForwardAddress = "invalid"
	_, err := serverBackendDialer()
	assert.NotNil(t, err, "invalid forward address should not have dialer")
}

func TestInvalidCABundle(t *testing.T) {
	err := run([]string{
		"server",
		"--cacert", "/dev/null",
		"--target", "localhost:8080",
		"--keystore", "keystore.p12",
		"--listen", "localhost:8080",
	})
	assert.NotNil(t, err, "invalid CA bundle should exit with error")
}

func TestParseUnixOrTcpAddress(t *testing.T) {
	network, address, host, _ := parseUnixOrTCPAddress("unix:/tmp/foo")
	if network != "unix" {
		t.Errorf("unexpected network: %s", network)
	}
	if address != "/tmp/foo" {
		t.Errorf("unexpected address: %s", address)
	}
	if host != "" {
		t.Errorf("unexpected host: %s", host)
	}

	network, address, host, _ = parseUnixOrTCPAddress("localhost:8080")
	if network != "tcp" {
		t.Errorf("unexpected network: %s", network)
	}
	if address != "localhost:8080" {
		t.Errorf("unexpected address: %s", address)
	}
	if host != "localhost" {
		t.Errorf("unexpected host: %s", host)
	}

	_, _, _, err := parseUnixOrTCPAddress("localhost")
	assert.NotNil(t, err, "was able to parse invalid host/port")

	_, _, _, err = parseUnixOrTCPAddress("256.256.256.256:99999")
	assert.NotNil(t, err, "was able to parse invalid host/port")
}

func TestProxyLoggingFlags(t *testing.T) {
	assert.Equal(t, proxyLoggerFlags([]string{""}), proxy.LogEverything)
	assert.Equal(t, proxyLoggerFlags([]string{"conns"}), proxy.LogEverything & ^proxy.LogConnections)
	assert.Equal(t, proxyLoggerFlags([]string{"conn-errs"}), proxy.LogEverything & ^proxy.LogConnectionErrors)
	assert.Equal(t, proxyLoggerFlags([]string{"handshake-errs"}), proxy.LogEverything & ^proxy.LogHandshakeErrors)
	assert.Equal(t, proxyLoggerFlags([]string{"conns", "handshake-errs"}), proxy.LogConnectionErrors)
	assert.Equal(t, proxyLoggerFlags([]string{"conn-errs", "handshake-errs"}), proxy.LogConnections)
	assert.Equal(t, proxyLoggerFlags([]string{"conns", "conn-errs"}), proxy.LogHandshakeErrors)
}
