// Copyright 2019 The OPA Authors.  All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.

package e2e

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"net/url"
	"os"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/sirupsen/logrus"

	"github.com/open-policy-agent/opa/runtime"
	"github.com/open-policy-agent/opa/util"
)

const (
	defaultAddr = ":0" // default listening address for server, use a random open port
)

// NewAPIServerTestParams creates a new set of runtime.Params with enough
// default values filled in to start the server. Options can/should
// be customized for the test case.
func NewAPIServerTestParams() runtime.Params {
	params := runtime.NewParams()

	// Add in some defaults
	params.Addrs = &[]string{defaultAddr}

	params.Logging = runtime.LoggingConfig{
		Level:  "debug",
		Format: "json-pretty",
	}

	params.GracefulShutdownPeriod = 10 // seconds

	return params
}

// TestRuntime holds metadata and provides helper methods
// to interact with the runtime being tested.
type TestRuntime struct {
	Params  runtime.Params
	Runtime *runtime.Runtime
	Ctx     context.Context
	Cancel  context.CancelFunc
	Client  *http.Client
	url     string
	diagURL string
	urlMtx  *sync.Mutex
}

// NewTestRuntime returns a new TestRuntime which
func NewTestRuntime(params runtime.Params) (*TestRuntime, error) {
	ctx := context.Background()
	ctx, cancel := context.WithCancel(ctx)

	rt, err := runtime.NewRuntime(ctx, params)
	if err != nil {
		cancel()
		return nil, fmt.Errorf("unable to create new runtime: %s", err)
	}

	return &TestRuntime{
		Params:  params,
		Runtime: rt,
		Ctx:     ctx,
		Cancel:  cancel,
		Client:  &http.Client{},
		urlMtx:  new(sync.Mutex),
	}, nil
}

// WrapRuntime creates a new TestRuntime by wrapping an existing runtime
func WrapRuntime(ctx context.Context, cancel context.CancelFunc, rt *runtime.Runtime) *TestRuntime {
	return &TestRuntime{
		Params:  rt.Params,
		Runtime: rt,
		Ctx:     ctx,
		Cancel:  cancel,
		Client:  &http.Client{},
		urlMtx:  new(sync.Mutex),
	}
}

// RunAPIServerTests will start the OPA runtime serving with a given
// configuration. This is essentially a wrapper for `m.Run()` that
// handles starting and stopping the local API server. The return
// value is what should be used as the code in `os.Exit` in the
// `TestMain` function.
// Deprecated: Use RunTests instead
func (t *TestRuntime) RunAPIServerTests(m *testing.M) int {
	return t.runTests(m, true)
}

// RunAPIServerBenchmarks will start the OPA runtime and do
// `m.Run()` similar to how RunAPIServerTests works. This
// will suppress logging output on stdout to prevent the tests
// from being overly verbose. If log output is desired set
// the `test.v` flag.
// Deprecated: Use RunTests instead
func (t *TestRuntime) RunAPIServerBenchmarks(m *testing.M) int {
	return t.runTests(m, !testing.Verbose())
}

// RunTests will start the OPA runtime serving with a given
// configuration. This is essentially a wrapper for `m.Run()` that
// handles starting and stopping the local API server. The return
// value is what should be used as the code in `os.Exit` in the
// `TestMain` function.
func (t *TestRuntime) RunTests(m *testing.M) int {
	return t.runTests(m, !testing.Verbose())
}

// URL will return the URL that the server is listening on. If
// the server hasn't started listening this will return an empty string.
// It is not expected for the URL to change throughout the lifetime of the
// TestRuntime. Runtimes configured with >1 address will only get the
// first URL.
func (t *TestRuntime) URL() string {
	if t.url != "" {
		// fast path once it has been computed
		return t.url
	}

	t.urlMtx.Lock()
	defer t.urlMtx.Unlock()

	// check again in the lock, it might have changed on us..
	if t.url != "" {
		return t.url
	}

	addrs := t.Runtime.Addrs()
	if len(addrs) == 0 {
		return ""
	}
	// Just pick the first one, if a test was configured with >1 they
	// will need to determine the URLs themselves.
	addr := addrs[0]

	parsed, err := t.AddrToURL(addr)
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}

	t.url = parsed

	return t.url
}

// AddrToURL generates a full URL from an address, as configured on the runtime.
// This can include fully qualified urls, just host/ip, with port, or only port
// (eg, "localhost", ":8181", "http://foo", etc). If the runtime is configured
// with HTTPS certs it will generate an appropriate URL.
func (t *TestRuntime) AddrToURL(addr string) (string, error) {
	if strings.HasPrefix(addr, ":") {
		addr = "localhost" + addr
	}

	if !strings.Contains(addr, "://") {
		scheme := "http://"
		if t.Params.Certificate != nil {
			scheme = "https://"
		}
		addr = scheme + addr
	}

	parsed, err := url.Parse(addr)
	if err != nil {
		return "", fmt.Errorf("failed to parse listening address of server: %s", err)
	}

	return parsed.String(), nil
}

