package resolvers

import (
	"context"
	"math"
	"strings"
	"time"

	"github.com/graph-gophers/graphql-go"
	"github.com/pkg/errors"
	"github.com/stackrox/rox/central/graphql/resolvers/inputtypes"
	"github.com/stackrox/rox/central/metrics"
	"github.com/stackrox/rox/central/vulnerabilityrequest/common"
	vulnReqUtils "github.com/stackrox/rox/central/vulnerabilityrequest/utils"
	v1 "github.com/stackrox/rox/generated/api/v1"
	"github.com/stackrox/rox/generated/storage"
	"github.com/stackrox/rox/pkg/features"
	"github.com/stackrox/rox/pkg/grpc/authz/interceptor"
	pkgMetrics "github.com/stackrox/rox/pkg/metrics"
	"github.com/stackrox/rox/pkg/search"
	"github.com/stackrox/rox/pkg/search/paginated"
	"github.com/stackrox/rox/pkg/utils"
)

func init() {
	schema := getBuilder()
	utils.Must(
		schema.AddInput("VulnReqExpiry", []string{
			"expiresOn: Time",
			"expiresWhenFixed: Boolean",
		}),
		schema.AddInput("DeferVulnRequest", []string{
			"comment: String",
			"cve: String",
			"expiresOn: Time",
			"expiresWhenFixed: Boolean",
			"scope: VulnReqScope",
		}),

		schema.AddType("DeferralRequest", []string{
			"expiresOn: Time",
			"expiresWhenFixed: Boolean!",
		}),
		schema.AddType("VulnerabilityRequest", []string{
			"id: ID!",
			"targetState: String!",
			"status: String!",
			"expired: Boolean!",
			"requestor: SlimUser",
			"approvers: [SlimUser!]!",
			"createdAt: Time",
			"LastUpdated: Time",
			"comments: [RequestComment!]!",
			"scope: VulnerabilityRequest_Scope",
			"deferralReq: DeferralRequest",
			"falsePositiveReq: FalsePositiveRequest",
			"updatedDeferralReq: DeferralRequest",
			"cves: VulnerabilityRequest_CVEs",

			//// Derived fields

			"deploymentCount(query: String): Int!",
			"imageCount(query: String): Int!",

			"deployments(query: String, pagination: Pagination): [Deployment!]!",
			"images(query: String, pagination: Pagination): [Image!]!",
		}),

		schema.AddMutation("deferVulnerability(request: DeferVulnRequest!): VulnerabilityRequest!"),
		schema.AddMutation("markVulnerabilityFalsePositive(request: FalsePositiveVulnRequest!): VulnerabilityRequest!"),
		schema.AddMutation("approveVulnerabilityRequest(requestID: ID!, comment: String!): VulnerabilityRequest!"),
		schema.AddMutation("denyVulnerabilityRequest(requestID: ID!, comment: String!): VulnerabilityRequest!"),
		schema.AddMutation("updateVulnerabilityRequest(requestID: ID!, comment: String!, expiry: VulnReqExpiry!): VulnerabilityRequest!"),
		schema.AddMutation("undoVulnerabilityRequest(requestID: ID!): VulnerabilityRequest!"),
		schema.AddMutation("deleteVulnerabilityRequest(requestID: ID!): Boolean!"),

		schema.AddQuery("vulnerabilityRequest(id: ID!): VulnerabilityRequest"),
		schema.AddQuery("vulnerabilityRequests(query: String, requestIDSelector: String, pagination: Pagination): [VulnerabilityRequest!]!"),
		schema.AddQuery("vulnerabilityRequestsCount(query: String): Int!"),
	)
}

// processWithAuditLog runs handler and logs to the audit log pipeline (assuming there is a notifier setup for audit logging).
// It logs details of the request and if there was an error. processWithAuditLog will return the response and error directly
// from handler. You may need to cast it back to your desired type.
// This is required because currently audit logs are only automatically added for GRPC calls and not GraphQL.
// However, mutating calls should also log. This is a workaround for this limitation.
func (resolver *Resolver) processWithAuditLog(ctx context.Context, req interface{}, method string, handler func() (interface{}, error)) (interface{}, error) {
	resp, err := handler()
	if resolver.AuditLogger != nil {
		go resolver.AuditLogger.SendAdhocAuditMessage(ctx, req, method, interceptor.AuthStatus{}, err)
	}
	return resp, err
}

