package output

import (
	"encoding/json"
	"fmt"
	"os"
	"path/filepath"
	"regexp"
	"sort"
	"strings"
	"time"

	"github.com/Rhymond/go-money"
	"github.com/mitchellh/go-homedir"
	"github.com/pkg/errors"
	"github.com/shopspring/decimal"
	"golang.org/x/mod/semver"

	log "github.com/sirupsen/logrus"

	"github.com/infracost/infracost/internal/clierror"
	"github.com/infracost/infracost/internal/config"
	"github.com/infracost/infracost/internal/schema"
)

var (
	minOutputVersion     = "0.2"
	maxOutputVersion     = "0.2"
	GitHubMaxMessageSize = 262144 // bytes
)

type ReportInput struct {
	Metadata map[string]string
	Root     Root
}

// Load reads the file at the location p and the file body into a Root struct. Load naively
// validates that the Infracost JSON body is valid by checking the that the version attribute is within a supported range.
func Load(p string) (Root, error) {
	var out Root
	_, err := os.Stat(p)
	if errors.Is(err, os.ErrNotExist) {
		return out, errors.New("Infracost JSON file does not exist, generate it by running the following command then try again:\ninfracost breakdown --path /code --format json --out-file infracost-base.json")
	}

	data, err := os.ReadFile(p)
	if err != nil {
		return out, fmt.Errorf("error reading Infracost JSON file %w", err)
	}

	err = json.Unmarshal(data, &out)
	if err != nil {
		return out, fmt.Errorf("invalid Infracost JSON file %w, generate it by running the following command then try again:\ninfracost breakdown --path /code --format json --out-file infracost-base.json", err)
	}

	if !checkOutputVersion(out.Version) {
		return out, fmt.Errorf("invalid Infracost JSON file version. Supported versions are %s ≤ x ≤ %s", minOutputVersion, maxOutputVersion)
	}

	return out, nil
}

func LoadPaths(paths []string) ([]ReportInput, error) {
	inputFiles := []string{}

	for _, path := range paths {
		// To make things easier in GitHub actions and other CI environments, we allow path to be a json array, e.g.:
		// --path='["/path/one", "/path/two"]'
		var nestedPaths []string
		err := json.Unmarshal([]byte(path), &nestedPaths)
		if err != nil {
			// This is not a json string so there must be no nested paths
			nestedPaths = []string{path}
		}

		for _, p := range nestedPaths {
			expanded, err := homedir.Expand(p)
			if err != nil {
				return nil, errors.Wrap(err, "Failed to expand path")
			}

			matches, _ := filepath.Glob(expanded)
			if len(matches) > 0 {
				inputFiles = append(inputFiles, matches...)
			} else {
				inputFiles = append(inputFiles, p)
			}
		}
	}

	inputs := make([]ReportInput, 0, len(inputFiles))

	for _, f := range inputFiles {
		r, err := Load(f)
		if err != nil {
			return nil, fmt.Errorf("could not load input file %s err: %w", f, err)
		}

		inputs = append(inputs, ReportInput{
			Metadata: map[string]string{
				"filename": f,
			},
			Root: r,
		})
	}

	return inputs, nil
}