func (t *TestRuntime) runTests(m *testing.M, suppressLogs bool) int {
	// Start serving API requests in the background
	done := make(chan error)
	go func() {
		// Suppress the stdlogger in the server
		if suppressLogs {
			logrus.SetOutput(ioutil.Discard)
		}
		err := t.Runtime.Serve(t.Ctx)
		done <- err
	}()

	// Turns out this thread gets a different stdlogger
	// so we need to set the output on it here too.
	if suppressLogs {
		logrus.SetOutput(ioutil.Discard)
	}

	// wait for the server to be ready
	err := t.WaitForServer()
	if err != nil {
		return 1
	}

	// Actually run the unit tests/benchmarks
	errc := m.Run()

	// Wait for the API server to stop
	t.Cancel()
	err = <-done

	if err != nil && errc == 0 {
		// even if the tests passed return an error code if
		// the server encountered an error
		errc = 1
	}

	return errc
}

// WaitForServer will block until the server is running and passes a health check.
func (t *TestRuntime) WaitForServer() error {
	delay := time.Duration(100) * time.Millisecond
	retries := 100 // 10 seconds before we give up
	for i := 0; i < retries; i++ {
		// First make sure it has started listening and we have an address
		if t.URL() != "" {
			// Then make sure it has started serving
			err := t.HealthCheck(t.URL())
			if err == nil {
				logrus.Infof("Test server ready and listening on: %s", t.URL())
				return nil
			}
		}
		time.Sleep(delay)
	}
	return fmt.Errorf("API Server not ready in time")
}

// UploadPolicy will upload the given policy to the runtime via the v1 policy API
func (t *TestRuntime) UploadPolicy(name string, policy io.Reader) error {
	req, err := http.NewRequest("PUT", t.URL()+"/v1/policies/"+name, policy)
	if err != nil {
		return fmt.Errorf("Unexpected error creating request: %s", err)
	}
	resp, err := t.Client.Do(req)
	if err != nil {
		return fmt.Errorf("Failed to PUT the test policy: %s", err)
	}
	if resp.StatusCode != http.StatusOK {
		return fmt.Errorf("Unexpected response: %d %s", resp.StatusCode, resp.Status)
	}
	return nil
}

// UploadData will upload the given data to the runtime via the v1 data API
func (t *TestRuntime) UploadData(data io.Reader) error {
	client := &http.Client{}
	req, err := http.NewRequest("PUT", t.URL()+"/v1/data", data)
	if err != nil {
		return fmt.Errorf("Unexpected error creating request: %s", err)
	}

	resp, err := client.Do(req)
	if err != nil {
		return fmt.Errorf("Failed to PUT data: %s", err)
	}
	if resp.StatusCode != http.StatusNoContent {
		return fmt.Errorf("Unexpected response: %d %s", resp.StatusCode, resp.Status)
	}
	return nil
}

// GetDataWithInput will use the v1 data API and POST with the given input. The returned
// value is the full response body.
func (t *TestRuntime) GetDataWithInput(path string, input interface{}) ([]byte, error) {
	inputPayload := util.MustMarshalJSON(map[string]interface{}{
		"input": input,
	})

	path = strings.TrimPrefix(path, "/")
	if !strings.HasPrefix(path, "data") {
		path = "data/" + path
	}

	resp, err := t.GetDataWithRawInput(t.URL()+"/v1/"+path, bytes.NewReader(inputPayload))
	if err != nil {
		return nil, err
	}

	body, err := ioutil.ReadAll(resp)
	if err != nil {
		return nil, fmt.Errorf("unexpected error reading response body: %s", err)
	}

	return body, nil
}

// GetDataWithRawInput will use the v1 data API and POST with the given input. The returned
// value is the full response body.
func (t *TestRuntime) GetDataWithRawInput(url string, input io.Reader) (io.Reader, error) {
	resp, err := http.Post(url, "application/json", input)
	if err != nil {
		return nil, fmt.Errorf("unexpected error: %s", err)
	}
	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("unexpected response status: %d %s", resp.StatusCode, resp.Status)
	}
	return resp.Body, nil
}

// GetDataWithInputTyped returns an unmarshalled response from GetDataWithInput.
func (t *TestRuntime) GetDataWithInputTyped(path string, input interface{}, response interface{}) error {

	bs, err := t.GetDataWithInput(path, input)
	if err != nil {
		return err
	}

	return json.Unmarshal(bs, response)
}

// HealthCheck will query /health and return an error if the server is not healthy
func (t *TestRuntime) HealthCheck(url string, params ...string) error {
	reqURL := url + "/health"
	if len(params) > 0 {
		reqURL += "?" + strings.Join(params, "&")
	}
	req, err := http.NewRequest("GET", url+"/health", nil)
	if err != nil {
		return fmt.Errorf("unexpected error creating request: %s", err)
	}
	resp, err := t.Client.Do(req)
	if err != nil {
		return fmt.Errorf("unexpected error: %s", err)
	}
	if resp.StatusCode != http.StatusOK {
		return fmt.Errorf("unexpected response: %d %s", resp.StatusCode, resp.Status)
	}
	return nil
}
