package requestmgr

import (
	"context"
	"fmt"
	"time"

	"github.com/pkg/errors"
	componentCVEEdgeDataStore "github.com/stackrox/rox/central/componentcveedge/datastore"
	cveDataStore "github.com/stackrox/rox/central/cve/datastore"
	deploymentDataStore "github.com/stackrox/rox/central/deployment/datastore"
	imgDataStore "github.com/stackrox/rox/central/image/datastore"
	imgCVEEdgeDataStore "github.com/stackrox/rox/central/imagecveedge/datastore"
	"github.com/stackrox/rox/central/reprocessor"
	"github.com/stackrox/rox/central/role/resources"
	"github.com/stackrox/rox/central/sensor/service/connection"
	"github.com/stackrox/rox/central/vulnerabilityrequest/cache"
	"github.com/stackrox/rox/central/vulnerabilityrequest/common"
	vulnReqDataStore "github.com/stackrox/rox/central/vulnerabilityrequest/datastore"
	"github.com/stackrox/rox/central/vulnerabilityrequest/utils"
	"github.com/stackrox/rox/central/vulnerabilityrequest/validator"
	v1 "github.com/stackrox/rox/generated/api/v1"
	"github.com/stackrox/rox/generated/internalapi/central"
	"github.com/stackrox/rox/generated/storage"
	"github.com/stackrox/rox/pkg/batcher"
	"github.com/stackrox/rox/pkg/concurrency"
	"github.com/stackrox/rox/pkg/errorhelpers"
	"github.com/stackrox/rox/pkg/features"
	"github.com/stackrox/rox/pkg/sac"
	"github.com/stackrox/rox/pkg/search"
	deploymentOptionsMap "github.com/stackrox/rox/pkg/search/options/deployments"
	"github.com/stackrox/rox/pkg/search/scoped"
	"github.com/stackrox/rox/pkg/uuid"
)

var (
	batchSize = 1000

	// Give the processor access as an approver so that it can properly expire
	allAccessCtx             = sac.WithAllAccess(context.Background())
	allVulnApproverAccessSac = sac.WithGlobalAccessScopeChecker(context.Background(),
		sac.AllowFixedScopes(
			sac.AccessModeScopeKeys(storage.Access_READ_ACCESS, storage.Access_READ_WRITE_ACCESS),
			sac.ResourceScopeKeys(resources.VulnerabilityManagementApprovals)))

	clusterIDField = deploymentOptionsMap.OptionsMap.MustGet(search.ClusterID.String())
)

type managerImpl struct {
	deployments       deploymentDataStore.DataStore
	images            imgDataStore.DataStore
	imageCVEEdges     imgCVEEdgeDataStore.DataStore
	cves              cveDataStore.DataStore
	componentCVEEdges componentCVEEdgeDataStore.DataStore
	vulnReqs          vulnReqDataStore.DataStore
	connManager       connection.Manager
	reprocessor       reprocessor.Loop
	activeReqCache    cache.VulnReqCache
	pendingReqCache   cache.VulnReqCache

	reObserveTimedDeferralsTickerDuration     time.Duration
	reObserveWhenFixedDeferralsTickerDuration time.Duration

	stopSig concurrency.Signal
	stopped concurrency.Signal
}

func (m *managerImpl) Start() {
	if !features.VulnRiskManagement.Enabled() {
		return
	}
	if err := m.buildCache(); err != nil {
		log.Errorf("Could not build vulnerability request cache. Vulnerability snoozing and unsnoozing may not work correctly: %v", err)
	}
	go m.runExpiredDeferralsProcessor()
}

func (m *managerImpl) Stop() {
	m.stopSig.Signal()
	m.stopped.Wait()
}

