#   Copyright 2024 The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#   MIT License
#
#   Copyright (c) 2021-2022 aesara-devs
#
#   Permission is hereby granted, free of charge, to any person obtaining a copy
#   of this software and associated documentation files (the "Software"), to deal
#   in the Software without restriction, including without limitation the rights
#   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#   copies of the Software, and to permit persons to whom the Software is
#   furnished to do so, subject to the following conditions:
#
#   The above copyright notice and this permission notice shall be included in all
#   copies or substantial portions of the Software.
#
#   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#   SOFTWARE.
import itertools

import numpy as np
import pytensor
import pytensor.tensor as pt
import pytest

from pytensor import Mode
from pytensor.raise_op import assert_op
from pytensor.scan.utils import ScanArgs
from scipy import stats

from pymc.logprob.abstract import _logprob_helper
from pymc.logprob.basic import conditional_logp, logp
from pymc.logprob.scan import (
    construct_scan,
    convert_outer_out_to_in,
    get_random_outer_outputs,
)
from pymc.testing import assert_no_rvs


def create_inner_out_logp(value_map):
    """Create a log-likelihood inner-output.

    This is intended to be use with `get_random_outer_outputs`.

    """
    res = []
    for old_inner_out_var, new_inner_in_var in value_map.items():
        logp = _logprob_helper(old_inner_out_var, new_inner_in_var)
        if new_inner_in_var.name:
            logp.name = f"logp({new_inner_in_var.name})"
        res.append(logp)

    return res


def test_convert_outer_out_to_in_sit_sot():
    """Test a single replacement with `convert_outer_out_to_in`.

    This should be a single SIT-SOT replacement.
    """

    rng_state = np.random.default_rng(123)
    rng_tt = pytensor.shared(rng_state, name="rng", borrow=True)
    rng_tt.tag.is_rng = True
    rng_tt.default_update = rng_tt

    #
    # We create a `Scan` representing a time-series model with normally
    # distributed responses that are dependent on lagged values of both the
    # response `RandomVariable` and a lagged "deterministic" that also depends
    # on the lagged response values.
    #
    def input_step_fn(mu_tm1, y_tm1, rng):
        mu_tm1.name = "mu_tm1"
        y_tm1.name = "y_tm1"
        mu = mu_tm1 + y_tm1 + 1
        mu.name = "mu_t"
        return mu, pt.random.normal(mu, 1.0, rng=rng, name="Y_t")

    (mu_tt, Y_rv), _ = pytensor.scan(
        fn=input_step_fn,
        outputs_info=[
            {
                "initial": pt.as_tensor_variable(0.0, dtype=pytensor.config.floatX),
                "taps": [-1],
            },
            {
                "initial": pt.as_tensor_variable(0.0, dtype=pytensor.config.floatX),
                "taps": [-1],
            },
        ],
        non_sequences=[rng_tt],
        n_steps=10,
    )

    mu_tt.name = "mu_tt"
    mu_tt.owner.inputs[0].name = "mu_all"
    Y_rv.name = "Y_rv"
    Y_all = Y_rv.owner.inputs[0]
    Y_all.name = "Y_all"

    input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)

    # TODO FIXME: Everything below needs to be replaced with explicit asserts
    # on the values in `input_scan_args`

    #
    # Sample from the model and create another `Scan` that computes the
    # log-likelihood of the model at the sampled point.
    #
    Y_obs = pt.as_tensor_variable(Y_rv.eval())
    Y_obs.name = "Y_obs"

    def output_step_fn(y_t, y_tm1, mu_tm1):
        mu_tm1.name = "mu_tm1"
        y_tm1.name = "y_tm1"
        mu = mu_tm1 + y_tm1 + 1
        mu.name = "mu_t"
        logp = _logprob_helper(pt.random.normal(mu, 1.0), y_t)
        logp.name = "logp"
        return mu, logp

    (mu_tt, Y_logp), _ = pytensor.scan(
        fn=output_step_fn,
        sequences=[{"input": Y_obs, "taps": [0, -1]}],
        outputs_info=[
            {
                "initial": pt.as_tensor_variable(0.0, dtype=pytensor.config.floatX),
                "taps": [-1],
            },
            {},
        ],
    )

    Y_logp.name = "Y_logp"
    mu_tt.name = "mu_tt"

    #
    # Get the model output variable that corresponds to the response
    # `RandomVariable`
    #
    oo_idx, oo_var, io_var = get_random_outer_outputs(input_scan_args)[0]

    #
    # Convert the original model `Scan` into another `Scan` that's equivalent
    # to the log-likelihood `Scan` given above.
    # In other words, automatically construct the log-likelihood `Scan` based
    # on the model `Scan`.
    #
    value_map = {Y_all: Y_obs}
    test_scan_args = convert_outer_out_to_in(
        input_scan_args,
        [oo_var],
        value_map,
        inner_out_fn=create_inner_out_logp,
    )

    scan_out, updates = construct_scan(test_scan_args)

    #
    # Evaluate the manually and automatically constructed log-likelihoods and
    # compare.
    #
    res = scan_out[oo_idx].eval()
    exp_res = Y_logp.eval()

    assert np.array_equal(res, exp_res)


