// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef TFX_BSL_CC_CODERS_EXAMPLE_CODER_H_
#define TFX_BSL_CC_CODERS_EXAMPLE_CODER_H_

#include <memory>

#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tfx_bsl/cc/util/status.h"

namespace arrow {
class Field;
class RecordBatch;
class Schema;
class StructType;
}  // namespace arrow

namespace tfx_bsl {
// ExamplesToRecordBatchDecoder converts a vector of Example protos to an Arrow
// RecordBatch.
//
// If a schema is provided then the record batch will contain only the fields
// from the schema, in the same order as the Schema.  The data type of the
// schema to determine the field types, with INT, BYTES and FLOAT fields in the
// schema corresponding to the Arrow data types large_list[int64],
// large_list[large_binary] and large_list[float32].
//
// If a schema is not provided then the data type will be inferred, and chosen
// from list_type[int64], list_type[binary_type] and list_type[float32].  In the
// case where no data type can be inferred the arrow null type will be inferred.
//
// If the decoder detects a non-conformant Example, it will return an error.

class FeatureDecoder;
class FeatureListDecoder;
class ExamplesToRecordBatchDecoder {
 public:
  static Status Make(
      absl::optional<absl::string_view> serialized_schema,
      std::unique_ptr<ExamplesToRecordBatchDecoder>* result);
  ~ExamplesToRecordBatchDecoder();

  ExamplesToRecordBatchDecoder(const ExamplesToRecordBatchDecoder&) = delete;
  ExamplesToRecordBatchDecoder& operator=(const ExamplesToRecordBatchDecoder&) =
      delete;

  // Decodes a batch of serialized examples into a RecordBatch.
  Status DecodeBatch(const std::vector<absl::string_view>& serialized_examples,
                     std::shared_ptr<arrow::RecordBatch>* record_batch) const;

  // Returns the schema of record batches that would be generated by
  // DecodeBatch, if a TFMD schema was provided at construction time, otherwise
  // returns nullptr.
  std::shared_ptr<arrow::Schema> ArrowSchema() const;

 private:
  ExamplesToRecordBatchDecoder(
      std::shared_ptr<arrow::Schema> arrow_schema,
      std::unique_ptr<const absl::flat_hash_map<
          std::string, std::unique_ptr<FeatureDecoder>>>
          feature_decoders);
  Status DecodeFeatureDecodersAvailable(
      const std::vector<absl::string_view>& serialized_examples,
      std::shared_ptr<arrow::RecordBatch>* record_batch) const;
  Status DecodeFeatureDecodersUnavailable(
      const std::vector<absl::string_view>& serialized_examples,
      std::shared_ptr<arrow::RecordBatch>* record_batch) const;

 private:
  const std::shared_ptr<arrow::Schema> arrow_schema_;
  const std::unique_ptr<
      const absl::flat_hash_map<std::string, std::unique_ptr<FeatureDecoder>>>
      feature_decoders_;
};

// Converts a RecordBatch to a list of examples.
//
// The fields of the RecordBatch must have types list[dtype] or
// large_list[dtype], where dtype is one of float32, int64, binary, or
// large_binary.
Status RecordBatchToExamples(const arrow::RecordBatch& record_batch,
                             std::vector<std::string>* serialized_examples);

// SequenceExamplesToRecordBatchDecoder converts a vector of SequenceExample
// protos to an Arrow RecordBatch.
//
// If a schema is provided, the record batch will contain only the fields from
// the schema. The context fields will be in the same order as in the schema.
// The sequence fields will be arranged in a struct array within a single column
// of the record batch. The data type of the schema determines the field types,
// with INT, BYTES and FLOAT fields in the schema corresponding to the Arrow
// data types large_list[int64], large_list[large_binary], and
// large_list[float32] for context features; and large_list[large_list[int64]],
// large_list[large_list[large_binary]], and large_list[large_list[float32]] for
// sequence features.
//
// If a schema is not provided, then the data type will be inferred and chosen
// from list_type[int64], list_type[binary_type], and list_type[float32] for
// context features; and list_type[list_type[int64]],
// list_type[list_type[binary_type]], and list_type[list_type[float32]] for
// sequence features. When no data type can be inferred, the arrow null type
// will be used for context features, and list_type<null> will be used for
// sequence features.
//
// If the decoder detects a non-conformant SequenceExample, it will return an
// error.
class SequenceExamplesToRecordBatchDecoder {
 public:
  // Creates a decoder that can be used to convert a vector of SequenceExample
  // protos to an Arrow RecordBatch (with the DecodeBatch method). See the class
  // help for detailed information about how the decoder does this conversion
  // with or without a schema provided.
  static Status Make(
      const absl::optional<absl::string_view>& serialized_schema,
      const std::string& sequence_feature_column_name,
      std::unique_ptr<SequenceExamplesToRecordBatchDecoder>* result);
  ~SequenceExamplesToRecordBatchDecoder();

