package db

import (
	"fmt"

	"github.com/facebookincubator/nvdtools/wfn"

	"github.com/anchore/grype/grype/cpe"
	grypeDB "github.com/anchore/grype/grype/db/v3"
	"github.com/anchore/grype/grype/distro"
	"github.com/anchore/grype/grype/pkg"
	"github.com/anchore/grype/grype/vulnerability"
	syftPkg "github.com/anchore/syft/syft/pkg"
)

var _ vulnerability.Provider = (*VulnerabilityProvider)(nil)

type VulnerabilityProvider struct {
	reader grypeDB.VulnerabilityStoreReader
}

func NewVulnerabilityProvider(reader grypeDB.VulnerabilityStoreReader) *VulnerabilityProvider {
	return &VulnerabilityProvider{
		reader: reader,
	}
}

func (pr *VulnerabilityProvider) GetByDistro(d *distro.Distro, p pkg.Package) ([]vulnerability.Vulnerability, error) {
	if d == nil {
		return nil, nil
	}

	namespace := grypeDB.NamespaceForDistro(d)
	allPkgVulns, err := pr.reader.GetVulnerability(namespace, p.Name)

	if err != nil {
		return nil, fmt.Errorf("provider failed to fetch namespace='%s' pkg='%s': %w", namespace, p.Name, err)
	}

	var vulnerabilities []vulnerability.Vulnerability

	for _, vuln := range allPkgVulns {
		vulnObj, err := vulnerability.NewVulnerability(vuln)
		if err != nil {
			return nil, fmt.Errorf("provider failed to parse distro='%s': %w", d, err)
		}

		vulnerabilities = append(vulnerabilities, *vulnObj)
	}

	return vulnerabilities, nil
}

func (pr *VulnerabilityProvider) GetByLanguage(l syftPkg.Language, p pkg.Package) ([]vulnerability.Vulnerability, error) {
	vulns := make([]vulnerability.Vulnerability, 0)

	namersByNamespace := grypeDB.NamespacePackageNamersForLanguage(l)
	if namersByNamespace == nil {
		return nil, fmt.Errorf("no store namespaces found for language '%s'", l)
	}

	for namespace, namer := range namersByNamespace {
		for _, name := range namer(p) {
			allPkgVulns, err := pr.reader.GetVulnerability(namespace, name)
			if err != nil {
				return nil, fmt.Errorf("provider failed to fetch namespace='%s' pkg='%s': %w", namespace, name, err)
			}

			for _, vuln := range allPkgVulns {
				vulnObj, err := vulnerability.NewVulnerability(vuln)
				if err != nil {
					return nil, fmt.Errorf("provider failed to parse language='%s': %w", l, err)
				}

				vulns = append(vulns, *vulnObj)
			}
		}
	}

	return vulns, nil
}

func (pr *VulnerabilityProvider) GetByCPE(requestCPE syftPkg.CPE) ([]vulnerability.Vulnerability, error) {
	vulns := make([]vulnerability.Vulnerability, 0)

	namespaces := grypeDB.NamespacesIndexedByCPE()
	if namespaces == nil {
		return nil, fmt.Errorf("no store namespaces found for arbitrary CPEs")
	}

	if requestCPE.Product == wfn.Any || requestCPE.Product == wfn.NA {
		return nil, fmt.Errorf("product name is required")
	}

	for _, namespace := range namespaces {
		allPkgVulns, err := pr.reader.GetVulnerability(namespace, requestCPE.Product)
		if err != nil {
			return nil, fmt.Errorf("provider failed to fetch namespace='%s' product='%s': %w", namespace, requestCPE.Product, err)
		}

		for _, vuln := range allPkgVulns {
			vulnCPEs, err := cpe.NewSlice(vuln.CPEs...)
			if err != nil {
				return nil, err
			}

			// compare the request CPE to the potential matches (excluding version, which is handled downstream)
			candidateMatchCpes := cpe.MatchWithoutVersion(requestCPE, vulnCPEs)

			if len(candidateMatchCpes) > 0 {
				vulnObj, err := vulnerability.NewVulnerability(vuln)
				if err != nil {
					return nil, fmt.Errorf("provider failed to parse cpe='%s': %w", requestCPE.BindToFmtString(), err)
				}

				vulnObj.CPEs = candidateMatchCpes

				vulns = append(vulns, *vulnObj)
			}
		}
	}

	return vulns, nil
}
