package db

import (
	"fmt"

	"github.com/facebookincubator/nvdtools/wfn"

	"github.com/anchore/grype/grype/cpe"
	grypeDB "github.com/anchore/grype/grype/db/v5"
	"github.com/anchore/grype/grype/db/v5/namespace"
	"github.com/anchore/grype/grype/distro"
	"github.com/anchore/grype/grype/pkg"
	"github.com/anchore/grype/grype/vulnerability"
	"github.com/anchore/grype/internal/log"
	syftPkg "github.com/anchore/syft/syft/pkg"
)

var _ vulnerability.Provider = (*VulnerabilityProvider)(nil)

type VulnerabilityProvider struct {
	namespaceIndex *namespace.Index
	reader         grypeDB.VulnerabilityStoreReader
}

func NewVulnerabilityProvider(reader grypeDB.VulnerabilityStoreReader) (*VulnerabilityProvider, error) {
	namespaces, err := reader.GetVulnerabilityNamespaces()
	if err != nil {
		return nil, fmt.Errorf("unable to get namespaces from store: %w", err)
	}

	namespaceIndex, err := namespace.FromStrings(namespaces)
	if err != nil {
		return nil, fmt.Errorf("unable to parse namespaces from store: %w", err)
	}

	return &VulnerabilityProvider{
		namespaceIndex: namespaceIndex,
		reader:         reader,
	}, nil
}

func (pr *VulnerabilityProvider) Get(id, namespace string) ([]vulnerability.Vulnerability, error) {
	// note: getting a vulnerability record by id doesn't necessarily return a single record
	// since records are duplicated by the set of fixes they have.
	vulns, err := pr.reader.GetVulnerability(namespace, id)
	if err != nil {
		return nil, fmt.Errorf("provider failed to fetch namespace=%q pkg=%q: %w", namespace, id, err)
	}

	var results []vulnerability.Vulnerability
	for _, vuln := range vulns {
		vulnObj, err := vulnerability.NewVulnerability(vuln)
		if err != nil {
			return nil, fmt.Errorf("provider failed to inflate vulnerability record (namespace=%q id=%q): %w", vuln.Namespace, vuln.ID, err)
		}

		results = append(results, *vulnObj)
	}
	return results, nil
}

func (pr *VulnerabilityProvider) GetByDistro(d *distro.Distro, p pkg.Package) ([]vulnerability.Vulnerability, error) {
	if d == nil {
		return nil, nil
	}

	var vulnerabilities []vulnerability.Vulnerability
	namespaces := pr.namespaceIndex.NamespacesForDistro(d)

	if len(namespaces) == 0 {
		log.Debugf("no vulnerability namespaces found in grype database for distro=%s package=%s", d.String(), p.Name)
		return vulnerabilities, nil
	}

	vulnerabilities = make([]vulnerability.Vulnerability, 0)

	for _, n := range namespaces {
		for _, packageName := range n.Resolver().Resolve(p) {
			nsStr := n.String()
			allPkgVulns, err := pr.reader.SearchForVulnerabilities(nsStr, packageName)

			if err != nil {
				return nil, fmt.Errorf("provider failed to search for vulnerabilities (namespace=%q pkg=%q): %w", nsStr, packageName, err)
			}

			for _, vuln := range allPkgVulns {
				vulnObj, err := vulnerability.NewVulnerability(vuln)
				if err != nil {
					return nil, fmt.Errorf("provider failed to inflate vulnerability record (namespace=%q id=%q distro=%q): %w", vuln.Namespace, vuln.ID, d, err)
				}

				vulnerabilities = append(vulnerabilities, *vulnObj)
			}
		}
	}

	return vulnerabilities, nil
}

func (pr *VulnerabilityProvider) GetByLanguage(l syftPkg.Language, p pkg.Package) ([]vulnerability.Vulnerability, error) {
	var vulnerabilities []vulnerability.Vulnerability
	namespaces := pr.namespaceIndex.NamespacesForLanguage(l)

	if len(namespaces) == 0 {
		log.Debugf("no vulnerability namespaces found in grype database for language=%s package=%s", l, p.Name)
		return vulnerabilities, nil
	}

	vulnerabilities = make([]vulnerability.Vulnerability, 0)

	for _, n := range namespaces {
		for _, packageName := range n.Resolver().Resolve(p) {
			nsStr := n.String()
			allPkgVulns, err := pr.reader.SearchForVulnerabilities(nsStr, packageName)

			if err != nil {
				return nil, fmt.Errorf("provider failed to fetch namespace=%q pkg=%q: %w", nsStr, packageName, err)
			}

			for _, vuln := range allPkgVulns {
				vulnObj, err := vulnerability.NewVulnerability(vuln)
				if err != nil {
					return nil, fmt.Errorf("provider failed to inflate vulnerability record (namespace=%q id=%q language=%q): %w", vuln.Namespace, vuln.ID, l, err)
				}

				vulnerabilities = append(vulnerabilities, *vulnObj)
			}
		}
	}

	return vulnerabilities, nil
}

func (pr *VulnerabilityProvider) GetByCPE(requestCPE syftPkg.CPE) ([]vulnerability.Vulnerability, error) {
	vulns := make([]vulnerability.Vulnerability, 0)
	namespaces := pr.namespaceIndex.CPENamespaces()

	if len(namespaces) == 0 {
		log.Debugf("no vulnerability namespaces found for arbitrary CPEs in grype database")
		return nil, nil
	}

	if requestCPE.Product == wfn.Any || requestCPE.Product == wfn.NA {
		return nil, fmt.Errorf("product name is required")
	}

	for _, ns := range namespaces {
		allPkgVulns, err := pr.reader.SearchForVulnerabilities(ns.String(), ns.Resolver().Normalize(requestCPE.Product))
		if err != nil {
			return nil, fmt.Errorf("provider failed to fetch namespace=%q product=%q: %w", ns, requestCPE.Product, err)
		}

		normalizedRequestCPE, err := syftPkg.NewCPE(ns.Resolver().Normalize(requestCPE.BindToFmtString()))

		if err != nil {
			normalizedRequestCPE = requestCPE
		}

		for _, vuln := range allPkgVulns {
			vulnCPEs, err := cpe.NewSlice(vuln.CPEs...)
			if err != nil {
				return nil, err
			}

			// compare the request CPE to the potential matches (excluding version, which is handled downstream)
			candidateMatchCpes := cpe.MatchWithoutVersion(normalizedRequestCPE, vulnCPEs)

			if len(candidateMatchCpes) > 0 {
				vulnObj, err := vulnerability.NewVulnerability(vuln)
				if err != nil {
					return nil, fmt.Errorf("provider failed to inflate vulnerability record (namespace=%q id=%q cpe=%q): %w", vuln.Namespace, vuln.ID, requestCPE.BindToFmtString(), err)
				}

				vulnObj.CPEs = candidateMatchCpes

				vulns = append(vulns, *vulnObj)
			}
		}
	}

	return vulns, nil
}