func (m *managerImpl) Create(ctx context.Context, req *storage.VulnerabilityRequest) error {
	if err := validator.ValidateNewSuppressVulnRequest(req); err != nil {
		return errors.Wrap(errorhelpers.ErrInvalidArgs, err.Error())
	}

	// Find all requests for this CVE that are active and approved
	reqs, err := m.vulnReqs.SearchRawRequests(ctx, utils.GetQueryForApprovedReqsWithSimilarScope(req.GetCves().GetIds()...))
	if err != nil {
		return errors.Wrap(err, "could not search for other vulnerability requests")
	}
	// And validate that this CVE + VulnReqScope combo hasn't been unwatched already
	if err := validator.ValidateSuppressVulnRequestIsUnique(req, reqs); err != nil {
		return errors.Wrap(errorhelpers.ErrAlreadyExists, err.Error())
	}

	req.Id = uuid.NewV4().String()
	if err := m.vulnReqs.AddRequest(ctx, req); err != nil {
		return errors.Wrap(err, "could not create vulnerability request")
	}
	if utils.IsPending(req) {
		m.pendingReqCache.Add(req)
	}
	return nil
}

func (m *managerImpl) Approve(ctx context.Context, id string, reqParams *common.VulnRequestParams) (*storage.VulnerabilityRequest, error) {
	if reqParams == nil || reqParams.Comment == "" {
		return nil, errors.Wrap(errorhelpers.ErrInvalidArgs, "comment must be provided")
	}
	req, err := m.vulnReqs.UpdateRequestStatus(ctx, id, reqParams.Comment, storage.RequestStatus_APPROVED)
	if err != nil {
		return nil, errors.Wrapf(err, "approving vulnerability request %s", id)
	}
	if err := m.SnoozeVulnerabilityOnRequest(ctx, req); err != nil {
		return nil, errors.Wrapf(err, "approving vulnerability request %s", id)
	}
	return req, nil
}

func (m *managerImpl) Delete(ctx context.Context, id string) error {
	if err := m.vulnReqs.RemoveRequest(ctx, id); err != nil {
		return err
	}
	m.pendingReqCache.Remove(id)
	// We do not allow deleting active requests. Only pending requests and pending request updates can be removed.
	// Hence, skip the active cache.
	return nil
}

func (m *managerImpl) Deny(ctx context.Context, id string, reqParams *common.VulnRequestParams) (*storage.VulnerabilityRequest, error) {
	if reqParams.Comment == "" {
		return nil, errors.Wrap(errorhelpers.ErrInvalidArgs, "comment must be provided")
	}
	req, err := m.vulnReqs.UpdateRequestStatus(ctx, id, reqParams.Comment, storage.RequestStatus_DENIED)
	if err != nil {
		return nil, errors.Wrapf(err, "denying vulnerability request %s", id)
	}
	// Request (Request update) that is not approved can only be in pending cache.
	m.pendingReqCache.Remove(req.GetId())
	// No need to unsnooze or snooze the vulns. A denial state is reached only from pending state which has
	// no effect on policy or risk workflow.
	return req, nil
}

func (m *managerImpl) Undo(ctx context.Context, id string, reqParams *common.VulnRequestParams) (*storage.VulnerabilityRequest, error) {
	req, err := m.vulnReqs.MarkRequestInactive(ctx, id, "[System Generated] Request undone")
	if err != nil {
		return nil, errors.Wrapf(err, "undoing vulnerability request %s", id)
	}
	if err := m.UnSnoozeVulnerabilityOnRequest(ctx, req); err != nil {
		return nil, errors.Wrapf(err, "undoing vulnerability request %s", id)
	}
	return req, nil
}

func (m *managerImpl) UpdateExpiry(ctx context.Context, id string, reqParams *common.VulnRequestParams) (*storage.VulnerabilityRequest, error) {
	// Currently, only deferral requests can be updated
	if reqParams.Expiry == nil {
		return nil, errors.Wrap(errorhelpers.ErrInvalidArgs, "nothing to update for request - at least expiry must be provided")
	}
	req, err := m.vulnReqs.UpdateRequestExpiry(ctx, id, reqParams.Comment, reqParams.Expiry)
	if err != nil {
		return nil, errors.Wrapf(err, "updating vulnerability request %s", id)
	}
	m.pendingReqCache.Add(req)
	// No need to unsnooze or snooze the vulns. This update is still pending, therefore, it does not take effect
	// until approval. Meanwhile, the original request config remains in effect if it was approved.
	return req, nil
}