def test_convert_outer_out_to_in_mit_sot():
    """Test a single replacement with `convert_outer_out_to_in`.

    This should be a single MIT-SOT replacement.
    """

    rng_state = np.random.default_rng(1234)
    rng_tt = pytensor.shared(rng_state, name="rng", borrow=True)
    rng_tt.tag.is_rng = True
    rng_tt.default_update = rng_tt

    #
    # This is a very simple model with only one output, but multiple
    # taps/lags.
    #
    def input_step_fn(y_tm1, y_tm2, rng):
        y_tm1.name = "y_tm1"
        y_tm2.name = "y_tm2"
        return pt.random.normal(y_tm1 + y_tm2, 1.0, rng=rng, name="Y_t")

    Y_rv, _ = pytensor.scan(
        fn=input_step_fn,
        outputs_info=[
            {"initial": pt.as_tensor_variable(np.r_[-1.0, 0.0]), "taps": [-1, -2]},
        ],
        non_sequences=[rng_tt],
        n_steps=10,
    )

    Y_rv.name = "Y_rv"
    Y_all = Y_rv.owner.inputs[0]
    Y_all.name = "Y_all"

    Y_obs = pt.as_tensor_variable(Y_rv.eval())
    Y_obs.name = "Y_obs"

    input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner)

    # TODO FIXME: Everything below needs to be replaced with explicit asserts
    # on the values in `input_scan_args`

    #
    # The corresponding log-likelihood
    #
    def output_step_fn(y_t, y_tm1, y_tm2):
        y_t.name = "y_t"
        y_tm1.name = "y_tm1"
        y_tm2.name = "y_tm2"
        logp = _logprob_helper(pt.random.normal(y_tm1 + y_tm2, 1.0), y_t)
        logp.name = "logp(y_t)"
        return logp

    Y_logp, _ = pytensor.scan(
        fn=output_step_fn,
        sequences=[{"input": Y_obs, "taps": [0, -1, -2]}],
        outputs_info=[{}],
    )

    #
    # Get the model output variable that corresponds to the response
    # `RandomVariable`
    #
    oo_idx, oo_var, io_var = get_random_outer_outputs(input_scan_args)[0]

    #
    # Convert the original model `Scan` into another `Scan` that's equivalent
    # to the log-likelihood `Scan` given above.
    # In other words, automatically construct the log-likelihood `Scan` based
    # on the model `Scan`.

    value_map = {Y_all: Y_obs}
    test_scan_args = convert_outer_out_to_in(
        input_scan_args,
        [oo_var],
        value_map,
        inner_out_fn=create_inner_out_logp,
    )

    scan_out, updates = construct_scan(test_scan_args)

    #
    # Evaluate the manually and automatically constructed log-likelihoods and
    # compare.
    #
    res = scan_out[oo_idx].eval()
    exp_res = Y_logp.eval()

    assert np.array_equal(res, exp_res)