// CompareTo generates an output Root using another Root as the base snapshot.
// Each project in current Root will have all past resources overwritten with the matching projects
// in the prior Root. If we can't find a matching project then we assume that the project
// has been newly created and will show a 100% increase in the output Root.
func CompareTo(c *config.Config, current, prior Root) (Root, error) {
	priorProjects := make(map[string]*schema.Project)
	for _, p := range prior.Projects {
		if _, ok := priorProjects[p.LabelWithMetadata()]; ok {
			return Root{}, fmt.Errorf("Invalid --compare-to Infracost JSON, found duplicate project name %s", p.LabelWithMetadata())
		}

		priorProjects[p.LabelWithMetadata()] = p.ToSchemaProject()
	}

	var schemaProjects schema.Projects
	for _, p := range current.Projects {
		scp := p.ToSchemaProject()
		scp.Diff = scp.Resources
		scp.PastResources = nil
		scp.Metadata.PastPolicySha = ""
		scp.HasDiff = true

		if v, ok := priorProjects[p.LabelWithMetadata()]; ok {
			if !p.Metadata.HasErrors() && !v.Metadata.HasErrors() {
				scp.PastResources = v.Resources
				scp.Metadata.PastPolicySha = v.Metadata.PolicySha
				scp.Diff = schema.CalculateDiff(scp.PastResources, scp.Resources)
			}

			if !p.Metadata.HasErrors() && v.Metadata.HasErrors() {
				// the prior project has errors, but the current one does not
				// The prior errors will be copied over to the current, but we
				// also need to remove the current project costs
				scp.Resources = nil
				scp.Diff = nil
				scp.HasDiff = false
			}

			for _, pastE := range v.Metadata.Errors {
				pastE.Message = "Diff baseline error: " + pastE.Message
				scp.Metadata.Errors = append(scp.Metadata.Errors, pastE)
			}

			delete(priorProjects, p.LabelWithMetadata())
		}

		schemaProjects = append(schemaProjects, scp)
	}

	for _, scp := range priorProjects {
		scp.PastResources = scp.Resources
		scp.Resources = nil
		scp.Metadata.PolicySha = ""
		scp.HasDiff = true
		scp.Diff = schema.CalculateDiff(scp.PastResources, scp.Resources)

		schemaProjects = append(schemaProjects, scp)
	}

	sort.Sort(schemaProjects)

	out, err := ToOutputFormat(c, schemaProjects)
	if err != nil {
		return out, err
	}

	// preserve the summary from the original run
	currentProjects := make(map[string]Project)
	for _, p := range current.Projects {
		currentProjects[p.LabelWithMetadata()] = p
	}
	for i := range out.Projects {
		if v, ok := currentProjects[out.Projects[i].LabelWithMetadata()]; ok {
			out.Projects[i].Summary = v.Summary
			out.Projects[i].fullSummary = v.fullSummary
		}
	}

	out.Summary = current.Summary
	out.FullSummary = current.FullSummary
	out.Currency = current.Currency
	return out, nil
}