// DeferVulnerability starts the  workflow to defer a vulnerability.
func (resolver *Resolver) DeferVulnerability(
	ctx context.Context,
	args struct{ Request inputtypes.DeferVulnRequest },
) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "DeferVulnerability")

	if !features.VulnRiskManagement.Enabled() {
		return nil, errors.New("Vulnerability Risk Management is not enabled")
	}

	resp, err := resolver.processWithAuditLog(ctx, args.Request.AsV1DeferralRequest(), "DeferVulnerability", func() (interface{}, error) {
		if err := writeVulnerabilityRequests(ctx); err != nil {
			return nil, err
		}
		req := vulnReqUtils.V1DeferVulnRequestToVulnReq(ctx, args.Request.AsV1DeferralRequest())
		if err := resolver.vulnReqMgr.Create(ctx, req); err != nil {
			return nil, err
		}
		return resolver.wrapVulnerabilityRequest(req, nil)
	})

	if resp == nil {
		return nil, err
	}

	return resp.(*VulnerabilityRequestResolver), err
}

// MarkVulnerabilityFalsePositive starts the workflow to mark a vulnerability as false-positive.
func (resolver *Resolver) MarkVulnerabilityFalsePositive(
	ctx context.Context, args struct {
		Request inputtypes.FalsePositiveVulnRequest
	},
) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "MarkVulnerabilityFalsePositive")

	if !features.VulnRiskManagement.Enabled() {
		return nil, errors.New("Vulnerability Risk Management is not enabled")
	}

	resp, err := resolver.processWithAuditLog(ctx, args.Request.AsV1FalsePositiveRequest(), "MarkVulnerabilityFalsePositive", func() (interface{}, error) {
		if err := writeVulnerabilityRequests(ctx); err != nil {
			return nil, err
		}

		req := vulnReqUtils.V1FalsePositiveRequestToVulnReq(ctx, args.Request.AsV1FalsePositiveRequest())
		if err := resolver.vulnReqMgr.Create(ctx, req); err != nil {
			return nil, err
		}
		return resolver.wrapVulnerabilityRequest(req, nil)
	})

	if resp == nil {
		return nil, err
	}

	return resp.(*VulnerabilityRequestResolver), err
}

// ApproveVulnerabilityRequest approves the vulnerability request with the specified ID.
func (resolver *Resolver) ApproveVulnerabilityRequest(
	ctx context.Context,
	args struct {
		RequestID graphql.ID
		Comment   string
	},
) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "ApproveVulnerabilityRequest")

	if !features.VulnRiskManagement.Enabled() {
		return nil, errors.New("Vulnerability Risk Management is not enabled")
	}

	approveReq := &v1.ApproveVulnRequest{
		Id:      string(args.RequestID),
		Comment: args.Comment,
	}

	resp, err := resolver.processWithAuditLog(ctx, approveReq, "ApproveVulnerabilityRequest", func() (interface{}, error) {
		if err := writeVulnerabilityApprovals(ctx); err != nil {
			return nil, err
		}
		response, err := resolver.vulnReqMgr.Approve(ctx, approveReq.Id,
			&common.VulnRequestParams{Comment: approveReq.Comment})
		if err != nil {
			return nil, err
		}
		return resolver.wrapVulnerabilityRequest(response, err)
	})

	if resp == nil {
		return nil, err
	}

	return resp.(*VulnerabilityRequestResolver), err
}

// DenyVulnerabilityRequest denies the vulnerability request with the specified ID.
func (resolver *Resolver) DenyVulnerabilityRequest(
	ctx context.Context,
	args struct {
		RequestID graphql.ID
		Comment   string
	},
) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "DenyVulnerabilityRequest")

	if !features.VulnRiskManagement.Enabled() {
		return nil, errors.New("Vulnerability Risk Management is not enabled")
	}

	denyReq := &v1.DenyVulnRequest{
		Id:      string(args.RequestID),
		Comment: args.Comment,
	}

	resp, err := resolver.processWithAuditLog(ctx, denyReq, "DenyVulnerabilityRequest", func() (interface{}, error) {
		if err := writeVulnerabilityApprovals(ctx); err != nil {
			return nil, err
		}
		response, err := resolver.vulnReqMgr.Deny(ctx, denyReq.Id, &common.VulnRequestParams{Comment: denyReq.Comment})
		if err != nil {
			return nil, err
		}
		return resolver.wrapVulnerabilityRequest(response, err)
	})

	if resp == nil {
		return nil, err
	}
	return resp.(*VulnerabilityRequestResolver), err
}