@pytest.mark.parametrize(
    "require_inner_rewrites",
    [
        False,
        True,
    ],
)
def test_scan_joint_logprob(require_inner_rewrites):
    srng = pt.random.RandomStream()

    N_tt = pt.iscalar("N")
    N_val = 10
    N_tt.tag.test_value = N_val

    M_tt = pt.iscalar("M")
    M_val = 2
    M_tt.tag.test_value = M_val

    mus_tt = pt.matrix("mus_t")

    mus_val = np.stack([np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype(
        pytensor.config.floatX
    )
    mus_tt.tag.test_value = mus_val

    sigmas_tt = pt.ones((N_tt,))
    Gamma_rv = srng.dirichlet(pt.ones((M_tt, M_tt)), name="Gamma")

    Gamma_vv = Gamma_rv.clone()
    Gamma_vv.name = "Gamma_vv"

    Gamma_val = np.array([[0.5, 0.5], [0.5, 0.5]])
    Gamma_rv.tag.test_value = Gamma_val

    def scan_fn(mus_t, sigma_t, Gamma_t):
        S_t = srng.categorical(Gamma_t[0], name="S_t")

        if require_inner_rewrites:
            Y_t = srng.normal(mus_t, sigma_t, name="Y_t")[S_t]
        else:
            Y_t = srng.normal(mus_t[S_t], sigma_t, name="Y_t")

        return Y_t, S_t

    (Y_rv, S_rv), _ = pytensor.scan(
        fn=scan_fn,
        sequences=[mus_tt, sigmas_tt],
        non_sequences=[Gamma_rv],
        outputs_info=[{}, {}],
        strict=True,
        name="scan_rv",
    )
    Y_rv.name = "Y"
    S_rv.name = "S"

    y_vv = Y_rv.clone()
    y_vv.name = "y"

    s_vv = S_rv.clone()
    s_vv.name = "s"

    y_logp = conditional_logp({Y_rv: y_vv, S_rv: s_vv, Gamma_rv: Gamma_vv})
    y_logp_combined = pt.sum([pt.sum(factor) for factor in y_logp.values()])

    y_val = np.arange(10)
    s_val = np.array([0, 1, 0, 1, 1, 0, 0, 0, 1, 1])

    test_point = {
        y_vv: y_val,
        s_vv: s_val,
        M_tt: M_val,
        N_tt: N_val,
        mus_tt: mus_val,
        Gamma_vv: Gamma_val,
    }

    y_logp_fn = pytensor.function(list(test_point.keys()), y_logp_combined)

    assert_no_rvs(y_logp_fn.maker.fgraph.outputs[0])

    # Construct the joint log-probability by hand so we can compare it with
    # `y_logp`
    def scan_fn(mus_t, sigma_t, Y_t_val, S_t_val, Gamma_t):
        S_t = pt.random.categorical(Gamma_t[0], name="S_t")
        Y_t = pt.random.normal(mus_t[S_t_val], sigma_t, name="Y_t")
        Y_t_logp, S_t_logp = _logprob_helper(Y_t, Y_t_val), _logprob_helper(S_t, S_t_val)
        Y_t_logp.name = "log(Y_t=y_t)"
        S_t_logp.name = "log(S_t=s_t)"
        return Y_t_logp, S_t_logp

    (Y_rv_logp, S_rv_logp), _ = pytensor.scan(
        fn=scan_fn,
        sequences=[mus_tt, sigmas_tt, y_vv, s_vv],
        non_sequences=[Gamma_vv],
        outputs_info=[{}, {}],
        strict=True,
        name="scan_rv",
    )
    Y_rv_logp.name = "logp(Y=y)"
    S_rv_logp.name = "logp(S=s)"

    Gamma_logp = _logprob_helper(Gamma_rv, Gamma_vv)

    y_logp_ref = Y_rv_logp.sum() + S_rv_logp.sum() + Gamma_logp.sum()

    assert_no_rvs(y_logp_ref)

    y_logp_val = y_logp_combined.eval(test_point)

    y_logp_ref_val = y_logp_ref.eval(test_point)

    assert np.allclose(y_logp_val, y_logp_ref_val)


@pytest.mark.parametrize("remove_asserts", (True, False))
def test_mode_is_kept(remove_asserts):
    mode = Mode().including("local_remove_all_assert") if remove_asserts else None
    x, _ = pytensor.scan(
        fn=lambda x: pt.random.normal(assert_op(x, x > 0)),
        outputs_info=[pt.ones(())],
        n_steps=10,
        mode=mode,
    )
    x.name = "x"
    x_vv = x.clone()
    x_logp = pytensor.function([x_vv], pt.sum(logp(x, x_vv)))

    x_test_val = np.full((10,), -1)
    if remove_asserts:
        assert x_logp(x=x_test_val)
    else:
        with pytest.raises(AssertionError):
            x_logp(x=x_test_val)


def test_scan_non_pure_rv_output():
    grw, _ = pytensor.scan(
        fn=lambda xtm1: pt.random.normal() + xtm1,
        outputs_info=[pt.zeros(())],
        n_steps=10,
        name="grw",
    )

    grw_vv = grw.clone()
    grw_logp = logp(grw, grw_vv)
    assert_no_rvs(grw_logp)

    grw_vv_test = np.arange(10) + 1
    np.testing.assert_array_almost_equal(
        grw_logp.eval({grw_vv: grw_vv_test}),
        stats.norm.logpdf(np.ones(10)),
    )


def test_scan_over_seqs():
    """Test that logprob inference for scans based on sequences (mapping)."""
    rng = np.random.default_rng(543)
    n_steps = 10

    xs = pt.random.normal(size=(n_steps,), name="xs")
    ys, _ = pytensor.scan(
        fn=lambda x: pt.random.normal(x), sequences=[xs], outputs_info=[None], name="ys"
    )

    xs_vv = ys.clone()
    ys_vv = ys.clone()
    ys_logp = conditional_logp({xs: xs_vv, ys: ys_vv})[ys_vv]

    assert_no_rvs(ys_logp)

    xs_test = rng.normal(size=(10,))
    ys_test = rng.normal(size=(10,))
    np.testing.assert_array_almost_equal(
        ys_logp.eval({xs_vv: xs_test, ys_vv: ys_test}),
        stats.norm.logpdf(ys_test, xs_test),
    )


def test_scan_carried_deterministic_state():
    """Test logp of scans with carried states downstream of measured variables.

    A moving average model with 2 lags is used for testing.
    """
    rng = np.random.default_rng(490)
    steps = 99

    rho = pt.vector("rho", shape=(2,))
    sigma = pt.scalar("sigma")

    def ma2_step(eps_tm2, eps_tm1, rho, sigma):
        mu = eps_tm1 * rho[0] + eps_tm2 * rho[1]
        y = pt.random.normal(mu, sigma)
        eps = y - mu
        update = {y.owner.inputs[0]: y.owner.outputs[0]}
        return (eps, y), update

    [_, ma2], ma2_updates = pytensor.scan(
        fn=ma2_step,
        outputs_info=[{"initial": pt.arange(2, dtype="float64"), "taps": range(-2, 0)}, None],
        non_sequences=[rho, sigma],
        n_steps=steps,
        strict=True,
        name="ma2",
    )

    def ref_logp(values, rho, sigma):
        epsilon_tm2 = 0
        epsilon_tm1 = 1
        step_logps = np.zeros_like(values)
        for t, value in enumerate(values):
            mu = epsilon_tm1 * rho[0] + epsilon_tm2 * rho[1]
            step_logps[t] = stats.norm.logpdf(value, mu, sigma)
            epsilon_tm2 = epsilon_tm1
            epsilon_tm1 = value - mu
        return step_logps

    ma2_vv = ma2.clone()
    logp_expr = logp(ma2, ma2_vv)
    assert_no_rvs(logp_expr)

    ma2_test = rng.normal(size=(steps,))
    rho_test = np.array([0.3, 0.7])
    sigma_test = 0.9

    np.testing.assert_array_almost_equal(
        logp_expr.eval({ma2_vv: ma2_test, rho: rho_test, sigma: sigma_test}),
        ref_logp(ma2_test, rho_test, sigma_test),
    )


def test_scan_multiple_output_types():
    """Test we can derive the logp for a scan that contains recurring and non-recurring measurable outputs."""
    [xs, ys, zs], _ = pytensor.scan(
        fn=lambda x_mu, y_tm1, z_tm2, z_tm1: (
            pt.random.normal(x_mu),
            pt.random.normal(y_tm1),
            pt.random.normal(z_tm1) + z_tm2,
        ),
        sequences=[pt.arange(10)],
        outputs_info=[
            None,
            pt.zeros(()),
            {"initial": pt.ones(2), "taps": [-2, -1]},
        ],
    )

    xs.name = "xs"
    xs_value = xs.clone()
    ys.name = "ys"
    ys_value = ys.clone()
    zs.name = "zs"
    zs_value = zs.clone()

    logp_dict = conditional_logp({xs: xs_value, ys: ys_value, zs: zs_value})
    xs_logp = logp_dict[xs_value]
    ys_logp = logp_dict[ys_value]
    zs_logp = logp_dict[zs_value]

    assert_no_rvs([xs_logp, ys_logp, zs_logp])
    fn = pytensor.function(
        [xs_value, ys_value, zs_value],
        [xs_logp, ys_logp, zs_logp],
    )

    rng = np.random.default_rng(577)
    test_value = rng.uniform(size=(10,))
    (xs_logp_eval, ys_logp_eval, zs_logp_eval) = fn(test_value, test_value, test_value)
    np.testing.assert_allclose(xs_logp_eval, stats.norm.logpdf(test_value, np.arange(10)))
    np.testing.assert_allclose(ys_logp_eval, stats.norm.logpdf(test_value, [0, *test_value[:-1]]))
    np.testing.assert_allclose(
        zs_logp_eval,
        stats.norm.logpdf(
            test_value, [a + b for a, b in itertools.pairwise([1, 1, *test_value[:-1]])]
        ),
    )
