// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pgsql

import (
	"database/sql"
	"encoding/json"
	"reflect"
	"time"

	"github.com/guregu/null/zero"
	log "github.com/sirupsen/logrus"

	"github.com/coreos/clair/database"
	"github.com/coreos/clair/ext/versionfmt"
	"github.com/coreos/clair/pkg/commonerr"
)

// compareStringLists returns the strings that are present in X but not in Y.
func compareStringLists(X, Y []string) []string {
	m := make(map[string]bool)

	for _, y := range Y {
		m[y] = true
	}

	diff := []string{}
	for _, x := range X {
		if m[x] {
			continue
		}

		diff = append(diff, x)
		m[x] = true
	}

	return diff
}

func compareStringListsInBoth(X, Y []string) []string {
	m := make(map[string]struct{})

	for _, y := range Y {
		m[y] = struct{}{}
	}

	diff := []string{}
	for _, x := range X {
		if _, e := m[x]; e {
			diff = append(diff, x)
			delete(m, x)
		}
	}

	return diff
}

func (pgSQL *pgSQL) ListVulnerabilities(namespaceName string, limit int, startID int) ([]database.Vulnerability, int, error) {
	defer observeQueryTime("listVulnerabilities", "all", time.Now())

	// Query Namespace.
	var id int
	err := pgSQL.QueryRow(searchNamespace, namespaceName).Scan(&id)
	if err != nil {
		return nil, -1, handleError("searchNamespace", err)
	} else if id == 0 {
		return nil, -1, commonerr.ErrNotFound
	}

	// Query.
	query := searchVulnerabilityBase + searchVulnerabilityByNamespace
	rows, err := pgSQL.Query(query, namespaceName, startID, limit+1)
	if err != nil {
		return nil, -1, handleError("searchVulnerabilityByNamespace", err)
	}
	defer rows.Close()

	var vulns []database.Vulnerability
	nextID := -1
	size := 0
	// Scan query.
	for rows.Next() {
		var vulnerability database.Vulnerability

		err := rows.Scan(
			&vulnerability.ID,
			&vulnerability.Name,
			&vulnerability.Namespace.ID,
			&vulnerability.Namespace.Name,
			&vulnerability.Namespace.VersionFormat,
			&vulnerability.Description,
			&vulnerability.Link,
			&vulnerability.Severity,
			&vulnerability.Metadata,
		)
		if err != nil {
			return nil, -1, handleError("searchVulnerabilityByNamespace.Scan()", err)
		}
		size++
		if size > limit {
			nextID = vulnerability.ID
		} else {
			vulns = append(vulns, vulnerability)
		}
	}

	if err := rows.Err(); err != nil {
		return nil, -1, handleError("searchVulnerabilityByNamespace.Rows()", err)
	}

	return vulns, nextID, nil
}

func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vulnerability, error) {
	return findVulnerability(pgSQL, namespaceName, name, false)
}

func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bool) (database.Vulnerability, error) {
	defer observeQueryTime("findVulnerability", "all", time.Now())

	queryName := "searchVulnerabilityBase+searchVulnerabilityByNamespaceAndName"
	query := searchVulnerabilityBase + searchVulnerabilityByNamespaceAndName
	if forUpdate {
		queryName = queryName + "+searchVulnerabilityForUpdate"
		query = query + searchVulnerabilityForUpdate
	}

	return scanVulnerability(queryer, queryName, queryer.QueryRow(query, namespaceName, name))
}

func (pgSQL *pgSQL) findVulnerabilityByIDWithDeleted(id int) (database.Vulnerability, error) {
	defer observeQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now())

	queryName := "searchVulnerabilityBase+searchVulnerabilityByID"
	query := searchVulnerabilityBase + searchVulnerabilityByID

	return scanVulnerability(pgSQL, queryName, pgSQL.QueryRow(query, id))
}