// UpdateVulnerabilityRequest updates the vulnerability request with specified ID. Currently, only the expiry of a deferral request can be updated.
func (resolver *Resolver) UpdateVulnerabilityRequest(
	ctx context.Context,
	args struct {
		RequestID graphql.ID
		Comment   string
		Expiry    inputtypes.VulnReqExpiry
	},
) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "UpdateVulnerabilityRequest")

	if !features.VulnRiskManagement.Enabled() {
		return nil, errors.New("Vulnerability Risk Management is not enabled")
	}

	updateReq := &v1.UpdateVulnRequest{
		Id:      string(args.RequestID),
		Comment: args.Comment,
		Expiry:  args.Expiry.AsRequestExpiry(),
	}

	resp, err := resolver.processWithAuditLog(ctx, updateReq, "UpdateVulnerabilityRequest", func() (interface{}, error) {
		if err := writeVulnerabilityRequestsOrApprovals(ctx); err != nil {
			return nil, err
		}

		response, err := resolver.vulnReqMgr.UpdateExpiry(ctx, string(args.RequestID),
			&common.VulnRequestParams{Comment: updateReq.Comment, Expiry: updateReq.Expiry})
		return resolver.wrapVulnerabilityRequest(response, err)
	})

	if resp == nil {
		return nil, err
	}

	return resp.(*VulnerabilityRequestResolver), err
}

// UndoVulnerabilityRequest undoes/retires the vulnerability request with specified ID. This action does not delete the vulnerability request.
func (resolver *Resolver) UndoVulnerabilityRequest(
	ctx context.Context,
	args struct{ RequestID graphql.ID },
) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "UndoVulnerabilityRequest")

	if !features.VulnRiskManagement.Enabled() {
		return nil, errors.New("Vulnerability Risk Management is not enabled")
	}

	resID := &v1.ResourceByID{Id: string(args.RequestID)}
	resp, err := resolver.processWithAuditLog(ctx, resID, "UndoVulnerabilityRequest", func() (interface{}, error) {
		if err := writeVulnerabilityRequestsOrApprovals(ctx); err != nil {
			return nil, err
		}
		response, err := resolver.vulnReqMgr.Undo(ctx, resID.Id, nil)
		if err != nil {
			return nil, err
		}
		return resolver.wrapVulnerabilityRequest(response, err)
	})

	if resp == nil {
		return nil, err
	}

	return resp.(*VulnerabilityRequestResolver), err
}

// DeleteVulnerabilityRequest deletes the vulnerability request with specified ID.
func (resolver *Resolver) DeleteVulnerabilityRequest(
	ctx context.Context,
	args struct{ RequestID graphql.ID },
) (bool, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "DeleteVulnerabilityRequest")

	if !features.VulnRiskManagement.Enabled() {
		return false, errors.New("Vulnerability Risk Management is not enabled")
	}

	resID := &v1.ResourceByID{Id: string(args.RequestID)}
	resp, err := resolver.processWithAuditLog(ctx, resID, "DeleteVulnerabilityRequest", func() (interface{}, error) {
		if err := writeVulnerabilityRequests(ctx); err != nil {
			return false, err
		}
		if err := resolver.vulnReqMgr.Delete(ctx, string(args.RequestID)); err != nil {
			return false, err
		}
		return true, nil
	})

	if resp == nil {
		return false, err
	}

	return resp.(bool), err
}

// VulnerabilityRequest returns the vulnerability request with specified ID.
func (resolver *Resolver) VulnerabilityRequest(ctx context.Context, args struct{ graphql.ID }) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "VulnerabilityRequest")

	if !features.VulnRiskManagement.Enabled() {
		return nil, errors.New("Vulnerability Risk Management is not enabled")
	}
	if err := readVulnerabilityRequestsOrApprovals(ctx); err != nil {
		return nil, err
	}
	response, found, err := resolver.vulnReqStore.Get(ctx, string(args.ID))
	if err != nil || !found {
		return nil, err
	}
	return resolver.wrapVulnerabilityRequest(response, nil)
}

