// Copyright The OpenTelemetry Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package otelsarama

import (
	"context"
	"fmt"
	"testing"

	"github.com/Shopify/sarama"
	"github.com/Shopify/sarama/mocks"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/oteltest"
	"go.opentelemetry.io/otel/propagation"
	semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
	"go.opentelemetry.io/otel/trace"
)

const (
	topic = "test-topic"
)

func TestWrapPartitionConsumer(t *testing.T) {
	propagators := propagation.TraceContext{}
	// Mock provider
	sr := new(oteltest.SpanRecorder)
	provider := oteltest.NewTracerProvider(oteltest.WithSpanRecorder(sr))

	// Mock partition consumer controller
	consumer := mocks.NewConsumer(t, sarama.NewConfig())
	mockPartitionConsumer := consumer.ExpectConsumePartition(topic, 0, 0)

	// Create partition consumer
	partitionConsumer, err := consumer.ConsumePartition(topic, 0, 0)
	require.NoError(t, err)

	partitionConsumer = WrapPartitionConsumer(partitionConsumer, WithTracerProvider(provider), WithPropagators(propagators))

	consumeAndCheck(t, provider.Tracer(defaultTracerName), sr.Completed, mockPartitionConsumer, partitionConsumer)
}

func TestWrapConsumer(t *testing.T) {
	propagators := propagation.TraceContext{}
	// Mock provider
	sr := new(oteltest.SpanRecorder)
	provider := oteltest.NewTracerProvider(oteltest.WithSpanRecorder(sr))

	// Mock partition consumer controller
	mockConsumer := mocks.NewConsumer(t, sarama.NewConfig())
	mockPartitionConsumer := mockConsumer.ExpectConsumePartition(topic, 0, 0)

	// Wrap consumer
	consumer := WrapConsumer(mockConsumer, WithTracerProvider(provider), WithPropagators(propagators))

	// Create partition consumer
	partitionConsumer, err := consumer.ConsumePartition(topic, 0, 0)
	require.NoError(t, err)

	consumeAndCheck(t, provider.Tracer(defaultTracerName), sr.Completed, mockPartitionConsumer, partitionConsumer)
}

func consumeAndCheck(t *testing.T, mt trace.Tracer, complFn func() []*oteltest.Span, mockPartitionConsumer *mocks.PartitionConsumer, partitionConsumer sarama.PartitionConsumer) {
	// Create message with span context
	ctx, _ := mt.Start(context.Background(), "")
	message := sarama.ConsumerMessage{Key: []byte("foo")}
	propagators := propagation.TraceContext{}
	propagators.Inject(ctx, NewConsumerMessageCarrier(&message))

	// Produce message
	mockPartitionConsumer.YieldMessage(&message)
	mockPartitionConsumer.YieldMessage(&sarama.ConsumerMessage{Key: []byte("foo2")})

	// Consume messages
	msgList := make([]*sarama.ConsumerMessage, 2)
	msgList[0] = <-partitionConsumer.Messages()
	msgList[1] = <-partitionConsumer.Messages()
	require.NoError(t, partitionConsumer.Close())
	// Wait for the channel to be closed
	<-partitionConsumer.Messages()

	// Check spans length
	spans := complFn()
	assert.Len(t, spans, 2)

	expectedList := []struct {
		attributeList []attribute.KeyValue
		parentSpanID  trace.SpanID
		kind          trace.SpanKind
		msgKey        []byte
	}{
		{
			attributeList: []attribute.KeyValue{
				semconv.MessagingSystemKey.String("kafka"),
				semconv.MessagingDestinationKindTopic,
				semconv.MessagingDestinationKey.String("test-topic"),
				semconv.MessagingOperationReceive,
				semconv.MessagingMessageIDKey.String("1"),
				kafkaPartitionKey.Int64(0),
			},
			parentSpanID: trace.SpanContextFromContext(ctx).SpanID(),
			kind:         trace.SpanKindConsumer,
			msgKey:       []byte("foo"),
		},
		{
			attributeList: []attribute.KeyValue{
				semconv.MessagingSystemKey.String("kafka"),
				semconv.MessagingDestinationKindTopic,
				semconv.MessagingDestinationKey.String("test-topic"),
				semconv.MessagingOperationReceive,
				semconv.MessagingMessageIDKey.String("2"),
				kafkaPartitionKey.Int64(0),
			},
			kind:   trace.SpanKindConsumer,
			msgKey: []byte("foo2"),
		},
	}

	for i, expected := range expectedList {
		t.Run(fmt.Sprint("index", i), func(t *testing.T) {
			span := spans[i]

			assert.Equal(t, expected.parentSpanID, span.ParentSpanID())

			var sc trace.SpanContext
			if i == 0 {
				sc = trace.SpanContextFromContext(propagators.Extract(context.Background(), NewConsumerMessageCarrier(msgList[i])))
			} else {
				sc = trace.SpanContextFromContext(propagators.Extract(context.Background(), NewConsumerMessageCarrier(msgList[i])))
				sc = sc.WithRemote(false)
			}
			assert.Equal(t, sc, span.SpanContext())

			assert.Equal(t, "kafka.consume", span.Name())
			assert.Equal(t, expected.kind, span.SpanKind())
			assert.Equal(t, expected.msgKey, msgList[i].Key)
			for _, k := range expected.attributeList {
				assert.Equal(t, k.Value, span.Attributes()[k.Key], k.Key)
			}
		})
	}
}

func TestConsumerConsumePartitionWithError(t *testing.T) {
	// Mock partition consumer controller
	mockConsumer := mocks.NewConsumer(t, sarama.NewConfig())
	mockConsumer.ExpectConsumePartition(topic, 0, 0)

	consumer := WrapConsumer(mockConsumer)
	_, err := consumer.ConsumePartition(topic, 0, 0)
	assert.NoError(t, err)
	// Consume twice
	_, err = consumer.ConsumePartition(topic, 0, 0)
	assert.Error(t, err)
}

func BenchmarkWrapPartitionConsumer(b *testing.B) {
	// Mock provider
	provider := oteltest.NewTracerProvider()

	mockPartitionConsumer, partitionConsumer := createMockPartitionConsumer(b)

	partitionConsumer = WrapPartitionConsumer(partitionConsumer, WithTracerProvider(provider))
	message := sarama.ConsumerMessage{Key: []byte("foo")}

	b.ReportAllocs()
	b.ResetTimer()

	for i := 0; i < b.N; i++ {
		mockPartitionConsumer.YieldMessage(&message)
		<-partitionConsumer.Messages()
	}
}

func BenchmarkMockPartitionConsumer(b *testing.B) {
	mockPartitionConsumer, partitionConsumer := createMockPartitionConsumer(b)

	message := sarama.ConsumerMessage{Key: []byte("foo")}

	b.ReportAllocs()
	b.ResetTimer()

	for i := 0; i < b.N; i++ {
		mockPartitionConsumer.YieldMessage(&message)
		<-partitionConsumer.Messages()
	}
}

func createMockPartitionConsumer(b *testing.B) (*mocks.PartitionConsumer, sarama.PartitionConsumer) {
	// Mock partition consumer controller
	consumer := mocks.NewConsumer(b, sarama.NewConfig())
	mockPartitionConsumer := consumer.ExpectConsumePartition(topic, 0, 0)

	// Create partition consumer
	partitionConsumer, err := consumer.ConsumePartition(topic, 0, 0)
	require.NoError(b, err)
	return mockPartitionConsumer, partitionConsumer
}