func scanVulnerability(queryer Queryer, queryName string, vulnerabilityRow *sql.Row) (database.Vulnerability, error) {
	var vulnerability database.Vulnerability

	err := vulnerabilityRow.Scan(
		&vulnerability.ID,
		&vulnerability.Name,
		&vulnerability.Namespace.ID,
		&vulnerability.Namespace.Name,
		&vulnerability.Namespace.VersionFormat,
		&vulnerability.Description,
		&vulnerability.Link,
		&vulnerability.Severity,
		&vulnerability.Metadata,
	)

	if err != nil {
		return vulnerability, handleError(queryName+".Scan()", err)
	}

	if vulnerability.ID == 0 {
		return vulnerability, commonerr.ErrNotFound
	}

	// Query the FixedIn FeatureVersion now.
	rows, err := queryer.Query(searchVulnerabilityFixedIn, vulnerability.ID)
	if err != nil {
		return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err)
	}
	defer rows.Close()

	for rows.Next() {
		var featureVersionID zero.Int
		var featureVersionVersion zero.String
		var featureVersionFeatureName zero.String

		err := rows.Scan(
			&featureVersionVersion,
			&featureVersionID,
			&featureVersionFeatureName,
		)

		if err != nil {
			return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err)
		}

		if !featureVersionID.IsZero() {
			// Note that the ID we fill in featureVersion is actually a Feature ID, and not
			// a FeatureVersion ID.
			featureVersion := database.FeatureVersion{
				Model: database.Model{ID: int(featureVersionID.Int64)},
				Feature: database.Feature{
					Model:     database.Model{ID: int(featureVersionID.Int64)},
					Namespace: vulnerability.Namespace,
					Name:      featureVersionFeatureName.String,
				},
				Version: featureVersionVersion.String,
			}
			vulnerability.FixedIn = append(vulnerability.FixedIn, featureVersion)
		}
	}

	if err := rows.Err(); err != nil {
		return vulnerability, handleError("searchVulnerabilityFixedIn.Rows()", err)
	}

	return vulnerability, nil
}

// FixedIn.Namespace are not necessary, they are overwritten by the vuln.
// By setting the fixed version to minVersion, we can say that the vuln does'nt affect anymore.
func (pgSQL *pgSQL) InsertVulnerabilities(vulnerabilities []database.Vulnerability, generateNotifications bool) error {
	for _, vulnerability := range vulnerabilities {
		err := pgSQL.insertVulnerability(vulnerability, false, generateNotifications)
		if err != nil {
			return err
		}
	}
	return nil
}

func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability, onlyFixedIn, generateNotification bool) error {
	tf := time.Now()

	// Verify parameters
	if vulnerability.Name == "" || vulnerability.Namespace.Name == "" {
		return commonerr.NewBadRequestError("insertVulnerability needs at least the Name and the Namespace")
	}

	for i := 0; i < len(vulnerability.FixedIn); i++ {
		fifv := &vulnerability.FixedIn[i]

		if fifv.Feature.Namespace.Name == "" {
			// As there is no Namespace on that FixedIn FeatureVersion, set it to the Vulnerability's
			// Namespace.
			fifv.Feature.Namespace = vulnerability.Namespace
		} else if fifv.Feature.Namespace.Name != vulnerability.Namespace.Name {
			msg := "could not insert an invalid vulnerability that contains FixedIn FeatureVersion that are not in the same namespace as the Vulnerability"
			log.Warning(msg)
			return commonerr.NewBadRequestError(msg)
		}
	}

	// We do `defer observeQueryTime` here because we don't want to observe invalid vulnerabilities.
	defer observeQueryTime("insertVulnerability", "all", tf)

	// Begin transaction.
	tx, err := pgSQL.Begin()
	if err != nil {
		tx.Rollback()
		return handleError("insertVulnerability.Begin()", err)
	}

	// Find existing vulnerability and its Vulnerability_FixedIn_Features (for update).
	existingVulnerability, err := findVulnerability(tx, vulnerability.Namespace.Name, vulnerability.Name, true)
	if err != nil && err != commonerr.ErrNotFound {
		tx.Rollback()
		return err
	}

	if onlyFixedIn {
		// Because this call tries to update FixedIn FeatureVersion, import all other data from the
		// existing one.
		if existingVulnerability.ID == 0 {
			return commonerr.ErrNotFound
		}

		fixedIn := vulnerability.FixedIn
		vulnerability = existingVulnerability
		vulnerability.FixedIn = fixedIn
	}

	if existingVulnerability.ID != 0 {
		updateMetadata := vulnerability.Description != existingVulnerability.Description ||
			vulnerability.Link != existingVulnerability.Link ||
			vulnerability.Severity != existingVulnerability.Severity ||
			!reflect.DeepEqual(castMetadata(vulnerability.Metadata), existingVulnerability.Metadata)

		// Construct the entire list of FixedIn FeatureVersion, by using the
		// the FixedIn list of the old vulnerability.
		//
		// TODO(Quentin-M): We could use !updateFixedIn to just copy FixedIn/Affects rows from the
		// existing vulnerability in order to make metadata updates much faster.
		var updateFixedIn bool
		vulnerability.FixedIn, updateFixedIn = applyFixedInDiff(existingVulnerability.FixedIn, vulnerability.FixedIn)

		if !updateMetadata && !updateFixedIn {
			tx.Commit()
			return nil
		}

		// Mark the old vulnerability as non latest.
		_, err = tx.Exec(removeVulnerability, vulnerability.Namespace.Name, vulnerability.Name)
		if err != nil {
			tx.Rollback()
			return handleError("removeVulnerability", err)
		}
	} else {
		// The vulnerability is new, we don't want to have any
		// versionfmt.MinVersion as they are only used for diffing existing
		// vulnerabilities.
		var fixedIn []database.FeatureVersion
		for _, fv := range vulnerability.FixedIn {
			if fv.Version != versionfmt.MinVersion {
				fixedIn = append(fixedIn, fv)
			}
		}
		vulnerability.FixedIn = fixedIn
	}

	// Find or insert Vulnerability's Namespace.
	namespaceID, err := pgSQL.insertNamespace(vulnerability.Namespace)
	if err != nil {
		return err
	}

	// Insert vulnerability.
	err = tx.QueryRow(
		insertVulnerability,
		namespaceID,
		vulnerability.Name,
		vulnerability.Description,
		vulnerability.Link,
		&vulnerability.Severity,
		&vulnerability.Metadata,
	).Scan(&vulnerability.ID)

	if err != nil {
		tx.Rollback()
		return handleError("insertVulnerability", err)
	}

	// Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now.
	err = pgSQL.insertVulnerabilityFixedInFeatureVersions(tx, vulnerability.ID, vulnerability.FixedIn)
	if err != nil {
		tx.Rollback()
		return err
	}

	// Create a notification.
	if generateNotification {
		err = createNotification(tx, existingVulnerability.ID, vulnerability.ID)
		if err != nil {
			return err
		}
	}

	// Commit transaction.
	err = tx.Commit()
	if err != nil {
		tx.Rollback()
		return handleError("insertVulnerability.Commit()", err)
	}

	return nil
}