// VulnerabilityRequests returns all vulnerability requests satisfying the specified query.
func (resolver *Resolver) VulnerabilityRequests(ctx context.Context,
	args struct {
		Query             *string
		RequestIDSelector *string
		Pagination        *inputtypes.Pagination
	}) ([]*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "VulnerabilityRequests")

	if !features.VulnRiskManagement.Enabled() {
		return nil, errors.New("Vulnerability Risk Management is not enabled")
	}
	if err := readVulnerabilityRequestsOrApprovals(ctx); err != nil {
		return nil, err
	}

	parsedQuery, err := search.ParseQuery(func() string {
		if args.Query == nil {
			return ""
		}
		return *args.Query
	}(), search.MatchAllIfEmpty())
	if err != nil {
		return nil, err
	}

	if args.RequestIDSelector != nil && *args.RequestIDSelector != "" {
		parsedQuery = search.ConjunctionQuery(
			search.NewQueryBuilder().AddDocIDs(strings.Split(*args.RequestIDSelector, ",")...).ProtoQuery(),
			parsedQuery,
		)
	}

	// Fill in pagination.
	paginated.FillPagination(parsedQuery, args.Pagination.AsV1Pagination(), math.MaxInt32)

	response, err := resolver.vulnReqStore.SearchRawRequests(
		ctx,
		parsedQuery,
	)
	if err != nil {
		return nil, err
	}
	return resolver.wrapVulnerabilityRequests(response, nil)
}

// VulnerabilityRequestsCount returns a count of all vulnerability requests satisfying the specified query.
func (resolver *Resolver) VulnerabilityRequestsCount(ctx context.Context, args RawQuery) (int32, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "VulnerabilityRequestsCount")

	if !features.VulnRiskManagement.Enabled() {
		return 0, errors.New("Vulnerability Risk Management is not enabled")
	}
	if err := readVulnerabilityRequestsOrApprovals(ctx); err != nil {
		return 0, err
	}

	q, err := args.AsV1QueryOrEmpty()
	if err != nil {
		return 0, err
	}

	count, err := resolver.vulnReqStore.Count(ctx, q)
	if err != nil {
		return 0, err
	}
	return int32(count), nil
}

// VulnerabilityRequestResolver resolves data about a Vulnerability Requests.
type VulnerabilityRequestResolver struct {
	root *Resolver
	data *storage.VulnerabilityRequest
}

func (resolver *Resolver) wrapVulnerabilityRequest(value *storage.VulnerabilityRequest, err error) (*VulnerabilityRequestResolver, error) {
	if err != nil || value == nil {
		return nil, err
	}
	return &VulnerabilityRequestResolver{root: resolver, data: value}, nil
}

func (resolver *Resolver) wrapVulnerabilityRequests(values []*storage.VulnerabilityRequest, err error) ([]*VulnerabilityRequestResolver, error) {
	if err != nil || len(values) == 0 {
		return nil, err
	}
	ret := make([]*VulnerabilityRequestResolver, 0, len(values))
	for _, value := range values {
		ret = append(ret, &VulnerabilityRequestResolver{root: resolver, data: value})
	}
	return ret, nil
}

// ID returns the ID of the vulnerability request.
func (vr *VulnerabilityRequestResolver) ID(ctx context.Context) graphql.ID {
	return graphql.ID(vr.data.GetId())
}

// TargetState returns the requested state for the vulnerability.
func (vr *VulnerabilityRequestResolver) TargetState(ctx context.Context) string {
	return vr.data.GetTargetState().String()
}

// Status returns the request status.
func (vr *VulnerabilityRequestResolver) Status(ctx context.Context) string {
	return vr.data.GetStatus().String()
}

// Expired returns whether the vulnerability request is expired.
func (vr *VulnerabilityRequestResolver) Expired(ctx context.Context) bool {
	return vr.data.GetExpired()
}

// Requestor returns the requestor of the vulnerbility request.
func (vr *VulnerabilityRequestResolver) Requestor(ctx context.Context) (*slimUserResolver, error) {
	return vr.root.wrapSlimUser(vr.data.GetRequestor(), true, nil)
}

// Approvers returns the list of approvers of the vulnerbility request, if any.
func (vr *VulnerabilityRequestResolver) Approvers(ctx context.Context) ([]*slimUserResolver, error) {
	return vr.root.wrapSlimUsers(vr.data.GetApprovers(), nil)
}

// CreatedAt returns the timestamp when the request was created.
func (vr *VulnerabilityRequestResolver) CreatedAt(ctx context.Context) (*graphql.Time, error) {
	return timestamp(vr.data.GetCreatedAt())
}

// LastUpdated returns the timestamp when the request was last updated.
func (vr *VulnerabilityRequestResolver) LastUpdated(ctx context.Context) (*graphql.Time, error) {
	return timestamp(vr.data.GetLastUpdated())
}

// Comments returns the request comments.
func (vr *VulnerabilityRequestResolver) Comments(ctx context.Context) ([]*requestCommentResolver, error) {
	return vr.root.wrapRequestComments(vr.data.GetComments(), nil)
}

