package services

import (
	"context"
	"crypto/tls"
	"encoding/json"
	"fmt"
	"io"
	"net"
	"net/http"
	"strings"
	"time"

	"github.com/rancher/rke/hosts"
	"github.com/sirupsen/logrus"
	etcdclientv2 "go.etcd.io/etcd/client/v2"
	etcdclientv3 "go.etcd.io/etcd/client/v3"
	"google.golang.org/grpc"
)

func getEtcdClientV2(ctx context.Context, etcdHost *hosts.Host, localConnDialerFactory hosts.DialerFactory, cert, key []byte) (etcdclientv2.Client, error) {
	dialer, err := getEtcdDialer(localConnDialerFactory, etcdHost)
	if err != nil {
		return nil, fmt.Errorf("failed to create a dialer for host [%s]: %v", etcdHost.Address, err)
	}
	tlsConfig, err := getEtcdTLSConfig(cert, key)
	if err != nil {
		return nil, err
	}

	var defaultEtcdTransport etcdclientv2.CancelableTransport = &http.Transport{
		Dial:                dialer,
		TLSClientConfig:     tlsConfig,
		TLSHandshakeTimeout: 10 * time.Second,
	}

	cfg := etcdclientv2.Config{
		Endpoints: []string{"https://" + etcdHost.InternalAddress + ":2379"},
		Transport: defaultEtcdTransport,
	}

	return etcdclientv2.New(cfg)
}

func getEtcdClientV3(ctx context.Context, etcdHost *hosts.Host, localConnDialerFactory hosts.DialerFactory, cert, key []byte) (*etcdclientv3.Client, error) {
	dialer, err := getEtcdDialer(localConnDialerFactory, etcdHost)
	if err != nil {
		return nil, fmt.Errorf("failed to create a dialer for host [%s]: %v", etcdHost.Address, err)
	}
	tlsConfig, err := getEtcdTLSConfig(cert, key)
	if err != nil {
		return nil, err
	}

	cfg := etcdclientv3.Config{
		Endpoints:   []string{"https://" + etcdHost.InternalAddress + ":2379"},
		TLS:         tlsConfig,
		DialOptions: []grpc.DialOption{grpc.WithContextDialer(wrapper(dialer))},
		DialTimeout: 5 * time.Second,
	}

	return etcdclientv3.New(cfg)

}

func wrapper(f func(network, address string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) {
	return func(_ context.Context, address string) (net.Conn, error) {
		return f("tcp", address)
	}
}

func isEtcdHealthy(localConnDialerFactory hosts.DialerFactory, host *hosts.Host, cert, key []byte, url string) error {
	logrus.Debugf("[etcd] check etcd cluster health on host [%s]", host.Address)
	var finalErr error
	var healthy string
	// given a max election timeout of 50000ms (50s), max re-election of 77 seconds was seen
	// this allows for 18 * 5 seconds = 90 seconds of re-election
	for i := 0; i < 18; i++ {
		dialer, err := getEtcdDialer(localConnDialerFactory, host)
		if err != nil {
			return err
		}
		tlsConfig, err := getEtcdTLSConfig(cert, key)
		if err != nil {
			return fmt.Errorf("[etcd] failed to create etcd tls config for host [%s]: %v", host.Address, err)
		}

		hc := http.Client{
			Transport: &http.Transport{
				Dial:                dialer,
				TLSClientConfig:     tlsConfig,
				TLSHandshakeTimeout: 10 * time.Second,
			},
		}
		healthy, finalErr = getHealthEtcd(hc, host, url)
		if finalErr != nil {
			logrus.Debugf("[etcd] failed to check health for etcd host [%s]: %v", host.Address, finalErr)
			time.Sleep(5 * time.Second)
			continue
		}
		// Changed this from Debug to Info to inform user on what is happening
		logrus.Infof("[etcd] etcd host [%s] reported healthy=%s", host.Address, healthy)
		if healthy == "true" {
			return nil
		}
		time.Sleep(5 * time.Second)
	}
	if finalErr != nil {
		return fmt.Errorf("[etcd] host [%s] failed to check etcd health: %v", host.Address, finalErr)
	}
	return fmt.Errorf("[etcd] host [%s] reported healthy=%s", host.Address, healthy)
}

func getHealthEtcd(hc http.Client, host *hosts.Host, url string) (string, error) {
	healthy := struct{ Health string }{}
	resp, err := hc.Get(url)
	if err != nil {
		return healthy.Health, fmt.Errorf("failed to get /health for host [%s]: %v", host.Address, err)
	}
	bytes, err := io.ReadAll(resp.Body)
	if err != nil {
		return healthy.Health, fmt.Errorf("failed to read response of /health for host [%s]: %v", host.Address, err)
	}
	resp.Body.Close()
	if err := json.Unmarshal(bytes, &healthy); err != nil {
		return healthy.Health, fmt.Errorf("failed to unmarshal response of /health for host [%s]: %v", host.Address, err)
	}
	return healthy.Health, nil
}

func GetEtcdInitialCluster(hosts []*hosts.Host) string {
	initialCluster := ""
	for i, host := range hosts {
		initialCluster += fmt.Sprintf("etcd-%s=https://%s:2380", host.HostnameOverride, host.InternalAddress)
		if i < (len(hosts) - 1) {
			initialCluster += ","
		}
	}
	return initialCluster
}

func getEtcdDialer(localConnDialerFactory hosts.DialerFactory, etcdHost *hosts.Host) (func(network, address string) (net.Conn, error), error) {
	etcdHost.LocalConnPort = 2379
	var etcdFactory hosts.DialerFactory
	if localConnDialerFactory == nil {
		etcdFactory = hosts.LocalConnFactory
	} else {
		etcdFactory = localConnDialerFactory
	}
	return etcdFactory(etcdHost)
}

func GetEtcdConnString(hosts []*hosts.Host, hostAddress string) string {
	connHosts := []string{}
	containsHostAddress := false
	for _, host := range hosts {
		if host.InternalAddress == hostAddress {
			containsHostAddress = true
			continue
		}
		connHosts = append(connHosts, "https://"+host.InternalAddress+":2379")
	}
	if containsHostAddress {
		connHosts = append([]string{"https://" + hostAddress + ":2379"}, connHosts...)
	}
	return strings.Join(connHosts, ",")
}

func getEtcdTLSConfig(certificate, key []byte) (*tls.Config, error) {
	// get tls config
	x509Pair, err := tls.X509KeyPair([]byte(certificate), []byte(key))
	if err != nil {
		return nil, err

	}
	tlsConfig := &tls.Config{
		InsecureSkipVerify: true,
		Certificates:       []tls.Certificate{x509Pair},
	}
	return tlsConfig, nil
}