// castMetadata marshals the given database.MetadataMap and unmarshals it again to make sure that
// everything has the interface{} type.
// It is required when comparing crafted MetadataMap against MetadataMap that we get from the
// database.
func castMetadata(m database.MetadataMap) database.MetadataMap {
	c := make(database.MetadataMap)
	j, _ := json.Marshal(m)
	json.Unmarshal(j, &c)
	return c
}

// applyFixedInDiff applies a FeatureVersion diff on a FeatureVersion list and returns the result.
func applyFixedInDiff(currentList, diff []database.FeatureVersion) ([]database.FeatureVersion, bool) {
	currentMap, currentNames := createFeatureVersionNameMap(currentList)
	diffMap, diffNames := createFeatureVersionNameMap(diff)

	addedNames := compareStringLists(diffNames, currentNames)
	inBothNames := compareStringListsInBoth(diffNames, currentNames)

	different := false

	for _, name := range addedNames {
		if diffMap[name].Version == versionfmt.MinVersion {
			// MinVersion only makes sense when a Feature is already fixed in some version,
			// in which case we would be in the "inBothNames".
			continue
		}

		currentMap[name] = diffMap[name]
		different = true
	}

	for _, name := range inBothNames {
		fv := diffMap[name]

		if fv.Version == versionfmt.MinVersion {
			// MinVersion means that the Feature doesn't affect the Vulnerability anymore.
			delete(currentMap, name)
			different = true
		} else if fv.Version != currentMap[name].Version {
			// The version got updated.
			currentMap[name] = diffMap[name]
			different = true
		}
	}

	// Convert currentMap to a slice and return it.
	var newList []database.FeatureVersion
	for _, fv := range currentMap {
		newList = append(newList, fv)
	}

	return newList, different
}

func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string]database.FeatureVersion, []string) {
	m := make(map[string]database.FeatureVersion, 0)
	s := make([]string, 0, len(features))

	for i := 0; i < len(features); i++ {
		featureVersion := features[i]
		m[featureVersion.Feature.Name] = featureVersion
		s = append(s, featureVersion.Feature.Name)
	}

	return m, s
}