// SnoozeVulnerabilityOnRequest snoozes the CVE for the scope specified by the request
// Snoozed vulns won't result in a policy violation nor will it be included in risk calculation.
func (m *managerImpl) SnoozeVulnerabilityOnRequest(_ context.Context, request *storage.VulnerabilityRequest) error {
	// Only snooze the vulns if the request was approved and not expired
	if request.GetExpired() || request.GetStatus() != storage.RequestStatus_APPROVED {
		return errors.Errorf("vulnerability request %s not approved or expired", request.GetId())
	}
	// Add to the activeReqCache first because the request could be for images not detected in system.
	m.pendingReqCache.Remove(request.GetId())
	m.activeReqCache.Add(request)

	// Search for images matching the scope.
	// Validation of image-cve existence is performed by the image-cve datastore.
	imageIDs, err := m.getImagesIDsForVulnRequest(request)
	if err != nil {
		return errors.Wrapf(err, "could not fetch images matching vulnerability request %s", request.GetId())
	}
	if len(imageIDs) == 0 {
		return nil
	}

	for _, imageID := range imageIDs {
		img, found, err := m.images.GetImageMetadata(allAccessCtx, imageID)
		if err != nil {
			return errors.Wrapf(err, "could not un-snooze vulnerabilities for request %s", request.GetId())
		}
		if !found {
			continue
		}
		// Determine the effective for the cves in the image scope.
		cveStateMap := m.activeReqCache.GetEffectiveVulnStateForImage(request.GetCves().GetIds(), img.GetName().GetRegistry(), img.GetName().GetRemote(), img.GetName().GetTag())
		for _, cve := range request.GetCves().GetIds() {
			if err := m.imageCVEEdges.UpdateVulnerabilityState(allAccessCtx, cve, []string{imageID}, cveStateMap[cve]); err != nil {
				return errors.Wrapf(err, "could not un-snooze vulnerabilities for request %s", request.GetId())
			}
		}
	}

	go m.reprocessAffectedEntities(request.GetId(), imageIDs...)
	return nil
}

// UnSnoozeVulnerabilityOnRequest unsnoozes the CVE for the scope specified by the request
// unless there is another request that is still active that causes this CVE to remain snoozed
func (m *managerImpl) UnSnoozeVulnerabilityOnRequest(_ context.Context, request *storage.VulnerabilityRequest) error {
	// Visit pending cache to ensure that the entry added because of a deferral expiry update,
	// that puts the request into APPROVED_PENDING_UPDATE,is removed.
	m.pendingReqCache.Remove(request.GetId())
	m.activeReqCache.Remove(request.GetId())

	// Search for images matching the scope instead of image+cve combination.
	// Validation of image-cve existence is performed by the image-cve datastore.
	imageIDs, err := m.getImagesIDsForVulnRequest(request)
	if err != nil {
		return errors.Wrapf(err, "could not fetch images matching vulnerability request %s", request.GetId())
	}
	if len(imageIDs) == 0 {
		return nil
	}

	for _, imageID := range imageIDs {
		img, found, err := m.images.GetImageMetadata(allAccessCtx, imageID)
		if err != nil {
			return errors.Wrapf(err, "could not un-snooze vulnerabilities for request %s", request.GetId())
		}
		if !found {
			continue
		}
		// Determine the effective for the cves in the image scope.
		cveStateMap := m.activeReqCache.GetEffectiveVulnStateForImage(request.GetCves().GetIds(), img.GetName().GetRegistry(), img.GetName().GetRemote(), img.GetName().GetTag())
		for _, cve := range request.GetCves().GetIds() {
			if err := m.imageCVEEdges.UpdateVulnerabilityState(allAccessCtx, cve, []string{imageID}, cveStateMap[cve]); err != nil {
				return errors.Wrapf(err, "could not un-snooze vulnerabilities for request %s", request.GetId())
			}
		}
	}

	go m.reprocessAffectedEntities(request.GetId(), imageIDs...)
	return nil
}

