package searcher

import (
	"context"

	"github.com/stackrox/rox/central/role/resources"
	"github.com/stackrox/rox/central/vulnerabilityrequest/datastore/internal/store"
	"github.com/stackrox/rox/central/vulnerabilityrequest/index"
	v1 "github.com/stackrox/rox/generated/api/v1"
	"github.com/stackrox/rox/generated/storage"
	"github.com/stackrox/rox/pkg/sac"
	"github.com/stackrox/rox/pkg/search"
	"github.com/stackrox/rox/pkg/search/blevesearch"
	"github.com/stackrox/rox/pkg/search/paginated"
)

var (
	defaultSortOption = &v1.QuerySortOption{
		Field: search.CreatedTime.String(),
	}

	requesterOrApproverSAC = sac.ForResources(sac.ForResource(resources.VulnerabilityManagementRequests), sac.ForResource(resources.VulnerabilityManagementApprovals))
)

type searcherImpl struct {
	store    store.Store
	indexer  index.Indexer
	searcher search.Searcher
}

// Count returns the number of search results from the query
func (s *searcherImpl) Count(ctx context.Context, q *v1.Query) (int, error) {
	if ok, err := requesterOrApproverSAC.ReadAllowedToAny(ctx); err != nil || !ok {
		return 0, err
	}

	return s.searcher.Count(ctx, q)
}

// Search returns the raw search results from the query
func (s *searcherImpl) Search(ctx context.Context, q *v1.Query) ([]search.Result, error) {
	if ok, err := requesterOrApproverSAC.ReadAllowedToAny(ctx); err != nil || !ok {
		return nil, err
	}

	return s.getSearchResults(ctx, q)
}

// SearchRequests returns the search results from indexed VulnerabilityRequest objects for the query.
func (s *searcherImpl) SearchRequests(ctx context.Context, q *v1.Query) ([]*v1.SearchResult, error) {
	if ok, err := requesterOrApproverSAC.ReadAllowedToAny(ctx); err != nil || !ok {
		return nil, err
	}

	vulnRequests, results, err := s.searchRequests(ctx, q)
	if err != nil {
		return nil, err
	}
	return convertMany(vulnRequests, results), nil
}

// SearchRawRequests retrieves vulnerability requests from the indexer and storage
func (s *searcherImpl) SearchRawRequests(ctx context.Context, q *v1.Query) ([]*storage.VulnerabilityRequest, error) {
	if ok, err := requesterOrApproverSAC.ReadAllowedToAny(ctx); err != nil || !ok {
		return nil, err
	}

	vulnReqs, _, err := s.searchRequests(ctx, q)
	return vulnReqs, err
}

func (s *searcherImpl) getSearchResults(ctx context.Context, q *v1.Query) ([]search.Result, error) {
	return s.searcher.Search(ctx, q)
}

func (s *searcherImpl) searchRequests(ctx context.Context, q *v1.Query) ([]*storage.VulnerabilityRequest, []search.Result, error) {
	results, err := s.searcher.Search(ctx, q)
	if err != nil {
		return nil, nil, err
	}

	vulnRequests, missingIndices, err := s.store.GetMany(ctx, search.ResultsToIDs(results))
	if err != nil {
		return nil, nil, err
	}
	results = search.RemoveMissingResults(results, missingIndices)
	return vulnRequests, results, nil
}

func convertMany(vulnRequests []*storage.VulnerabilityRequest, results []search.Result) []*v1.SearchResult {
	ret := make([]*v1.SearchResult, len(vulnRequests))
	for i, vulnRequest := range vulnRequests {
		ret[i] = convertOne(vulnRequest, &results[i])
	}
	return ret
}

func convertOne(vulnRequest *storage.VulnerabilityRequest, result *search.Result) *v1.SearchResult {
	return &v1.SearchResult{
		Category:       v1.SearchCategory_VULN_REQUEST,
		Id:             vulnRequest.GetId(),
		Name:           vulnRequest.GetId(), // Requests do not have names.
		FieldToMatches: search.GetProtoMatchesMap(result.Matches),
		Score:          result.Score,
	}
}

// Format the search functionality of the indexer to be filtered (for sac) and paginated.
func formatSearcher(unsafeSearcher blevesearch.UnsafeSearcher) search.Searcher {
	filteredSearcher := blevesearch.WrapUnsafeSearcherAsSearcher(unsafeSearcher) // Make the UnsafeSearcher safe.
	paginatedSearcher := paginated.Paginated(filteredSearcher)
	defaultSortedSearcher := paginated.WithDefaultSortOption(paginatedSearcher, defaultSortOption)
	return defaultSortedSearcher
}