func Combine(inputs []ReportInput) (Root, error) {
	var combined Root

	var lastestGeneratedAt time.Time
	var totalHourlyCost *decimal.Decimal
	var totalMonthlyCost *decimal.Decimal
	var pastTotalHourlyCost *decimal.Decimal
	var pastTotalMonthlyCost *decimal.Decimal
	var diffTotalHourlyCost *decimal.Decimal
	var diffTotalMonthlyCost *decimal.Decimal

	projects := make([]Project, 0)
	summaries := make([]*Summary, 0, len(inputs))
	var tagPolicies []TagPolicy
	var finOpsPolicies []FinOpsPolicy
	currency := ""

	var metadata Metadata
	var invalidMetadata bool
	builder := strings.Builder{}
	for i, input := range inputs {
		var err error
		currency, err = checkCurrency(currency, input.Root.Currency)
		if err != nil {
			return combined, err
		}

		projects = append(projects, input.Root.Projects...)

		summaries = append(summaries, input.Root.Summary)

		if input.Root.TimeGenerated.After(lastestGeneratedAt) {
			lastestGeneratedAt = input.Root.TimeGenerated
		}

		if len(input.Root.TagPolicies) > 0 {
			tagPolicies = append(tagPolicies, input.Root.TagPolicies...)
		}

		if len(input.Root.FinOpsPolicies) > 0 {
			finOpsPolicies = append(finOpsPolicies, input.Root.FinOpsPolicies...)
		}

		if input.Root.TotalHourlyCost != nil {
			if totalHourlyCost == nil {
				totalHourlyCost = decimalPtr(decimal.Zero)
			}

			totalHourlyCost = decimalPtr(totalHourlyCost.Add(*input.Root.TotalHourlyCost))
		}
		if input.Root.TotalMonthlyCost != nil {
			if totalMonthlyCost == nil {
				totalMonthlyCost = decimalPtr(decimal.Zero)
			}

			totalMonthlyCost = decimalPtr(totalMonthlyCost.Add(*input.Root.TotalMonthlyCost))
		}
		if input.Root.PastTotalHourlyCost != nil {
			if pastTotalHourlyCost == nil {
				pastTotalHourlyCost = decimalPtr(decimal.Zero)
			}

			pastTotalHourlyCost = decimalPtr(pastTotalHourlyCost.Add(*input.Root.PastTotalHourlyCost))
		}
		if input.Root.PastTotalMonthlyCost != nil {
			if pastTotalMonthlyCost == nil {
				pastTotalMonthlyCost = decimalPtr(decimal.Zero)
			}

			pastTotalMonthlyCost = decimalPtr(pastTotalMonthlyCost.Add(*input.Root.PastTotalMonthlyCost))
		}
		if input.Root.DiffTotalMonthlyCost != nil {
			if diffTotalMonthlyCost == nil {
				diffTotalMonthlyCost = decimalPtr(decimal.Zero)
			}

			diffTotalMonthlyCost = decimalPtr(diffTotalMonthlyCost.Add(*input.Root.DiffTotalMonthlyCost))
		}

		if input.Root.DiffTotalHourlyCost != nil {
			if diffTotalHourlyCost == nil {
				diffTotalHourlyCost = decimalPtr(decimal.Zero)
			}

			diffTotalHourlyCost = decimalPtr(diffTotalHourlyCost.Add(*input.Root.DiffTotalHourlyCost))
		}

		if i != 0 && metadata.VCSRepositoryURL != input.Root.Metadata.VCSRepositoryURL {
			invalidMetadata = true
		}

		metadata = input.Root.Metadata
		builder.WriteString(fmt.Sprintf("%q, ", input.Root.Metadata.VCSRepositoryURL))
	}

	combined.Version = outputVersion
	combined.Currency = currency
	combined.Projects = projects
	combined.TotalHourlyCost = totalHourlyCost
	combined.TotalMonthlyCost = totalMonthlyCost
	combined.PastTotalHourlyCost = pastTotalHourlyCost
	combined.PastTotalMonthlyCost = pastTotalMonthlyCost
	combined.DiffTotalHourlyCost = diffTotalHourlyCost
	combined.DiffTotalMonthlyCost = diffTotalMonthlyCost
	combined.TimeGenerated = lastestGeneratedAt
	combined.Summary = MergeSummaries(summaries)
	combined.Metadata = metadata
	combined.TagPolicies = mergeTagPolicies(tagPolicies)
	combined.FinOpsPolicies = mergeFinOpsPolicies(finOpsPolicies)
	if len(inputs) > 0 {
		combined.CloudURL = inputs[len(inputs)-1].Root.CloudURL
	}

	if invalidMetadata {
		return combined, clierror.NewWarningF(
			"combining Infracost JSON for different VCS repositories %s. Using %s as the top-level repository in the outputted JSON",
			strings.TrimRight(builder.String(), ", "),
			metadata.VCSRepositoryURL,
		)
	}

	return combined, nil
}

func mergeTagPolicies(tagPolicies []TagPolicy) []TagPolicy {
	// gather and merge tag policies by name
	tpMap := map[string]TagPolicy{}
	for _, tp := range tagPolicies {
		if existingTp, ok := tpMap[tp.Name]; ok {
			tp.PrComment = existingTp.PrComment || tp.PrComment
			tp.BlockPr = existingTp.BlockPr || tp.BlockPr
			tp.Resources = append(existingTp.Resources, tp.Resources...)
		}
		tpMap[tp.Name] = tp
	}

	tpMerged := make([]TagPolicy, 0, len(tpMap))
	// use the original tagPolicies array to iterate over the map so the order is preserved
	for _, tp := range tagPolicies {
		if mergedTp, ok := tpMap[tp.Name]; ok {
			tpMerged = append(tpMerged, mergedTp)
			delete(tpMap, tp.Name)
		}
	}

	return tpMerged
}

