/*
    Copyright (C) 2020 Accurics, Inc.

	Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at

		http://www.apache.org/licenses/LICENSE-2.0

	Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.
*/

package vulnerability

import (
	"context"
	"encoding/json"
	"fmt"
	"os"
	"strings"

	"github.com/Azure/azure-sdk-for-go/profiles/preview/preview/containerregistry/runtime/containerregistry"
	"github.com/Azure/azure-sdk-for-go/services/resourcegraph/mgmt/2021-03-01/resourcegraph"
	"github.com/Azure/go-autorest/autorest"
	"github.com/Azure/go-autorest/autorest/azure"
	"github.com/Azure/go-autorest/autorest/azure/auth"
	"github.com/accurics/terrascan/pkg/iac-providers/output"
	"go.uber.org/zap"
)

const (
	azureURL              = ".azurecr.io"
	azureRegistryPassword = "AZURE_ACR_PASSWORD"
)

//ACR  acr container registry
type ACR struct {
	scanner acrScanner
}

//scanner implementor for acrScanner interface
type scanner struct{}

func init() {
	RegisterContainerRegistry("acr", &ACR{
		scanner: scanner{},
	})
}

//acrScanner holds external Azure methods
type acrScanner interface {
	newAuthorizerFromFile() (autorest.Authorizer, error)
	newResourcegraphClient() resourcegraph.BaseClient
	getResources(context.Context, resourcegraph.QueryRequest, resourcegraph.BaseClient) (resourcegraph.QueryResponse, error)
	newBasicAuthorizer(string, string) *autorest.BasicAuthorizer
	newTagClient(string) containerregistry.TagClient
	getAttributes(context.Context, string, string, containerregistry.TagClient) (containerregistry.TagAttributes, error)
}

//CheckRegistry verify provided image belongs to acr registry
func (a *ACR) checkRegistry(image string) bool {
	domain := GetDomain(image)
	return strings.HasSuffix(domain, azureURL)
}

//newTagClient returns containerregistry tag client
func (scanner) newTagClient(loginURL string) containerregistry.TagClient {
	return containerregistry.NewTagClient(loginURL)
}

//newBasicAuthorizer returns basic athorizer
func (scanner) newBasicAuthorizer(username string, password string) *autorest.BasicAuthorizer {
	return autorest.NewBasicAuthorizer(username, password)
}

//newAuthorizerFromFile creates and returns autherizer
func (scanner) newAuthorizerFromFile() (autorest.Authorizer, error) {
	authorizer, err := auth.NewAuthorizerFromFile(azure.PublicCloud.ResourceManagerEndpoint)
	if err != nil {
		zap.S().Errorf("failed to authorize : %v ", err)
		return nil, err
	}
	return authorizer, nil
}

//newResourcegraphClient returns new resourcegraph client
func (scanner) newResourcegraphClient() resourcegraph.BaseClient {
	return resourcegraph.New()
}

//getResources returns all the vulnerability resources
func (scanner) getResources(ctx context.Context, request resourcegraph.QueryRequest, client resourcegraph.BaseClient) (resourcegraph.QueryResponse, error) {
	return client.Resources(ctx, request)
}

//GetAttributes returns tag attributes
func (scanner) getAttributes(ctx context.Context, name string, reference string, client containerregistry.TagClient) (result containerregistry.TagAttributes, err error) {
	return client.GetAttributes(ctx, name, reference)
}

// GetVulnerabilities - get vulnerabilities from acr registry
func (a *ACR) getVulnerabilities(container output.ContainerDetails, options map[string]interface{}) (vulnerabilities []output.Vulnerability) {
	results, err := a.ScanImage(container.Image)
	if err != nil {
		zap.S().Errorf("error finding vulnerabilities for image %s : %v", container.Image, err)
		return
	}
	for _, result := range results {
		vulnerability := output.Vulnerability{}
		vulnerability.PrepareFromACRImageScan(result)
		if vulnerability.VulnerabilityID != "" {
			vulnerabilities = append(vulnerabilities, vulnerability)
		}
	}
	return
}

