import unittest
import numpy as np
from spreg.sur_utils import sur_dictxy, sur_dictZ
from spreg.sur_lag import SURlagIV
import libpysal
import geopandas as gpd
from libpysal.common import RTOL

def dict_compare(actual, desired, rtol, atol=1e-7):
    for i in actual.keys():
        np.testing.assert_allclose(actual[i], desired[i], rtol, atol=atol)

class Test_SURlagIV(unittest.TestCase):
    def setUp(self):
        nat = libpysal.examples.load_example('NCOVR')
        self.db = gpd.read_file(nat.get_path("NAT.shp"))
        self.w = libpysal.weights.Queen.from_dataframe(self.db)
        self.w.transform = 'r'

    def test_3SLS(self):  # 2 equations, same K in each
        y_var0 = ["HR80", "HR90"]
        x_var0 = [["PS80", "UE80"], ["PS90", "UE90"]]
        bigy0, bigX0, bigyvars0, bigXvars0 = sur_dictxy(self.db, y_var0, x_var0)
        reg = SURlagIV(
            bigy0,
            bigX0,
            w=self.w,
            name_bigy=bigyvars0,
            name_bigX=bigXvars0,
            name_ds="NAT",
            name_w="nat_queen",
        )

        dict_compare(
            reg.b3SLS,
            {
                0: np.array([[4.79766641], [0.66900706], [0.45430715], [-0.13665465]]),
                1: np.array([[2.27972563], [0.99252289], [0.52280565], [0.06909469]]),
            },
            RTOL,
        )
        dict_compare(
            reg.tsls_inf,
            {
                0: np.array(
                    [
                        [4.55824001e00, 1.05252606e00, 2.92558259e-01],
                        [3.54744447e-01, 1.88588453e00, 5.93105171e-02],
                        [7.79071951e-02, 5.83138887e00, 5.49679157e-09],
                        [6.74318852e-01, -2.02655838e-01, 8.39404043e-01],
                    ]
                ),
                1: np.array(
                    [
                        [3.90351092e-01, 5.84019280e00, 5.21404469e-09],
                        [1.21674079e-01, 8.15722547e00, 3.42808098e-16],
                        [4.47686969e-02, 1.16779288e01, 1.65273681e-31],
                        [7.99640809e-02, 8.64071585e-01, 3.87548567e-01],
                    ]
                ),
            },
            RTOL,
        )
        np.testing.assert_allclose(
            reg.corr, np.array([[1.0, 0.525751], [0.525751, 1.0]]), RTOL
        )
        np.testing.assert_allclose(
            reg.surchow,
            [
                (0.3178787640240518, 1, 0.57288522734425285),
                (1.0261877219299562, 1, 0.31105574708021311),
                (0.76852435750330428, 1, 0.38067394159083323),
                (0.099802260814129934, 1, 0.75206705793155604),
            ],
            RTOL,
        )

    def test_3SLS_3eq(self):  # Three equations, no endogenous
        y_var1 = ["HR60", "HR70", "HR80"]
        x_var1 = [["RD60", "PS60"], ["RD70", "PS70", "UE70"], ["RD80", "PS80"]]
        bigy1, bigX1, bigyvars1, bigXvars1 = sur_dictxy(self.db, y_var1, x_var1)
        reg = SURlagIV(
            bigy1,
            bigX1,
            w=self.w,
            name_bigy=bigyvars1,
            name_bigX=bigXvars1,
            name_ds="NAT",
            name_w="nat_queen",
        )

        dict_compare(
            reg.b2SLS,
            {
                0: np.array([[2.42754085], [1.48928052], [0.33812558], [0.45567848]]),
                1: np.array(
                    [
                        [4.83887747],
                        [2.86272903],
                        [0.96950417],
                        [-0.12928124],
                        [0.33328525],
                    ]
                ),
                2: np.array([[6.69407561], [3.81449588], [1.44603996], [0.03355501]]),
            },
            RTOL,
        )
        dict_compare(
            reg.b3SLS,
            {
                0: np.array([[2.1646724], [1.31916307], [0.3398716], [0.51336281]]),
                1: np.array(
                    [
                        [4.87587006],
                        [2.68927603],
                        [0.94945336],
                        [-0.145607],
                        [0.33901794],
                    ]
                ),
                2: np.array([[6.48848271], [3.53936913], [1.34731149], [0.06309451]]),
            },
            RTOL,
        )
        dict_compare(
            reg.tsls_inf,
            {
                0: np.array(
                    [
                        [3.51568531e-01, 6.15718476e00, 7.40494437e-10],
                        [1.86875349e-01, 7.05905340e00, 1.67640650e-12],
                        [9.04557549e-02, 3.75732426e00, 1.71739894e-04],
                        [7.48661202e-02, 6.85707782e00, 7.02833502e-12],
                    ]
                ),
                1: np.array(
                    [
                        [4.72342840e-01, 1.03227352e01, 5.56158073e-25],
                        [2.12539934e-01, 1.26530388e01, 1.07629786e-36],
                        [1.21325632e-01, 7.82566179e00, 5.04993280e-15],
                        [4.61662438e-02, -3.15397123e00, 1.61064963e-03],
                        [5.41804741e-02, 6.25719766e00, 3.91956530e-10],
                    ]
                ),
                2: np.array(
                    [
                        [3.36526688e-001, 1.92807374e001, 7.79572152e-083],
                        [1.59012676e-001, 2.22584087e001, 9.35079396e-110],
                        [1.08370073e-001, 1.24325052e001, 1.74091603e-035],
                        [4.61776859e-002, 1.36634202e000, 1.71831639e-001],
                    ]
                ),
            },
            RTOL,
        )

        reg = SURlagIV(
            bigy1,
            bigX1,
            w=self.w,
            w_lags=2,
            name_bigy=bigyvars1,
            name_bigX=bigXvars1,
            name_ds="NAT",
            name_w="nat_queen",
        )

        dict_compare(
            reg.b3SLS,
            {
                0: np.array([[1.77468937], [1.14510457], [0.30768813], [0.5989414]]),
                1: np.array(
                    [
                        [4.26823484],
                        [2.43651351],
                        [0.8683601],
                        [-0.12672555],
                        [0.4208373],
                    ]
                ),
                2: np.array([[6.02334209], [3.38056146], [1.30003556], [0.12992573]]),
            },
            RTOL,
        )
        dict_compare(
            reg.tsls_inf,
            {
                0: np.array(
                    [
                        [3.27608281e-01, 5.41710779e00, 6.05708284e-08],
                        [1.76245578e-01, 6.49721025e00, 8.18230736e-11],
                        [8.95068772e-02, 3.43759205e00, 5.86911195e-04],
                        [6.94610221e-02, 8.62269771e00, 6.53949186e-18],
                    ]
                ),
                1: np.array(
                    [
                        [4.52225005e-01, 9.43829906e00, 3.78879655e-21],
                        [2.03807701e-01, 1.19549629e01, 6.11608551e-33],
                        [1.19004906e-01, 7.29684281e00, 2.94598624e-13],
                        [4.57552474e-02, -2.76963964e00, 5.61183429e-03],
                        [5.13101239e-02, 8.20183745e00, 2.36740266e-16],
                    ]
                ),
                2: np.array(
                    [
                        [3.27580342e-001, 1.83873735e001, 1.65820984e-075],
                        [1.55771577e-001, 2.17020429e001, 1.96247435e-104],
                        [1.06817752e-001, 1.21705946e001, 4.45822889e-034],
                        [4.48871540e-002, 2.89449691e000, 3.79766647e-003],
                    ]
                ),
            },
            RTOL,
        )

    def test_3SLS_3eq_end(self):  # Three equations, two endogenous, three instruments
        y_var2 = ["HR60", "HR70", "HR80"]
        x_var2 = [["RD60", "PS60"], ["RD70", "PS70", "MA70"], ["RD80", "PS80"]]
        yend_var2 = [["UE60", "DV60"], ["UE70", "DV70"], ["UE80", "DV80"]]
        q_var2 = [
            ["FH60", "FP59", "GI59"],
            ["FH70", "FP69", "GI69"],
            ["FH80", "FP79", "GI79"],
        ]
        bigy2, bigX2, bigyvars2, bigXvars2 = sur_dictxy(self.db, y_var2, x_var2)
        bigyend2, bigyendvars2 = sur_dictZ(self.db, yend_var2)
        bigq2, bigqvars2 = sur_dictZ(self.db, q_var2)
        reg = SURlagIV(
            bigy2,
            bigX2,
            bigyend2,
            bigq2,
            w=self.w,
            name_bigy=bigyvars2,
            name_bigX=bigXvars2,
            name_bigyend=bigyendvars2,
            name_bigq=bigqvars2,
            spat_diag=True,
            name_ds="NAT",
            name_w="nat_queen",
        )

        dict_compare(
            reg.b2SLS,
            {
                0: np.array(
                    [
                        [-2.36265226],
                        [1.69785946],
                        [0.65777251],
                        [-0.07519173],
                        [2.15755822],
                        [0.69200015],
                    ]
                ),
                1: np.array(
                    [
                        [8.13716008],
                        [3.28583832],
                        [0.90311859],
                        [-0.21702098],
                        [-1.04365606],
                        [2.8597322],
                        [0.39935589],
                    ]
                ),
                2: np.array(
                    [
                        [-5.8117312],
                        [3.49934818],
                        [0.56523782],
                        [0.09653315],
                        [2.31166815],
                        [0.20602185],
                    ]
                ),
            },
            RTOL,
        )
        dict_compare(
            reg.b3SLS,
            {
                0: np.array(
                    [
                        [-2.33115839],
                        [1.43097732],
                        [0.57312948],
                        [0.03474891],
                        [1.78825098],
                        [0.7145636],
                    ]
                ),
                1: np.array(
                    [
                        [8.34932294],
                        [3.28396774],
                        [0.95119978],
                        [-0.19323687],
                        [-1.1750583],
                        [2.75925141],
                        [0.38544424],
                    ]
                ),
                2: np.array(
                    [
                        [-5.2395274],
                        [3.38941755],
                        [0.55897901],
                        [0.08212108],
                        [2.19387428],
                        [0.21582944],
                    ]
                ),
            },
            RTOL,
        )
        dict_compare(
            reg.tsls_inf,
            {
                0: np.array(
                    [
                        [7.31246733e-01, -3.18792315e00, 1.43298614e-03],
                        [2.07089585e-01, 6.90994348e00, 4.84846854e-12],
                        [1.15296751e-01, 4.97090750e00, 6.66402399e-07],
                        [8.75272616e-02, 3.97006755e-01, 6.91362479e-01],
                        [3.10638495e-01, 5.75669472e00, 8.57768262e-09],
                        [5.40333500e-02, 1.32244919e01, 6.33639937e-40],
                    ]
                ),
                1: np.array(
                    [
                        [1.71703190e00, 4.86264870e00, 1.15825305e-06],
                        [2.79253520e-01, 1.17598079e01, 6.28772226e-32],
                        [1.27575632e-01, 7.45596763e00, 8.92106480e-14],
                        [3.31742265e-02, -5.82490950e00, 5.71435564e-09],
                        [2.19785746e-01, -5.34638083e00, 8.97303096e-08],
                        [3.29882178e-01, 8.36435430e00, 6.04450321e-17],
                        [5.54968909e-02, 6.94533032e00, 3.77575814e-12],
                    ]
                ),
                2: np.array(
                    [
                        [9.77398092e-01, -5.36068920e00, 8.29050465e-08],
                        [1.67632600e-01, 2.02193222e01, 6.61862485e-91],
                        [1.24321379e-01, 4.49624202e00, 6.91650078e-06],
                        [6.94834624e-02, 1.18187957e00, 2.37253491e-01],
                        [1.68013780e-01, 1.30577045e01, 5.74336064e-39],
                        [4.16751208e-02, 5.17885587e00, 2.23250870e-07],
                    ]
                ),
            },
            RTOL,
        )
        np.testing.assert_allclose(reg.joinrho, (215.897034, 3, 1.54744730e-46))

    def test_3SLS_3eq_2or(self):  # Second order spatial lags, no instrument lags
        y_var2 = ["HR60", "HR70", "HR80"]
        x_var2 = [["RD60", "PS60"], ["RD70", "PS70", "MA70"], ["RD80", "PS80"]]
        yend_var2 = [["UE60", "DV60"], ["UE70", "DV70"], ["UE80", "DV80"]]
        q_var2 = [
            ["FH60", "FP59", "GI59"],
            ["FH70", "FP69", "GI69"],
            ["FH80", "FP79", "GI79"],
        ]

        bigy2, bigX2, bigyvars2, bigXvars2 = sur_dictxy(self.db, y_var2, x_var2)
        bigyend2, bigyendvars2 = sur_dictZ(self.db, yend_var2)
        bigq2, bigqvars2 = sur_dictZ(self.db, q_var2)
        reg = SURlagIV(
            bigy2,
            bigX2,
            bigyend2,
            bigq2,
            w=self.w,
            w_lags=2,
            lag_q=False,
            name_bigy=bigyvars2,
            name_bigX=bigXvars2,
            name_bigyend=bigyendvars2,
            name_bigq=bigqvars2,
            name_ds="NAT",
            name_w="nat_queen",
        )

        dict_compare(
            reg.b3SLS,
            {
                0: np.array(
                    [
                        [-2.40071969],
                        [1.2933015],
                        [0.53165876],
                        [0.04883189],
                        [1.6663233],
                        [0.76473297],
                    ]
                ),
                1: np.array(
                    [
                        [7.24987963],
                        [2.96110365],
                        [0.86322179],
                        [-0.17847268],
                        [-1.1332928],
                        [2.69573919],
                        [0.48295237],
                    ]
                ),
                2: np.array(
                    [
                        [-7.55692635],
                        [3.17561152],
                        [0.37487877],
                        [0.1816544],
                        [2.45768258],
                        [0.27716717],
                    ]
                ),
            },
            RTOL,
        )
        dict_compare(
            reg.tsls_inf,
            {
                0: np.array(
                    [
                        [7.28635609e-01, -3.29481522e00, 9.84864177e-04],
                        [2.44756930e-01, 5.28402406e00, 1.26376643e-07],
                        [1.26021571e-01, 4.21879172e00, 2.45615028e-05],
                        [1.03323393e-01, 4.72612122e-01, 6.36489932e-01],
                        [3.48694501e-01, 4.77874843e00, 1.76389726e-06],
                        [6.10435763e-02, 1.25276568e01, 5.26966810e-36],
                    ]
                ),
                1: np.array(
                    [
                        [1.76286536e00, 4.11255436e00, 3.91305295e-05],
                        [2.78649343e-01, 1.06266306e01, 2.24061686e-26],
                        [1.28607242e-01, 6.71207766e00, 1.91872523e-11],
                        [3.21721548e-02, -5.54742685e00, 2.89904383e-08],
                        [2.09773378e-01, -5.40246249e00, 6.57322045e-08],
                        [3.06806758e-01, 8.78644007e00, 1.54373978e-18],
                        [5.88231798e-02, 8.21023915e00, 2.20748374e-16],
                    ]
                ),
                2: np.array(
                    [
                        [1.10429601e00, -6.84320712e00, 7.74395589e-12],
                        [1.81002635e-01, 1.75445597e01, 6.54581911e-69],
                        [1.33983129e-01, 2.79795505e00, 5.14272697e-03],
                        [7.56814009e-02, 2.40025154e00, 1.63838090e-02],
                        [1.83365858e-01, 1.34031635e01, 5.79398038e-41],
                        [4.61324726e-02, 6.00807101e00, 1.87743612e-09],
                    ]
                ),
            },
            RTOL,
        )


if __name__ == "__main__":
    unittest.main()