func mergeFinOpsPolicies(finOpsPolicies []FinOpsPolicy) []FinOpsPolicy {
	// gather and merge tag policies by id
	fpMap := map[string]FinOpsPolicy{}
	for _, fp := range finOpsPolicies {
		if existingFp, ok := fpMap[fp.PolicyID]; ok {
			fp.PrComment = existingFp.PrComment || fp.PrComment
			fp.BlockPr = existingFp.BlockPr || fp.BlockPr
			fp.Resources = append(existingFp.Resources, fp.Resources...)
		}
		fpMap[fp.PolicyID] = fp
	}

	fpMerged := make([]FinOpsPolicy, 0, len(fpMap))
	// use the original tagPolicies array to iterate over the map so the order is preserved
	for _, fp := range finOpsPolicies {
		if mergedFp, ok := fpMap[fp.PolicyID]; ok {
			fpMerged = append(fpMerged, mergedFp)
			delete(fpMap, fp.PolicyID)
		}
	}

	return fpMerged
}

func checkCurrency(inputCurrency, fileCurrency string) (string, error) {
	if fileCurrency == "" {
		fileCurrency = "USD" // default to USD
	}

	if inputCurrency == "" {
		// this must be the first file, save the input currency
		inputCurrency = fileCurrency
	}

	if inputCurrency != fileCurrency {
		return "", fmt.Errorf("Invalid Infracost JSON file currency mismatch.  Can't combine %s and %s", inputCurrency, fileCurrency)
	}

	return inputCurrency, nil
}

func checkOutputVersion(v string) bool {
	if !strings.HasPrefix(v, "v") {
		v = "v" + v
	}
	return semver.Compare(v, "v"+minOutputVersion) >= 0 && semver.Compare(v, "v"+maxOutputVersion) <= 0
}

// FormatOutput returns Root r as the format specified. The default format is a table output.
func FormatOutput(format string, r Root, opts Options) ([]byte, error) {
	var b []byte
	var err error

	if opts.CurrencyFormat != "" {
		addCurrencyFormat(opts.CurrencyFormat)
	}

	switch format {
	case "json":
		b, err = ToJSON(r, opts)
	case "html":
		b, err = ToHTML(r, opts)
	case "diff":
		b, err = ToDiff(r, opts)
	case "github-comment":
		out, error := ToMarkdown(r, opts, MarkdownOptions{MaxMessageSize: GitHubMaxMessageSize})
		b, err = out.Msg, error
	case "gitlab-comment", "azure-repos-comment":
		out, error := ToMarkdown(r, opts, MarkdownOptions{})
		b, err = out.Msg, error
	case "bitbucket-comment":
		out, error := ToMarkdown(r, opts, MarkdownOptions{BasicSyntax: true})
		b, err = out.Msg, error
	case "bitbucket-comment-summary":
		out, error := ToMarkdown(r, opts, MarkdownOptions{BasicSyntax: true, OmitDetails: true})
		b, err = out.Msg, error
	case "slack-message":
		b, err = ToSlackMessage(r, opts)
	default:
		b, err = ToTable(r, opts)
	}

	if err != nil {
		return nil, fmt.Errorf("error generating %s output %w", format, err)
	}

	return b, nil
}

func addCurrencyFormat(currencyFormat string) {
	rgx := regexp.MustCompile(`^(.{3}): (.*)1(,|\.)234(,|\.)?([0-9]*)?(.*)$`)
	m := rgx.FindStringSubmatch(currencyFormat)

	if len(m) == 0 {
		log.Warningf("Invalid currency format: %s", currencyFormat)
		return
	}

	currency := m[1]

	graphemeWithSpace := m[2]
	grapheme := strings.TrimSpace(graphemeWithSpace)
	template := "$" + strings.Repeat(" ", len(graphemeWithSpace)-len(grapheme)) + "1"

	if graphemeWithSpace == "" {
		graphemeWithSpace = m[6]
		grapheme = strings.TrimSpace(graphemeWithSpace)
		template = "1" + strings.Repeat(" ", len(graphemeWithSpace)-len(grapheme)) + "$"
	}

	thousand := m[3]
	decimal := m[4]
	fraction := len(m[5])

	money.AddCurrency(currency, grapheme, template, decimal, thousand, fraction)
}