func (m *managerImpl) reprocessAffectedEntities(requestID string, affectedImages ...string) {
	// Once the Secured Cluster image cache is invalidated, the image pull cycle is run. It further triggers image
	// risk calculation. Hence, we do not need to recalculate risk here.
	if err := m.reprocessImage(requestID, affectedImages...); err != nil {
		log.Errorf("Could not fetch Secured Cluster image cache keys in response to vuln request %q: %v", requestID, err)
	}
	go m.reprocessDeployments(requestID, affectedImages...)
}

func (m *managerImpl) reprocessDeployments(requestID string, affectedImages ...string) {
	// The re-processing will happen anyways at the next re-processing interval.
	depsByCluster, err := m.getAffectedDeployments(affectedImages...)
	if err != nil {
		log.Errorf("Cannot reprocess deployments. "+
			"Could not get deployment affected by vuln request %q: %v", requestID, err)
		return
	}

	var allDeps []string
	for cluster, deps := range depsByCluster {
		allDeps = append(allDeps, deps...)
		conn := m.connManager.GetConnection(cluster)
		if conn == nil {
			continue
		}
		if err := conn.InjectMessage(allAccessCtx, getReprocessDeploymentMsg(deps...)); err != nil {
			log.Errorf("Could not send request to reprocess deployments affected by vuln request %q", requestID)
		}
	}
	// Reprocessor throttles the requests to reprocess deployments once every the reprocessing interval.
	m.reprocessor.ReprocessRiskForDeployments(allDeps...)
}

func (m *managerImpl) reprocessImage(requestID string, affectedImages ...string) error {
	imageKeys := make([]*central.InvalidateImageCache_ImageKey, 0, len(affectedImages))
	for _, imgID := range affectedImages {
		image, found, err := m.images.GetImage(allAccessCtx, imgID)
		if err != nil {
			return errors.Wrap(err, "could not get image for reprocessing")
		}
		if !found {
			continue
		}
		imageKeys = append(imageKeys, &central.InvalidateImageCache_ImageKey{
			ImageId:       imgID,
			ImageFullName: image.GetName().GetFullName(),
		})
	}

	m.connManager.BroadcastMessage(&central.MsgToSensor{
		Msg: &central.MsgToSensor_InvalidateImageCache{
			InvalidateImageCache: &central.InvalidateImageCache{
				ImageKeys: imageKeys,
			},
		},
	})
	return nil
}

func (m *managerImpl) getAffectedDeployments(affectedImages ...string) (map[string][]string, error) {
	query := search.ConjunctionQuery(
		search.NewQueryBuilder().AddExactMatches(search.ImageSHA, affectedImages...).ProtoQuery(),
		search.NewQueryBuilder().AddStringsHighlighted(search.ClusterID, search.WildcardString).ProtoQuery(),
	)
	results, err := m.deployments.SearchDeployments(allAccessCtx, query)
	if err != nil {
		return nil, errors.Wrap(err, "could not get deployment results")
	}
	if len(results) == 0 {
		return nil, nil
	}

	depsByCluster := make(map[string][]string)
	for _, r := range results {
		clusterIDs := r.FieldToMatches[clusterIDField.FieldPath].GetValues()
		if len(clusterIDs) == 0 {
			log.Errorf("No cluster ID found in fields for deployment %q", r.GetId())
			continue
		}
		depsByCluster[clusterIDs[0]] = append(depsByCluster[clusterIDs[0]], r.GetId())
	}
	return depsByCluster, nil
}

func (m *managerImpl) getImagesIDsForVulnRequest(request *storage.VulnerabilityRequest) ([]string, error) {
	imageQuery, err := utils.GetAffectedImagesQuery(request, nil)
	if err != nil {
		return nil, err
	}
	results, err := m.images.Search(allAccessCtx, imageQuery)
	if err != nil {
		return nil, err
	}
	return search.ResultsToIDs(results), nil
}

