# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Chain Tests."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports
import mock

import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
from tensorflow_probability.python import bijectors as tfb
from tensorflow_probability.python.bijectors import bijector_test_util
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.internal import test_util


class ShapeChanging(tfb.Bijector):
  """Only used for op_ndims manipulation."""

  def __init__(self, forward_min_event_ndims=0, inverse_min_event_ndims=3):
    super(ShapeChanging, self).__init__(
        forward_min_event_ndims=forward_min_event_ndims,
        inverse_min_event_ndims=inverse_min_event_ndims,
        validate_args=False, name="shape_changer")


@test_util.test_all_tf_execution_regimes
class ChainBijectorTest(test_util.TestCase):
  """Tests the correctness of the Y = Chain(bij1, bij2, bij3) transformation."""

  def testBijector(self):
    chain = tfb.Chain((tfb.Exp(), tfb.Softplus()))
    self.assertStartsWith(chain.name, "chain_of_exp_of_softplus")
    x = np.asarray([[[1., 2.],
                     [2., 3.]]])
    self.assertAllClose(1. + np.exp(x), self.evaluate(chain.forward(x)))
    self.assertAllClose(np.log(x - 1.), self.evaluate(chain.inverse(x)))
    self.assertAllClose(
        -np.sum(np.log(x - 1.), axis=2),
        self.evaluate(chain.inverse_log_det_jacobian(x, event_ndims=1)))
    self.assertAllClose(
        np.sum(x, axis=2),
        self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1)))

  def testBijectorIdentity(self):
    chain = tfb.Chain()
    self.assertStartsWith(chain.name, "identity")
    x = np.asarray([[[1., 2.],
                     [2., 3.]]])
    self.assertAllClose(x, self.evaluate(chain.forward(x)))
    self.assertAllClose(x, self.evaluate(chain.inverse(x)))
    self.assertAllClose(
        0., self.evaluate(chain.inverse_log_det_jacobian(x, event_ndims=1)))
    self.assertAllClose(
        0., self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1)))

  def testNestedDtype(self):
    chain = tfb.Chain([
        tfb.Identity(),
        tfb.Scale(tf.constant(2., tf.float64)),
        tfb.Identity()
    ])

    self.assertAllClose(tf.constant([2, 4, 6], tf.float64),
                        self.evaluate(chain.forward([1, 2, 3])))

  def testScalarCongruency(self):
    chain = tfb.Chain((tfb.Exp(), tfb.Softplus()))
    bijector_test_util.assert_scalar_congruency(
        chain, lower_x=1e-3, upper_x=1.5, rtol=0.05, eval_func=self.evaluate)

  def testShapeGetters(self):
    chain = tfb.Chain([
        tfb.SoftmaxCentered(validate_args=True),
        tfb.SoftmaxCentered(validate_args=True),
    ])
    x = tf.TensorShape([1])
    y = tf.TensorShape([2 + 1])
    self.assertAllEqual(y, chain.forward_event_shape(x))
    self.assertAllEqual(
        tensorshape_util.as_list(y),
        self.evaluate(
            chain.forward_event_shape_tensor(tensorshape_util.as_list(x))))
    self.assertAllEqual(x, chain.inverse_event_shape(y))
    self.assertAllEqual(
        tensorshape_util.as_list(x),
        self.evaluate(
            chain.inverse_event_shape_tensor(tensorshape_util.as_list(y))))

  def testMinEventNdimsChain(self):
    chain = tfb.Chain([tfb.Exp(), tfb.Exp(), tfb.Exp()])
    self.assertEqual(0, chain.forward_min_event_ndims)
    self.assertEqual(0, chain.inverse_min_event_ndims)

    chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
                       tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
                       tfb.ScaleMatvecDiag(scale_diag=[1., 1.])])
    self.assertEqual(1, chain.forward_min_event_ndims)
    self.assertEqual(1, chain.inverse_min_event_ndims)

    chain = tfb.Chain([tfb.Exp(), tfb.ScaleMatvecDiag(scale_diag=[1., 1.])])
    self.assertEqual(1, chain.forward_min_event_ndims)
    self.assertEqual(1, chain.inverse_min_event_ndims)

    chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), tfb.Exp()])
    self.assertEqual(1, chain.forward_min_event_ndims)
    self.assertEqual(1, chain.inverse_min_event_ndims)

    chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
                       tfb.Exp(),
                       tfb.Softplus(),
                       tfb.ScaleMatvecDiag(scale_diag=[1., 1.])])
    self.assertEqual(1, chain.forward_min_event_ndims)
    self.assertEqual(1, chain.inverse_min_event_ndims)

  def testMinEventNdimsShapeChangingAddDims(self):
    chain = tfb.Chain([ShapeChanging()])
    self.assertEqual(0, chain.forward_min_event_ndims)
    self.assertEqual(3, chain.inverse_min_event_ndims)

    chain = tfb.Chain([ShapeChanging(),
                       tfb.ScaleMatvecDiag(scale_diag=[1., 1.])])
    self.assertEqual(1, chain.forward_min_event_ndims)
    self.assertEqual(4, chain.inverse_min_event_ndims)

    chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
                       ShapeChanging()])
    self.assertEqual(0, chain.forward_min_event_ndims)
    self.assertEqual(3, chain.inverse_min_event_ndims)

    chain = tfb.Chain([ShapeChanging(), ShapeChanging()])
    self.assertEqual(0, chain.forward_min_event_ndims)
    self.assertEqual(6, chain.inverse_min_event_ndims)

  def testMinEventNdimsShapeChangingRemoveDims(self):
    chain = tfb.Chain([ShapeChanging(3, 0)])
    self.assertEqual(3, chain.forward_min_event_ndims)
    self.assertEqual(0, chain.inverse_min_event_ndims)

    chain = tfb.Chain([ShapeChanging(3, 0),
                       tfb.ScaleMatvecDiag(scale_diag=[1., 1.])])
    self.assertEqual(3, chain.forward_min_event_ndims)
    self.assertEqual(0, chain.inverse_min_event_ndims)

    chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]),
                       ShapeChanging(3, 0)])
    self.assertEqual(4, chain.forward_min_event_ndims)
    self.assertEqual(1, chain.inverse_min_event_ndims)

    chain = tfb.Chain([ShapeChanging(3, 0), ShapeChanging(3, 0)])
    self.assertEqual(6, chain.forward_min_event_ndims)
    self.assertEqual(0, chain.inverse_min_event_ndims)

  def testMinEventNdimsShapeChangingAddRemoveDims(self):
    chain = tfb.Chain(
        [ShapeChanging(2, 1),
         ShapeChanging(3, 0),
         ShapeChanging(1, 2)])
    self.assertEqual(4, chain.forward_min_event_ndims)
    self.assertEqual(1, chain.inverse_min_event_ndims)

  def testMinEventNdimsWithJointMap(self):
    jm_0 = tfb.JointMap([ShapeChanging(1, 1), ShapeChanging(3, 1)])
    split = ShapeChanging(1, [1, 1])
    concat = ShapeChanging([1, 1], 1)
    jm_1 = tfb.JointMap([ShapeChanging(1, 0), ShapeChanging(1, 1)])

    self.assertFalse(jm_0.has_static_min_event_ndims)
    self.assertFalse(jm_1.has_static_min_event_ndims)
    self.assertTrue(split.has_static_min_event_ndims)
    self.assertTrue(concat.has_static_min_event_ndims)

    # Decidable. Inner bijectors have static min_event_ndims.
    chain = tfb.Chain([jm_0, split, concat, jm_1])
    self.assertTrue(chain.has_static_min_event_ndims)
    self.assertAllEqualNested([4, 3], chain.forward_min_event_ndims)
    self.assertAllEqualNested([3, 1], chain.inverse_min_event_ndims)

    # Undecidable. None of the nested bijectors have known event_ndims.
    chain = tfb.Chain([jm_0, jm_1])
    self.assertFalse(chain.has_static_min_event_ndims)
    self.assertAllEqualNested([None, None], chain.forward_min_event_ndims)
    self.assertAllEqualNested([None, None], chain.inverse_min_event_ndims)

  def testChainExpAffine(self):
    scale_diag = np.array([1., 2., 3.], dtype=np.float32)
    chain = tfb.Chain([tfb.Exp(), tfb.ScaleMatvecDiag(scale_diag=scale_diag)])
    x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)]
    y = [1., 4., 27.]
    self.assertAllClose(y, self.evaluate(chain.forward(x)))
    self.assertAllClose(x, self.evaluate(chain.inverse(y)))
    self.assertAllClose(
        np.log(6, dtype=np.float32) + np.sum(scale_diag * x),
        self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1)))

    self.assertAllClose(
        -np.log(6, dtype=np.float32) - np.sum(scale_diag * x),
        self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1)))

  def testChainAffineExp(self):
    scale_diag = np.array([1., 2., 3.], dtype=np.float32)
    chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=scale_diag), tfb.Exp()])
    x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)]
    y = [1., 4., 9.]
    self.assertAllClose(y, self.evaluate(chain.forward(x)))
    self.assertAllClose(x, self.evaluate(chain.inverse(y)))
    self.assertAllClose(
        np.log(6, dtype=np.float32) + np.sum(x),
        self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1)))

    self.assertAllClose(
        -np.log(6, dtype=np.float32) - np.sum(x),
        self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1)))

  def testChainIldjWithPlaceholder(self):
    chain = tfb.Chain((tfb.Exp(), tfb.Exp()))
    samples = tf1.placeholder_with_default(
        np.zeros([2, 10], np.float32), shape=None)
    ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0)
    self.assertIsNotNone(ildj)
    self.evaluate(ildj)

  def testChainDynamicToStatic(self):
    if tf.executing_eagerly():
      return

    def xform_dynamic(x):
      return tf1.placeholder_with_default(x, shape=None)

    def xform_static(x):
      tensorshape_util.set_shape(x, [1])
      return x

    def ldj(_):
      return tf.constant(1.)

    # The issue was that the sample's shape was going in-and-out of being fully
    # specified, causing internal consistency issues inside the bijector.
    chain = tfb.Chain([
        tfb.Inline(
            inverse_fn=xform_dynamic,
            forward_min_event_ndims=0,
            forward_log_det_jacobian_fn=ldj,
            forward_fn=xform_dynamic),
        tfb.Inline(
            inverse_fn=xform_static,
            forward_min_event_ndims=0,
            forward_log_det_jacobian_fn=ldj,
            forward_fn=xform_static),
        tfb.Inline(
            inverse_fn=xform_dynamic,
            forward_min_event_ndims=0,
            forward_log_det_jacobian_fn=ldj,
            forward_fn=xform_dynamic)
    ])

    ildj = chain.inverse_log_det_jacobian(
        tf.zeros((2, 3), dtype=tf.float32), event_ndims=1)

    # The shape of `ildj` is known statically to be scalar; its value is
    # not statically known.
    self.assertTrue(tensorshape_util.is_fully_defined(ildj.shape))

    # `ldj_reduce_shape` uses `prefer_static` to get input shapes. That means
    # that we respect statically-known shape information where present.
    # In this case, the manually-assigned static shape is incorrect.
    self.assertEqual(self.evaluate(ildj), -7.)

    # Ditto.
    fldj = chain.forward_log_det_jacobian([0.], event_ndims=0)
    self.assertTrue(tensorshape_util.is_fully_defined(fldj.shape))
    self.assertEqual(self.evaluate(fldj), 3.)

  def testDofChangeError(self):
    exp = tfb.Exp()
    smc = tfb.SoftmaxCentered()

    # Increase in event-size is the last step. No problems here.
    safe_bij = tfb.Chain([smc, exp], validate_args=True)
    self.evaluate(safe_bij.forward_log_det_jacobian([1., 2., 3.], 1))

    # Increase in event-size before Exp.
    raise_bij = tfb.Chain([exp, smc], validate_args=True)
    with self.assertRaisesRegex((ValueError, tf.errors.InvalidArgumentError),
                                r".+degrees of freedom.+"):
      self.evaluate(raise_bij.forward_log_det_jacobian([1., 2., 3.], 1))

    # When validate_args is False, warns instead of raising.
    warn_bij = tfb.Chain([exp, smc], validate_args=False)
    with mock.patch.object(tf, "print", return_value=tf.no_op()) as mock_print:
      self.evaluate(warn_bij.forward_log_det_jacobian([1., 2., 3.], 1))
      print_args, _ = mock_print.call_args
      self.assertRegex(print_args[0], r"WARNING:.+degrees of freedom")

    # When validate_event_shape is False, neither warns nor raises.
    ignore_bij = tfb.Chain([exp, smc], validate_event_size=False)
    self.evaluate(ignore_bij.forward_log_det_jacobian([1., 2., 3.], 1))


if __name__ == "__main__":
  tf.test.main()
