/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_SHIM_TF_OP_SHIM_H_
#define TENSORFLOW_LITE_KERNELS_SHIM_TF_OP_SHIM_H_

#include <memory>
#include <string>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/registration/registration.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/lite/kernels/shim/op_kernel.h"
#include "tensorflow/lite/kernels/shim/shape.h"
#include "tensorflow/lite/kernels/shim/tf_tensor_view.h"

namespace tflite {
namespace shim {

// TF implementation of the methods during an op kernel initialization
class TfInitContext : public InitContext<TfInitContext> {
 public:
  explicit TfInitContext(const ::tensorflow::OpKernelConstruction* context);
  // Read a given attribute
  absl::StatusOr<AttrValue> GetAttr(const std::string& attr_name) const;

 private:
  const ::tensorflow::OpKernelConstruction* context_;
};

// TF implementation of the methods during an op kernel invocation
class TfInvokeContext : public InvokeContext<TfInvokeContext> {
 public:
  explicit TfInvokeContext(::tensorflow::OpKernelContext* context);
  // Read an input tensor
  ConstTensorViewOr GetInput(const int idx) const;
  // Get a mutable output tensor
  TensorViewOr GetOutput(const int idx, const Shape& shape) const;

 private:
  ::tensorflow::OpKernelContext* context_;
};

// TF implementation of the methods during shape inference
class TfShapeInferenceContext
    : public ShapeInferenceContext<TfShapeInferenceContext> {
 public:
  explicit TfShapeInferenceContext(
      ::tensorflow::shape_inference::InferenceContext* context);
  // Read an input tensor shape
  ShapeOr GetInputShape(const int idx) const;
  // Set an output tensor shape
  absl::Status SetOutputShape(const int idx, const Shape& shape);
  // Read an input tensor during shape inference
  ConstTensorViewOr GetInputTensor(const int idx) const;

 private:
  ::tensorflow::shape_inference::InferenceContext* context_;
};

// Converts absl::Status to tensorflow::Status
::tensorflow::Status FromAbslStatus(const ::absl::Status& s);
// Converts to tensorflow::Status to absl::Status
::absl::Status ToAbslStatus(const ::tensorflow::Status& s);

// The adaptor between an op implementation (OpKernelShim subclass) and TF
// runtime
template <template <Runtime> typename Impl>
class TfOpKernel : public ::tensorflow::OpKernel {
 public:
  using ImplType = Impl<Runtime::kTf>;

  explicit TfOpKernel(::tensorflow::OpKernelConstruction* c)
      : OpKernel(c), impl_(absl::make_unique<ImplType>()) {
    TfInitContext ctx(c);
    c->SetStatus(FromAbslStatus(impl_->Init(&ctx)));
  }

  // The main computation of the op
  void Compute(::tensorflow::OpKernelContext* c) override {
    TfInvokeContext ctx(c);
    OP_REQUIRES_OK(c, FromAbslStatus(impl_->Invoke(&ctx)));
  }

  // Shape inference for the op.
  static tensorflow::Status ShapeInference(
      ::tensorflow::shape_inference::InferenceContext* c) {
    TF_RETURN_IF_ERROR(ValidateInputRanks(c));
    TfShapeInferenceContext ctx(c);
    return FromAbslStatus(ImplType::ShapeInference(&ctx));
  }

  // The operation name
  static const char* OpName() { return ImplType::kOpName; }

 protected:
  static tensorflow::Status ValidateInputRanks(
      ::tensorflow::shape_inference::InferenceContext* c);

  std::unique_ptr<OpKernelShim<Impl, Runtime::kTf>> impl_;
};

static_assert(::tensorflow::shape_inference::InferenceContext::kUnknownDim ==
                  Shape::kUnknownDim,
              "The values must match.");
static_assert(::tensorflow::shape_inference::InferenceContext::kUnknownRank ==
                  Shape::kUnknownRank,
              "The values must match.");

// Builds the OpDef to register theop with the TF runtime
template <typename Kernel>
::tensorflow::register_op::OpDefBuilderWrapper CreateOpDefBuilderWrapper() {
  auto ret =
      ::tensorflow::register_op::OpDefBuilderWrapper(Kernel::ImplType::kOpName);
  for (const auto& input : Kernel::ImplType::Inputs())
    ret = ret.Input(std::string(input.name_type));
  for (const auto& output : Kernel::ImplType::Outputs())
    ret = ret.Output(std::string(output.name_type));
  for (const auto& attr : Kernel::ImplType::Attrs()) ret = ret.Attr(attr);
  ret.SetShapeFn(Kernel::ShapeInference).Doc(Kernel::ImplType::kDoc);
  return ret;
}

template <>
struct ContextTypeForRuntime<Runtime::kTf> {
  using Init = TfInitContext;
  using Invoke = TfInvokeContext;
  using ShapeInference = TfShapeInferenceContext;
};

// Macros for defining an op. These are taken from op.h because they need to be
// slightly modified here.
#define REGISTER_OP_SHIM_IMPL(ctr, name, op_kernel_cls)           \
  static ::tensorflow::InitOnStartupMarker const register_op##ctr \
      TF_ATTRIBUTE_UNUSED =                                       \
          TF_INIT_ON_STARTUP_IF(SHOULD_REGISTER_OP(name))         \
          << ::tflite::shim::CreateOpDefBuilderWrapper<op_kernel_cls>()

#define REGISTER_TF_OP_SHIM(name, op_kernel_cls) \
  TF_ATTRIBUTE_ANNOTATE("tf:op")                 \
  TF_NEW_ID_FOR_INIT(REGISTER_OP_SHIM_IMPL, name, op_kernel_cls)

////////////////////////////////////////////
///////////////////////////// Implementation

template <template <Runtime> typename Impl>
tensorflow::Status TfOpKernel<Impl>::ValidateInputRanks(
    ::tensorflow::shape_inference::InferenceContext* c) {
  static const auto input_ranks = [&]() {
    const auto input_decls = ImplType::Inputs();
    auto ret = new std::vector<int>;
    ret->reserve(input_decls.size());
    for (const TensorDeclaration& input_decl : input_decls) {
      if (input_decl.shape.has_value())
        ret->emplace_back(input_decl.shape->size());
      else
        ret->emplace_back(-1);
    }
    return ret;
  }();
  ::tensorflow::shape_inference::ShapeHandle input_shape;
  for (int i = 0; i < input_ranks->size(); ++i) {
    const int input_rank_i = (*input_ranks)[i];
    if (input_rank_i == -1) continue;
    TF_RETURN_IF_ERROR(c->WithRank(c->input(i), input_rank_i, &input_shape));
  }
  return ::tensorflow::Status::OK();
}

}  // namespace shim
}  // namespace tflite

#endif  // TENSORFLOW_LITE_KERNELS_SHIM_TF_OP_SHIM_H_