func (m *managerImpl) expireDeferrals(deferrals []*storage.VulnerabilityRequest) error {
	processingErrs := errorhelpers.NewErrorList("re-observing expired deferrals")
	for _, req := range deferrals {
		// A request can be re-observed by just marking it inactive
		// NOTE: It is possible that another request will still force this vulnerability to be deferred (e.g. if this was image scoped
		// but a global one still exists).
		if _, err := m.vulnReqs.MarkRequestInactive(allVulnApproverAccessSac, req.GetId(), "[System Generated] Request expired"); err != nil {
			processingErrs.AddWrapf(err, "marking as inactive request %s", req.GetId())
			continue
		}
		if err := m.UnSnoozeVulnerabilityOnRequest(allVulnApproverAccessSac, req); err != nil {
			processingErrs.AddWrapf(err, "unsnoozing vulns for request %s", req.GetId())
		}
	}
	return processingErrs.ToError()
}

func (m *managerImpl) getExpiredDeferrals() ([]*storage.VulnerabilityRequest, error) {
	now := fmt.Sprintf("<%s", time.Now().Format("01/02/2006 MST"))
	q := search.ConjunctionQuery(
		search.NewQueryBuilder().AddGenericTypeLinkedFields([]search.FieldLabel{search.ExpiredRequest, search.RequestExpiryTime}, []interface{}{false, now}).ProtoQuery(),
		search.NewQueryBuilder().AddExactMatches(search.RequestStatus, storage.RequestStatus_APPROVED.String(), storage.RequestStatus_APPROVED_PENDING_UPDATE.String()).ProtoQuery(),
	)
	results, err := m.vulnReqs.SearchRawRequests(allVulnApproverAccessSac, q)
	if err != nil || len(results) == 0 {
		return nil, err
	}
	return results, nil
}

func (m *managerImpl) reObserveExpiredDeferrals() {
	if m.stopped.IsDone() {
		return
	}

	deferrals, err := m.getExpiredDeferrals()
	if err != nil {
		log.Errorf("error retrieving expired deferral requests for reprocessing: %v", err)
		return
	}
	if len(deferrals) == 0 {
		return
	}

	if err := m.expireDeferrals(deferrals); err != nil {
		log.Errorf("Failed to retire expired deferral requests and re-observe associated vulnerabilities with error(s): %+v", err)
	} else {
		log.Infof("Completed retiring %d expired deferral requests and re-observing deferred vulnerabilities", len(deferrals))
	}
}

func (m *managerImpl) getFixableDeferrals() ([]*storage.VulnerabilityRequest, error) {
	q := search.ConjunctionQuery(
		search.NewQueryBuilder().AddGenericTypeLinkedFields([]search.FieldLabel{search.ExpiredRequest, search.RequestExpiresWhenFixed}, []interface{}{false, true}).ProtoQuery(),
		search.NewQueryBuilder().AddExactMatches(search.RequestStatus, storage.RequestStatus_APPROVED.String(), storage.RequestStatus_APPROVED_PENDING_UPDATE.String()).ProtoQuery(),
	)
	results, err := m.vulnReqs.SearchRawRequests(allVulnApproverAccessSac, q)
	if err != nil || len(results) == 0 {
		return nil, err
	}
	var fixableReqs []*storage.VulnerabilityRequest
	for _, res := range results {
		for _, cve := range res.GetCves().GetIds() {
			// TODO: Determine if it's worth checking cvePkg.ContainsComponentBasedCVE(cve.GetTypes()) before doing this. It would involve going to the data store to read CVE data
			// This is only necessary if somehow there ended up being a deferral on a cluster CVE. Or if it goes from an image cve to node cve
			// TODO: Test what happens if it's a cluster cve
			cveScopedCtx := scoped.Context(allAccessCtx, scoped.Scope{
				ID:    cve,
				Level: v1.SearchCategory_VULNERABILITIES,
			})

			fixableQuery, err := utils.GetAffectedImagesQuery(res, search.NewQueryBuilder().AddBools(search.Fixable, true).ProtoQuery())
			if err != nil {
				return nil, err
			}

			count, err := m.componentCVEEdges.Count(cveScopedCtx, fixableQuery)
			if err != nil {
				return nil, errors.Wrapf(err, "could not fetch cve component edge for cve %q for request %q", cve, res.GetId())
			}

			if count != 0 { // This CVE is fixable for this image (or for all images in the query)
				fixableReqs = append(fixableReqs, res)
			}
		}
	}

	return fixableReqs, nil
}