//ScanImage - get the image scan result from ACR registry
func (a *ACR) ScanImage(image string) ([]output.ACRResponse, error) {
	results := []output.ACRResponse{}
	resourcegraphClient := a.scanner.newResourcegraphClient()

	authrizer, err := a.scanner.newAuthorizerFromFile()
	if err != nil {
		zap.S().Errorf("failed to authorize for image %s ", image)
		return results, err
	}
	resourcegraphClient.Authorizer = authrizer

	imageDetails, isValidImage := a.getACRImageDetails(image, authrizer)
	if !isValidImage {
		zap.S().Errorf(invalidImageReferenceMsg, image)
		return results, fmt.Errorf(invalidImageReferenceMsg, image)
	}

	RequestOptions := resourcegraph.QueryRequestOptions{
		ResultFormat: resourcegraph.ResultFormatObjectArray,
	}

	query := `securityresources | where type == "microsoft.security/assessments"
	| summarize by assessmentKey=name 
	| join kind=inner (
		securityresources
		 | where type == "microsoft.security/assessments/subassessments"
		 | extend assessmentKey = extract(".*assessments/(.+?)/.*",1,  id)
	 ) on assessmentKey
	| project parse_json(properties)
	| where properties.additionalData.repositoryName =="` + imageDetails.Repository +
		`" and properties.additionalData.registryHost == "` + imageDetails.Registry +
		`" and properties.additionalData.imageDigest == "` + imageDetails.Digest + `"`

	request := resourcegraph.QueryRequest{
		Query:   &query,
		Options: &RequestOptions,
	}

	// Run the query and get the results
	var response, queryErr = a.scanner.getResources(context.Background(), request, resourcegraphClient)
	if queryErr != nil {
		zap.S().Errorf(errorScanningMsg, image, queryErr)
		return results, queryErr
	}

	jsonData, err := json.Marshal(response.Data)
	if err != nil {
		zap.S().Errorf("error marshaling image %s scan results", image)
		return results, fmt.Errorf("error marshaling image %s scan results", image)
	}
	err = json.Unmarshal([]byte(jsonData), &results)
	if err != nil {
		zap.S().Errorf("error unmarshaling image %s scan results", image)
		return results, fmt.Errorf("error unmarshaling image %s scan results", image)
	}
	return results, nil
}

// getImageNameHostDigest - gets image tag, host and digest from iamge name
func (a *ACR) getACRImageDetails(image string, authrizer autorest.Authorizer) (imageDetails ImageDetails, validImage bool) {
	validImage = true
	imageDetails = GetImageDetails(image, imageDetails)
	if imageDetails.Tag == "" && imageDetails.Repository == "" {
		validImage = false
		return
	}
	if imageDetails.Tag == "" {
		imageDetails.Tag = defaultTagValue
	}
	if imageDetails.Digest == "" {
		loginURL := fmt.Sprintf("https://%s", imageDetails.Registry)
		digest := a.getImageDigestFromTag(image, imageDetails, loginURL)
		if digest == "" {
			validImage = false
			return
		}
		imageDetails.Digest = digest
	}
	return

}

// getImageDigestFromTag - get image digest from repository for given tag
func (a *ACR) getImageDigestFromTag(image string, imageDetails ImageDetails, loginURL string) (digest string) {
	zap.S().Debug("fetching digest for image %s", image)
	password := os.Getenv(azureRegistryPassword)
	parts := strings.Split(imageDetails.Registry, azureURL)
	if len(parts) < 2 {
		return
	}
	basicAuthorizer := a.scanner.newBasicAuthorizer(parts[0], password)
	client := a.scanner.newTagClient(loginURL)
	client.Authorizer = basicAuthorizer
	result, err := a.scanner.getAttributes(context.Background(), imageDetails.Repository, imageDetails.Tag, client)
	if err != nil {
		zap.S().Errorf("error getting image  %s digest %v", image, err)
		return
	}
	if result.Attributes != nil && result.Attributes.Digest != nil {
		digest = *result.Attributes.Digest
	}
	return
}