// insertVulnerabilityFixedInFeatureVersions populates Vulnerability_FixedIn_Feature for the given
// vulnerability with the specified database.FeatureVersion list and uses
// linkVulnerabilityToFeatureVersions to propagate the changes on Vulnerability_FixedIn_Feature to
// Vulnerability_Affects_FeatureVersion.
func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulnerabilityID int, fixedIn []database.FeatureVersion) error {
	defer observeQueryTime("insertVulnerabilityFixedInFeatureVersions", "all", time.Now())

	// Insert or find the Features.
	// TODO(Quentin-M): Batch me.
	var err error
	var features []*database.Feature
	for i := 0; i < len(fixedIn); i++ {
		features = append(features, &fixedIn[i].Feature)
	}
	for _, feature := range features {
		if feature.ID == 0 {
			if feature.ID, err = pgSQL.insertFeature(*feature); err != nil {
				return err
			}
		}
	}

	// Lock Vulnerability_Affects_FeatureVersion exclusively.
	// We want to prevent InsertFeatureVersion to modify it.
	promConcurrentLockVAFV.Inc()
	defer promConcurrentLockVAFV.Dec()
	t := time.Now()
	_, err = tx.Exec(lockVulnerabilityAffects)
	observeQueryTime("insertVulnerability", "lock", t)

	if err != nil {
		tx.Rollback()
		return handleError("insertVulnerability.lockVulnerabilityAffects", err)
	}

	for _, fv := range fixedIn {
		var fixedInID int
		var created bool

		// Find or create entry in Vulnerability_FixedIn_Feature.
		err = tx.QueryRow(
			soiVulnerabilityFixedInFeature,
			vulnerabilityID, fv.Feature.ID,
			&fv.Version,
		).Scan(&created, &fixedInID)

		if err != nil {
			return handleError("insertVulnerabilityFixedInFeature", err)
		}

		if !created {
			// The relationship between the feature and the vulnerability already
			// existed, no need to update Vulnerability_Affects_FeatureVersion.
			continue
		}

		// Insert Vulnerability_Affects_FeatureVersion.
		err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerabilityID, fv.Feature.ID, fv.Feature.Namespace.VersionFormat, fv.Version)
		if err != nil {
			return err
		}
	}

	return nil
}

func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, versionFormat, fixedInVersion string) error {
	// Find every FeatureVersions of the Feature that the vulnerability affects.
	// TODO(Quentin-M): LIMIT
	rows, err := tx.Query(searchFeatureVersionByFeature, featureID)
	if err != nil {
		return handleError("searchFeatureVersionByFeature", err)
	}
	defer rows.Close()

	var affecteds []database.FeatureVersion
	for rows.Next() {
		var affected database.FeatureVersion

		err := rows.Scan(&affected.ID, &affected.Version)
		if err != nil {
			return handleError("searchFeatureVersionByFeature.Scan()", err)
		}

		cmp, err := versionfmt.Compare(versionFormat, affected.Version, fixedInVersion)
		if err != nil {
			return err
		}
		if cmp < 0 {
			// The version of the FeatureVersion is lower than the fixed version of this vulnerability,
			// thus, this FeatureVersion is affected by it.
			affecteds = append(affecteds, affected)
		}
	}
	if err = rows.Err(); err != nil {
		return handleError("searchFeatureVersionByFeature.Rows()", err)
	}
	rows.Close()

	// Insert into Vulnerability_Affects_FeatureVersion.
	for _, affected := range affecteds {
		// TODO(Quentin-M): Batch me.
		_, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, vulnerabilityID, affected.ID, fixedInID)
		if err != nil {
			return handleError("insertVulnerabilityAffectsFeatureVersion", err)
		}
	}

	return nil
}

func (pgSQL *pgSQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []database.FeatureVersion) error {
	defer observeQueryTime("InsertVulnerabilityFixes", "all", time.Now())

	v := database.Vulnerability{
		Name: vulnerabilityName,
		Namespace: database.Namespace{
			Name: vulnerabilityNamespace,
		},
		FixedIn: fixes,
	}

	return pgSQL.insertVulnerability(v, true, true)
}

func (pgSQL *pgSQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error {
	defer observeQueryTime("DeleteVulnerabilityFix", "all", time.Now())

	v := database.Vulnerability{
		Name: vulnerabilityName,
		Namespace: database.Namespace{
			Name: vulnerabilityNamespace,
		},
		FixedIn: []database.FeatureVersion{
			{
				Feature: database.Feature{
					Name: featureName,
					Namespace: database.Namespace{
						Name: vulnerabilityNamespace,
					},
				},
				Version: versionfmt.MinVersion,
			},
		},
	}

	return pgSQL.insertVulnerability(v, true, true)
}

func (pgSQL *pgSQL) DeleteVulnerability(namespaceName, name string) error {
	defer observeQueryTime("DeleteVulnerability", "all", time.Now())

	// Begin transaction.
	tx, err := pgSQL.Begin()
	if err != nil {
		tx.Rollback()
		return handleError("DeleteVulnerability.Begin()", err)
	}

	var vulnerabilityID int
	err = tx.QueryRow(removeVulnerability, namespaceName, name).Scan(&vulnerabilityID)
	if err != nil {
		tx.Rollback()
		return handleError("removeVulnerability", err)
	}

	// Create a notification.
	err = createNotification(tx, vulnerabilityID, 0)
	if err != nil {
		return err
	}

	// Commit transaction.
	err = tx.Commit()
	if err != nil {
		tx.Rollback()
		return handleError("DeleteVulnerability.Commit()", err)
	}

	return nil
}