func (m *managerImpl) reObserveFixableDeferrals() {
	if m.stopped.IsDone() {
		return
	}

	deferrals, err := m.getFixableDeferrals()
	if err != nil {
		log.Errorf("error retrieving deferral requests that are now fixable for reprocessing: %v", err)
		return
	}
	if len(deferrals) == 0 {
		return
	}

	if err := m.expireDeferrals(deferrals); err != nil {
		log.Errorf("Failed to retire now-fixable deferral requests and re-observe associated vulnerabilities with error(s): %+v", err)
	} else {
		log.Infof("Completed retiring %d newly fixable deferral requests and re-observing deferred vulnerabilities", len(deferrals))
	}
}

func (m *managerImpl) runExpiredDeferralsProcessor() {
	defer m.stopped.Signal()
	reObserveTimedDeferralsTicker := time.NewTicker(m.reObserveTimedDeferralsTickerDuration)
	defer reObserveTimedDeferralsTicker.Stop()
	reObserveWhenFixedDeferralsTicker := time.NewTicker(m.reObserveWhenFixedDeferralsTickerDuration)
	defer reObserveWhenFixedDeferralsTicker.Stop()

	// Kick off a run to start
	go m.reObserveExpiredDeferrals()
	go m.reObserveFixableDeferrals()

	for {
		select {
		case <-m.stopSig.Done():
			return
		case <-reObserveTimedDeferralsTicker.C:
			go m.reObserveExpiredDeferrals()
		case <-reObserveWhenFixedDeferralsTicker.C:
			go m.reObserveFixableDeferrals()
		}
	}
}

func (m *managerImpl) buildCache() error {
	// Build active requests cache
	q := utils.GetActiveApprovedReqQuery()
	res, err := m.vulnReqs.Search(allAccessCtx, q)
	if err != nil {
		return errors.Wrap(err, "error retrieving keys from vuln request datastore")
	}
	ids := search.ResultsToIDs(res)
	if err := buildCache(m.vulnReqs, m.activeReqCache, ids...); err != nil {
		return err
	}

	// Build pending requests cache
	q = utils.GetActivePendingReqQuery()
	res, err = m.vulnReqs.Search(allAccessCtx, q)
	if err != nil {
		return errors.Wrap(err, "error retrieving keys from vuln request datastore")
	}
	ids = search.ResultsToIDs(res)
	if err := buildCache(m.vulnReqs, m.pendingReqCache, ids...); err != nil {
		return err
	}
	log.Info("[STARTUP] Successfully cached all vulnerability requests")
	return nil
}

func buildCache(vulnReqs vulnReqDataStore.DataStore, cache cache.VulnReqCache, ids ...string) error {
	vulnReqBatcher := batcher.New(len(ids), batchSize)
	for start, end, valid := vulnReqBatcher.Next(); valid; start, end, valid = vulnReqBatcher.Next() {
		vulnReqs, err := vulnReqs.GetMany(allAccessCtx, ids[start:end])
		if err != nil {
			return err
		}
		cache.AddMany(vulnReqs...)
		log.Infof("[STARTUP] Successfully cached %d/%d vulnerability requests", end, len(ids))
	}
	return nil
}

func getReprocessDeploymentMsg(deps ...string) *central.MsgToSensor {
	return &central.MsgToSensor{
		Msg: &central.MsgToSensor_ReprocessDeployment{
			ReprocessDeployment: &central.ReprocessDeployment{
				DeploymentIds: deps,
			},
		},
	}
}
