package cache

import (
	"fmt"

	"github.com/stackrox/rox/central/vulnerabilityrequest/common"
	"github.com/stackrox/rox/generated/storage"
	"github.com/stackrox/rox/pkg/set"
	"github.com/stackrox/rox/pkg/stringutils"
	"github.com/stackrox/rox/pkg/sync"
)

type slimRequest struct {
	requestID   string
	targetState storage.VulnerabilityState
	cves        map[string]struct{}
}

type vulnReqCacheImpl struct {
	// vulnReqByScope is map of scope to active request IDs to (slimmed) request object.
	vulnReqByScope map[string]map[string]*slimRequest
	// scopeByVulnReqs is map of active vulnerability request IDs to vuln request scope strings.
	scopeByVulnReqs map[string]string

	lock sync.RWMutex
}

func (c *vulnReqCacheImpl) GetVulnsWithState(registry, remote, tag string) map[string]storage.VulnerabilityState {
	c.lock.RLock()
	defer c.lock.RUnlock()

	vulns := make(map[string]storage.VulnerabilityState)
	for _, scope := range []string{
		// Start with the largest scope because state from the smallest scope takes precedence.
		common.MatchAll,
		imageNameToScopeStr(registry, remote, common.MatchAll),
		imageNameToScopeStr(registry, remote, tag),
	} {

		reqMap := c.vulnReqByScope[scope]
		if reqMap == nil {
			continue
		}

		for _, req := range reqMap {
			if req == nil {
				continue
			}
			for cve := range req.cves {
				vulns[cve] = req.targetState
			}
		}
	}
	return vulns
}

func (c *vulnReqCacheImpl) GetEffectiveVulnReqIDForImage(registry, remote, tag, cve string) string {
	c.lock.RLock()
	defer c.lock.RUnlock()

	for _, scope := range []string{
		// Start with the smallest scope because req from the smallest scope takes precedence.
		imageNameToScopeStr(registry, remote, tag),
		imageNameToScopeStr(registry, remote, common.MatchAll),
		common.MatchAll,
	} {
		reqMap := c.vulnReqByScope[scope]
		if reqMap == nil {
			continue
		}

		for _, req := range reqMap {
			if req == nil {
				continue
			}
			// Return the first request we encounter since we are moving from smallest scope to largest scope.
			if _, ok := req.cves[cve]; ok {
				return req.requestID
			}
		}
	}
	return ""
}

func (c *vulnReqCacheImpl) Add(request *storage.VulnerabilityRequest) bool {
	c.lock.Lock()
	defer c.lock.Unlock()

	return c.addNoLock(request)
}

func (c *vulnReqCacheImpl) GetEffectiveVulnStateForImage(cves []string, registry, remote, tag string) map[string]storage.VulnerabilityState {
	if registry == "" || remote == "" {
		return nil
	}
	ret := make(map[string]storage.VulnerabilityState)
	for _, cve := range cves {
		ret[cve] = storage.VulnerabilityState_OBSERVED
	}
	cvesSet := set.NewStringSet(cves...)

	c.lock.RLock()
	defer c.lock.RUnlock()

	for _, scope := range []string{
		// Start with the smallest scope because req from the smallest scope takes precedence.
		imageNameToScopeStr(registry, remote, tag),
		imageNameToScopeStr(registry, remote, common.MatchAll),
		common.MatchAll,
	} {
		reqMap := c.vulnReqByScope[scope]
		if reqMap == nil {
			continue
		}
		for _, req := range reqMap {
			processSlimRequest(req, scope, cvesSet, ret)
		}
	}
	return ret
}

func processSlimRequest(slimReq *slimRequest, scope string, cveSet set.StringSet, result map[string]storage.VulnerabilityState) {
	if slimReq == nil {
		return
	}
	// Currently, only one cve per request is allowed, but let's stay compatible with the object definition.
	for cve := range slimReq.cves {
		if !cveSet.Contains(cve) {
			continue
		}
		// OBSERVED state indicates that no scope has been processed for the cve so far.
		smaller, scopeHasTag := isScopeSmaller(scope, result[cve] == storage.VulnerabilityState_OBSERVED)
		if smaller {
			result[cve] = slimReq.targetState
		}
		if scopeHasTag {
			cveSet.Remove(cve)
		}
	}
}

func isScopeSmaller(scope string, firstScope bool) (smaller bool, scopeHasTag bool) {
	imageName, tag := stringutils.Split2(scope, ":")
	// If the tag is not regex, it is the smallest scope since the only supported regex is `.*`.
	if tag != common.MatchAll {
		return true, true
	}
	// If this is the first encounter of the scope or image name is not a regex, this is the smallest scope thus far.
	if firstScope || imageName != common.MatchAll {
		return true, false
	}
	return false, false
}

func (c *vulnReqCacheImpl) AddMany(requests ...*storage.VulnerabilityRequest) {
	c.lock.Lock()
	defer c.lock.Unlock()

	for _, request := range requests {
		c.addNoLock(request)
	}
}

func (c *vulnReqCacheImpl) addNoLock(request *storage.VulnerabilityRequest) bool {
	scope := c.scopeByVulnReqs[request.GetId()]
	if scope != "" {
		return false
	}
	scope = toScopeStr(request.GetScope())
	c.scopeByVulnReqs[request.GetId()] = scope

	if c.vulnReqByScope[scope] == nil {
		c.vulnReqByScope[scope] = make(map[string]*slimRequest)
	}
	if c.vulnReqByScope[scope][request.GetId()] != nil {
		return false
	}
	cveMap := make(map[string]struct{})
	for _, cve := range request.GetCves().GetIds() {
		cveMap[cve] = struct{}{}
	}
	c.vulnReqByScope[scope][request.GetId()] = &slimRequest{
		requestID:   request.GetId(),
		targetState: request.GetTargetState(),
		cves:        cveMap,
	}
	return true
}

func (c *vulnReqCacheImpl) Remove(requestID string) bool {
	c.lock.Lock()
	defer c.lock.Unlock()

	return c.removeNoLock(requestID)
}

func (c *vulnReqCacheImpl) RemoveMany(requestIDs ...string) bool {
	c.lock.Lock()
	defer c.lock.Unlock()

	for _, id := range requestIDs {
		c.removeNoLock(id)
	}
	return true
}

func (c *vulnReqCacheImpl) removeNoLock(requestID string) bool {
	scope := c.scopeByVulnReqs[requestID]
	if scope == "" {
		return false
	}
	delete(c.scopeByVulnReqs, requestID)

	reqMap := c.vulnReqByScope[scope]
	if reqMap == nil {
		return false
	}
	delete(reqMap, requestID)

	if len(reqMap) == 0 {
		delete(c.vulnReqByScope, scope)
	}
	return true
}

func toScopeStr(scope *storage.VulnerabilityRequest_Scope) string {
	if scope.GetGlobalScope() != nil {
		return common.MatchAll
	}
	if imgScope := scope.GetImageScope(); imgScope != nil {
		return imageNameToScopeStr(imgScope.GetRegistry(), imgScope.GetRemote(), imgScope.GetTag())
	}
	return ""
}

func imageNameToScopeStr(registry, remote, tag string) string {
	return fmt.Sprintf("%s/%s:%s", registry, remote, tag)
}