  SequenceExamplesToRecordBatchDecoder(
      const SequenceExamplesToRecordBatchDecoder&) = delete;
  SequenceExamplesToRecordBatchDecoder& operator=(
      const SequenceExamplesToRecordBatchDecoder&) = delete;

  // Decodes a vector of SequenceExample protos to an Arrow RecordBatch.
  Status DecodeBatch(
      const std::vector<absl::string_view>& serialized_sequence_examples,
      std::shared_ptr<arrow::RecordBatch>* record_batch) const;

  // Returns the schema of record batches that would be generated by
  // DecodeBatch, if a TFMD schema was provided at construction time, otherwise
  // returns nullptr.
  std::shared_ptr<arrow::Schema> ArrowSchema() const;

 private:
  SequenceExamplesToRecordBatchDecoder(
      const std::string& sequence_feature_column_name,
      std::shared_ptr<arrow::Schema> arrow_schema,
      std::shared_ptr<arrow::StructType> sequence_features_struct_type,
      std::unique_ptr<const absl::flat_hash_map<
          std::string, std::unique_ptr<FeatureDecoder>>>
          context_feature_decoders,
      std::unique_ptr<const absl::flat_hash_map<
          std::string, std::unique_ptr<FeatureListDecoder>>>
          sequence_feature_decoders);

  // Decodes a vector of sequence examples to an Arrow RecordBatch where
  // feature list decoders have been created based on the schema passed to
  // the decoder. All sequence features in the schema must be specified as
  // children of a top-level struct feature that has the name specified by
  // 'sequence_feature_column_name'. Only those features in the schema will be
  // included in the resulting RecordBatch.
  Status DecodeFeatureListDecodersAvailable(
      const std::vector<absl::string_view>& serialized_sequence_examples,
      std::shared_ptr<arrow::RecordBatch>* record_batch) const;
  // Decodes a vector of sequence examples to an Arrow RecordBatch where no
  // schema was provided from which to create feature list decoders. In this
  // case, a feature list decoder will be created for each new feature
  // encountered when iterating through the input sequence examples.
  Status DecodeFeatureListDecodersUnavailable(
      const std::vector<absl::string_view>& serialized_sequence_examples,
      std::shared_ptr<arrow::RecordBatch>* record_batch) const;

 private:
  // The name of the column used to hold all of the sequence features, which are
  // arranged in a StructArray. If a schema is provided, the top-level struct
  // feature that contains all sequence features as children must use this name.
  // No other features should use this name.
  const std::string sequence_feature_column_name_;
  const std::shared_ptr<arrow::Schema> arrow_schema_;
  // The type of the StructArray used to hold the sequence features.
  const std::shared_ptr<arrow::StructType> sequence_features_struct_type_;
  const std::unique_ptr<
      const absl::flat_hash_map<std::string, std::unique_ptr<FeatureDecoder>>>
      context_feature_decoders_;
  const std::unique_ptr<const absl::flat_hash_map<
      std::string, std::unique_ptr<FeatureListDecoder>>>
      sequence_feature_decoders_;
};
}  // namespace tfx_bsl

#endif  // TFX_BSL_CC_CODERS_EXAMPLE_CODER_H_