// Scope returns the request's scope.
func (vr *VulnerabilityRequestResolver) Scope(ctx context.Context) (*vulnerabilityRequest_ScopeResolver, error) {
	return vr.root.wrapVulnerabilityRequest_Scope(vr.data.GetScope(), true, nil)
}

// DeferralReq returns the deferral request.
func (vr *VulnerabilityRequestResolver) DeferralReq(ctx context.Context) (*DeferralRequestResolver, error) {
	return vr.root.wrapDeferralRequest(vr.data.GetDeferralReq(), nil)
}

// FalsePositiveReq returns the false positive request.
func (vr *VulnerabilityRequestResolver) FalsePositiveReq(ctx context.Context) (*falsePositiveRequestResolver, error) {
	return vr.root.wrapFalsePositiveRequest(vr.data.GetFpRequest(), true, nil)
}

// UpdatedDeferralReq returns the updated deferral request.
func (vr *VulnerabilityRequestResolver) UpdatedDeferralReq(ctx context.Context) (*DeferralRequestResolver, error) {
	return vr.root.wrapDeferralRequest(vr.data.GetUpdatedDeferralReq(), nil)
}

// Cves returns the list of CVEs that the request applies to.
func (vr *VulnerabilityRequestResolver) Cves(ctx context.Context) (*vulnerabilityRequest_CVEsResolver, error) {
	return vr.root.wrapVulnerabilityRequest_CVEs(vr.data.GetCves(), true, nil)
}

// DeploymentCount returns the count of deployments impacted by this vulnerability request for the specified query.
func (vr *VulnerabilityRequestResolver) DeploymentCount(ctx context.Context, args RawQuery) (int32, error) {
	if err := readDeployments(ctx); err != nil {
		return 0, err
	}
	query, err := args.AsV1QueryOrEmpty()
	if err != nil {
		return 0, err
	}
	count, err := vr.root.vulnReqQueryMgr.DeploymentCount(ctx, vr.data.GetId(), query)
	if err != nil {
		return 0, err
	}
	return int32(count), nil
}

// ImageCount returns the count of images impacted by this vulnerability request for the specified query.
func (vr *VulnerabilityRequestResolver) ImageCount(ctx context.Context, args RawQuery) (int32, error) {
	if err := readImages(ctx); err != nil {
		return 0, err
	}
	query, err := args.AsV1QueryOrEmpty()
	if err != nil {
		return 0, err
	}
	count, err := vr.root.vulnReqQueryMgr.ImageCount(ctx, vr.data.GetId(), query)
	if err != nil {
		return 0, err
	}
	return int32(count), nil
}

// Deployments returns the deployments impacted by this vulnerability request for the specified query.
func (vr *VulnerabilityRequestResolver) Deployments(ctx context.Context, args PaginatedQuery) ([]*deploymentResolver, error) {
	if err := readDeployments(ctx); err != nil {
		return nil, err
	}
	query, err := args.AsV1QueryOrEmpty()
	if err != nil {
		return nil, err
	}
	return vr.root.wrapDeployments(vr.root.vulnReqQueryMgr.Deployments(ctx, vr.data.GetId(), query))
}

// Images returns the images impacted by this vulnerability request for the specified query.
func (vr *VulnerabilityRequestResolver) Images(ctx context.Context, args PaginatedQuery) ([]*imageResolver, error) {
	if err := readImages(ctx); err != nil {
		return nil, err
	}
	query, err := args.AsV1QueryOrEmpty()
	if err != nil {
		return nil, err
	}
	return vr.root.wrapImages(vr.root.vulnReqQueryMgr.Images(ctx, vr.data.GetId(), query))
}

// DeferralRequestResolver resolves data about a Vulnerability deferral requests.
type DeferralRequestResolver struct {
	root *Resolver
	data *storage.DeferralRequest
}

func (resolver *Resolver) wrapDeferralRequest(value *storage.DeferralRequest, err error) (*DeferralRequestResolver, error) {
	if err != nil || value == nil {
		return nil, err
	}
	return &DeferralRequestResolver{root: resolver, data: value}, nil
}

// ExpiresOn returns the deferral request expiry timestamp if the request had a time-based expiry.
func (dr *DeferralRequestResolver) ExpiresOn(ctx context.Context) (*graphql.Time, error) {
	return timestamp(dr.data.GetExpiry().GetExpiresOn())
}

// ExpiresWhenFixed returns true if the deferral request expires when vulnerability is fixable.
func (dr *DeferralRequestResolver) ExpiresWhenFixed(ctx context.Context) bool {
	return dr.data.GetExpiry().GetExpiresWhenFixed()
}
