import numpy as np
import pytest

from pandas import DataFrame, Float64Index, Index, Int64Index, RangeIndex, Series
import pandas._testing as tm


def gen_obj(klass, index):
    if klass is Series:
        obj = Series(np.arange(len(index)), index=index)
    else:
        obj = DataFrame(
            np.random.randn(len(index), len(index)), index=index, columns=index
        )
    return obj


class TestFloatIndexers:
    def check(self, result, original, indexer, getitem):
        """
        comparator for results
        we need to take care if we are indexing on a
        Series or a frame
        """
        if isinstance(original, Series):
            expected = original.iloc[indexer]
        else:
            if getitem:
                expected = original.iloc[:, indexer]
            else:
                expected = original.iloc[indexer]

        tm.assert_almost_equal(result, expected)

    @pytest.mark.parametrize(
        "index_func",
        [
            tm.makeStringIndex,
            tm.makeUnicodeIndex,
            tm.makeCategoricalIndex,
            tm.makeDateIndex,
            tm.makeTimedeltaIndex,
            tm.makePeriodIndex,
        ],
    )
    def test_scalar_non_numeric(self, index_func, frame_or_series):

        # GH 4892
        # float_indexers should raise exceptions
        # on appropriate Index types & accessors

        i = index_func(5)
        s = gen_obj(frame_or_series, i)

        # getting
        with pytest.raises(KeyError, match="^3.0$"):
            s[3.0]

        with pytest.raises(KeyError, match="^3.0$"):
            s.loc[3.0]

        # contains
        assert 3.0 not in s

        # setting with an indexer
        if s.index.inferred_type in ["categorical"]:
            # Value or Type Error
            pass
        elif s.index.inferred_type in ["datetime64", "timedelta64", "period"]:

            # FIXME: dont leave commented-out
            # these should prob work
            # and are inconsistent between series/dataframe ATM
            # for idxr in [lambda x: x]:
            #    s2 = s.copy()
            #
            #    with pytest.raises(TypeError):
            #        idxr(s2)[3.0] = 0
            pass

        else:

            s2 = s.copy()
            s2.loc[3.0] = 10
            assert s2.index.is_object()

            s2 = s.copy()
            s2[3.0] = 0
            assert s2.index.is_object()

    @pytest.mark.parametrize(
        "index_func",
        [
            tm.makeStringIndex,
            tm.makeUnicodeIndex,
            tm.makeCategoricalIndex,
            tm.makeDateIndex,
            tm.makeTimedeltaIndex,
            tm.makePeriodIndex,
        ],
    )
    def test_scalar_non_numeric_series_fallback(self, index_func):
        # fallsback to position selection, series only
        i = index_func(5)
        s = Series(np.arange(len(i)), index=i)
        s[3]
        with pytest.raises(KeyError, match="^3.0$"):
            s[3.0]

    def test_scalar_with_mixed(self):

        s2 = Series([1, 2, 3], index=["a", "b", "c"])
        s3 = Series([1, 2, 3], index=["a", "b", 1.5])

        # lookup in a pure string index with an invalid indexer

        with pytest.raises(KeyError, match="^1.0$"):
            s2[1.0]

        with pytest.raises(KeyError, match=r"^1\.0$"):
            s2.loc[1.0]

        result = s2.loc["b"]
        expected = 2
        assert result == expected

        # mixed index so we have label
        # indexing
        with pytest.raises(KeyError, match="^1.0$"):
            s3[1.0]

        result = s3[1]
        expected = 2
        assert result == expected

        with pytest.raises(KeyError, match=r"^1\.0$"):
            s3.loc[1.0]

        result = s3.loc[1.5]
        expected = 3
        assert result == expected

    @pytest.mark.parametrize(
        "idxr,getitem", [(lambda x: x.loc, False), (lambda x: x, True)]
    )
    @pytest.mark.parametrize("index_func", [tm.makeIntIndex, tm.makeRangeIndex])
    def test_scalar_integer(self, index_func, frame_or_series, idxr, getitem):

        # test how scalar float indexers work on int indexes

        # integer index
        i = index_func(5)
        obj = gen_obj(frame_or_series, i)

        # coerce to equal int

        result = idxr(obj)[3.0]
        self.check(result, obj, 3, getitem)

        if isinstance(obj, Series):

            def compare(x, y):
                assert x == y

            expected = 100
        else:
            compare = tm.assert_series_equal
            if getitem:
                expected = Series(100, index=range(len(obj)), name=3)
            else:
                expected = Series(100.0, index=range(len(obj)), name=3)

        s2 = obj.copy()
        idxr(s2)[3.0] = 100

        result = idxr(s2)[3.0]
        compare(result, expected)

        result = idxr(s2)[3]
        compare(result, expected)

    @pytest.mark.parametrize("index_func", [tm.makeIntIndex, tm.makeRangeIndex])
    def test_scalar_integer_contains_float(self, index_func, frame_or_series):
        # contains
        # integer index
        index = index_func(5)
        obj = gen_obj(frame_or_series, index)

        # coerce to equal int
        assert 3.0 in obj

    def test_scalar_float(self, frame_or_series):

        # scalar float indexers work on a float index
        index = Index(np.arange(5.0))
        s = gen_obj(frame_or_series, index)

        # assert all operations except for iloc are ok
        indexer = index[3]
        for idxr, getitem in [(lambda x: x.loc, False), (lambda x: x, True)]:

            # getting
            result = idxr(s)[indexer]
            self.check(result, s, 3, getitem)

            # setting
            s2 = s.copy()

            result = idxr(s2)[indexer]
            self.check(result, s, 3, getitem)

            # random float is a KeyError
            with pytest.raises(KeyError, match=r"^3\.5$"):
                idxr(s)[3.5]

        # contains
        assert 3.0 in s

        # iloc succeeds with an integer
        expected = s.iloc[3]
        s2 = s.copy()

        s2.iloc[3] = expected
        result = s2.iloc[3]
        self.check(result, s, 3, False)

    @pytest.mark.parametrize(
        "index_func",
        [
            tm.makeStringIndex,
            tm.makeUnicodeIndex,
            tm.makeDateIndex,
            tm.makeTimedeltaIndex,
            tm.makePeriodIndex,
        ],
    )
    @pytest.mark.parametrize("idx", [slice(3.0, 4), slice(3, 4.0), slice(3.0, 4.0)])
    def test_slice_non_numeric(self, index_func, idx, frame_or_series):

        # GH 4892
        # float_indexers should raise exceptions
        # on appropriate Index types & accessors

        index = index_func(5)
        s = gen_obj(frame_or_series, index)

        # getitem
        msg = (
            "cannot do positional indexing "
            fr"on {type(index).__name__} with these indexers \[(3|4)\.0\] of "
            "type float"
        )
        with pytest.raises(TypeError, match=msg):
            s.iloc[idx]

        msg = (
            "cannot do (slice|positional) indexing "
            fr"on {type(index).__name__} with these indexers "
            r"\[(3|4)(\.0)?\] "
            r"of type (float|int)"
        )
        for idxr in [lambda x: x.loc, lambda x: x.iloc, lambda x: x]:
            with pytest.raises(TypeError, match=msg):
                idxr(s)[idx]

        # setitem
        msg = "slice indices must be integers or None or have an __index__ method"
        with pytest.raises(TypeError, match=msg):
            s.iloc[idx] = 0

        msg = (
            "cannot do (slice|positional) indexing "
            fr"on {type(index).__name__} with these indexers "
            r"\[(3|4)(\.0)?\] "
            r"of type (float|int)"
        )
        for idxr in [lambda x: x.loc, lambda x: x]:
            with pytest.raises(TypeError, match=msg):
                idxr(s)[idx] = 0

    def test_slice_integer(self):

        # same as above, but for Integer based indexes
        # these coerce to a like integer
        # oob indicates if we are out of bounds
        # of positional indexing
        for index, oob in [
            (Int64Index(range(5)), False),
            (RangeIndex(5), False),
            (Int64Index(range(5)) + 10, True),
        ]:

            # s is an in-range index
            s = Series(range(5), index=index)

            # getitem
            for idx in [slice(3.0, 4), slice(3, 4.0), slice(3.0, 4.0)]:

                result = s.loc[idx]

                # these are all label indexing
                # except getitem which is positional
                # empty
                if oob:
                    indexer = slice(0, 0)
                else:
                    indexer = slice(3, 5)
                self.check(result, s, indexer, False)

            # getitem out-of-bounds
            for idx in [slice(-6, 6), slice(-6.0, 6.0)]:

                result = s.loc[idx]

                # these are all label indexing
                # except getitem which is positional
                # empty
                if oob:
                    indexer = slice(0, 0)
                else:
                    indexer = slice(-6, 6)
                self.check(result, s, indexer, False)

            # positional indexing
            msg = (
                "cannot do slice indexing "
                fr"on {type(index).__name__} with these indexers \[-6\.0\] of "
                "type float"
            )
            with pytest.raises(TypeError, match=msg):
                s[slice(-6.0, 6.0)]

            # getitem odd floats
            for idx, res1 in [
                (slice(2.5, 4), slice(3, 5)),
                (slice(2, 3.5), slice(2, 4)),
                (slice(2.5, 3.5), slice(3, 4)),
            ]:

                result = s.loc[idx]
                if oob:
                    res = slice(0, 0)
                else:
                    res = res1

                self.check(result, s, res, False)

                # positional indexing
                msg = (
                    "cannot do slice indexing "
                    fr"on {type(index).__name__} with these indexers \[(2|3)\.5\] of "
                    "type float"
                )
                with pytest.raises(TypeError, match=msg):
                    s[idx]

    @pytest.mark.parametrize("idx", [slice(2, 4.0), slice(2.0, 4), slice(2.0, 4.0)])
    def test_integer_positional_indexing(self, idx):
        """make sure that we are raising on positional indexing
        w.r.t. an integer index
        """
        s = Series(range(2, 6), index=range(2, 6))

        result = s[2:4]
        expected = s.iloc[2:4]
        tm.assert_series_equal(result, expected)

        klass = RangeIndex
        msg = (
            "cannot do (slice|positional) indexing "
            fr"on {klass.__name__} with these indexers \[(2|4)\.0\] of "
            "type float"
        )
        with pytest.raises(TypeError, match=msg):
            s[idx]
        with pytest.raises(TypeError, match=msg):
            s.iloc[idx]

    @pytest.mark.parametrize("index_func", [tm.makeIntIndex, tm.makeRangeIndex])
    def test_slice_integer_frame_getitem(self, index_func):

        # similar to above, but on the getitem dim (of a DataFrame)
        index = index_func(5)

        s = DataFrame(np.random.randn(5, 2), index=index)

        # getitem
        for idx in [slice(0.0, 1), slice(0, 1.0), slice(0.0, 1.0)]:

            result = s.loc[idx]
            indexer = slice(0, 2)
            self.check(result, s, indexer, False)

            # positional indexing
            msg = (
                "cannot do slice indexing "
                fr"on {type(index).__name__} with these indexers \[(0|1)\.0\] of "
                "type float"
            )
            with pytest.raises(TypeError, match=msg):
                s[idx]

        # getitem out-of-bounds
        for idx in [slice(-10, 10), slice(-10.0, 10.0)]:

            result = s.loc[idx]
            self.check(result, s, slice(-10, 10), True)

        # positional indexing
        msg = (
            "cannot do slice indexing "
            fr"on {type(index).__name__} with these indexers \[-10\.0\] of "
            "type float"
        )
        with pytest.raises(TypeError, match=msg):
            s[slice(-10.0, 10.0)]

        # getitem odd floats
        for idx, res in [
            (slice(0.5, 1), slice(1, 2)),
            (slice(0, 0.5), slice(0, 1)),
            (slice(0.5, 1.5), slice(1, 2)),
        ]:

            result = s.loc[idx]
            self.check(result, s, res, False)

            # positional indexing
            msg = (
                "cannot do slice indexing "
                fr"on {type(index).__name__} with these indexers \[0\.5\] of "
                "type float"
            )
            with pytest.raises(TypeError, match=msg):
                s[idx]

    @pytest.mark.parametrize("idx", [slice(3.0, 4), slice(3, 4.0), slice(3.0, 4.0)])
    @pytest.mark.parametrize("index_func", [tm.makeIntIndex, tm.makeRangeIndex])
    def test_float_slice_getitem_with_integer_index_raises(self, idx, index_func):

        # similar to above, but on the getitem dim (of a DataFrame)
        index = index_func(5)

        s = DataFrame(np.random.randn(5, 2), index=index)

        # setitem
        sc = s.copy()
        sc.loc[idx] = 0
        result = sc.loc[idx].values.ravel()
        assert (result == 0).all()

        # positional indexing
        msg = (
            "cannot do slice indexing "
            fr"on {type(index).__name__} with these indexers \[(3|4)\.0\] of "
            "type float"
        )
        with pytest.raises(TypeError, match=msg):
            s[idx] = 0

        with pytest.raises(TypeError, match=msg):
            s[idx]

    @pytest.mark.parametrize("idx", [slice(3.0, 4), slice(3, 4.0), slice(3.0, 4.0)])
    def test_slice_float(self, idx, frame_or_series):

        # same as above, but for floats
        index = Index(np.arange(5.0)) + 0.1
        s = gen_obj(frame_or_series, index)

        expected = s.iloc[3:4]
        for idxr in [lambda x: x.loc, lambda x: x]:

            # getitem
            result = idxr(s)[idx]
            assert isinstance(result, type(s))
            tm.assert_equal(result, expected)

            # setitem
            s2 = s.copy()
            idxr(s2)[idx] = 0
            result = idxr(s2)[idx].values.ravel()
            assert (result == 0).all()

    def test_floating_index_doc_example(self):

        index = Index([1.5, 2, 3, 4.5, 5])
        s = Series(range(5), index=index)
        assert s[3] == 2
        assert s.loc[3] == 2
        assert s.loc[3] == 2
        assert s.iloc[3] == 3

    def test_floating_misc(self):

        # related 236
        # scalar/slicing of a float index
        s = Series(np.arange(5), index=np.arange(5) * 2.5, dtype=np.int64)

        # label based slicing
        result1 = s[1.0:3.0]
        result2 = s.loc[1.0:3.0]
        result3 = s.loc[1.0:3.0]
        tm.assert_series_equal(result1, result2)
        tm.assert_series_equal(result1, result3)

        # exact indexing when found
        result1 = s[5.0]
        result2 = s.loc[5.0]
        result3 = s.loc[5.0]
        assert result1 == result2
        assert result1 == result3

        result1 = s[5]
        result2 = s.loc[5]
        result3 = s.loc[5]
        assert result1 == result2
        assert result1 == result3

        assert s[5.0] == s[5]

        # value not found (and no fallbacking at all)

        # scalar integers
        with pytest.raises(KeyError, match=r"^4$"):
            s.loc[4]
        with pytest.raises(KeyError, match=r"^4$"):
            s.loc[4]
        with pytest.raises(KeyError, match=r"^4$"):
            s[4]

        # fancy floats/integers create the correct entry (as nan)
        # fancy tests
        expected = Series([2, 0], index=Float64Index([5.0, 0.0]))
        for fancy_idx in [[5.0, 0.0], np.array([5.0, 0.0])]:  # float
            tm.assert_series_equal(s[fancy_idx], expected)
            tm.assert_series_equal(s.loc[fancy_idx], expected)
            tm.assert_series_equal(s.loc[fancy_idx], expected)

        expected = Series([2, 0], index=Index([5, 0], dtype="int64"))
        for fancy_idx in [[5, 0], np.array([5, 0])]:  # int
            tm.assert_series_equal(s[fancy_idx], expected)
            tm.assert_series_equal(s.loc[fancy_idx], expected)
            tm.assert_series_equal(s.loc[fancy_idx], expected)

        # all should return the same as we are slicing 'the same'
        result1 = s.loc[2:5]
        result2 = s.loc[2.0:5.0]
        result3 = s.loc[2.0:5]
        result4 = s.loc[2.1:5]
        tm.assert_series_equal(result1, result2)
        tm.assert_series_equal(result1, result3)
        tm.assert_series_equal(result1, result4)

        # previously this did fallback indexing
        result1 = s[2:5]
        result2 = s[2.0:5.0]
        result3 = s[2.0:5]
        result4 = s[2.1:5]
        tm.assert_series_equal(result1, result2)
        tm.assert_series_equal(result1, result3)
        tm.assert_series_equal(result1, result4)

        result1 = s.loc[2:5]
        result2 = s.loc[2.0:5.0]
        result3 = s.loc[2.0:5]
        result4 = s.loc[2.1:5]
        tm.assert_series_equal(result1, result2)
        tm.assert_series_equal(result1, result3)
        tm.assert_series_equal(result1, result4)

        # combined test
        result1 = s.loc[2:5]
        result2 = s.loc[2:5]
        result3 = s[2:5]

        tm.assert_series_equal(result1, result2)
        tm.assert_series_equal(result1, result3)

        # list selection
        result1 = s[[0.0, 5, 10]]
        result2 = s.loc[[0.0, 5, 10]]
        result3 = s.loc[[0.0, 5, 10]]
        result4 = s.iloc[[0, 2, 4]]
        tm.assert_series_equal(result1, result2)
        tm.assert_series_equal(result1, result3)
        tm.assert_series_equal(result1, result4)

        with pytest.raises(KeyError, match="with any missing labels"):
            s[[1.6, 5, 10]]
        with pytest.raises(KeyError, match="with any missing labels"):
            s.loc[[1.6, 5, 10]]

        with pytest.raises(KeyError, match="with any missing labels"):
            s[[0, 1, 2]]
        with pytest.raises(KeyError, match="with any missing labels"):
            s.loc[[0, 1, 2]]

        result1 = s.loc[[2.5, 5]]
        result2 = s.loc[[2.5, 5]]
        tm.assert_series_equal(result1, result2)
        tm.assert_series_equal(result1, Series([1, 2], index=[2.5, 5.0]))

        result1 = s[[2.5]]
        result2 = s.loc[[2.5]]
        result3 = s.loc[[2.5]]
        tm.assert_series_equal(result1, result2)
        tm.assert_series_equal(result1, result3)
        tm.assert_series_equal(result1, Series([1], index=[2.5]))

    def test_floating_tuples(self):
        # see gh-13509
        s = Series([(1, 1), (2, 2), (3, 3)], index=[0.0, 0.1, 0.2], name="foo")

        result = s[0.0]
        assert result == (1, 1)

        expected = Series([(1, 1), (2, 2)], index=[0.0, 0.0], name="foo")
        s = Series([(1, 1), (2, 2), (3, 3)], index=[0.0, 0.0, 0.2], name="foo")

        result = s[0.0]
        tm.assert_series_equal(result, expected)

    def test_float64index_slicing_bug(self):
        # GH 5557, related to slicing a float index
        ser = {
            256: 2321.0,
            1: 78.0,
            2: 2716.0,
            3: 0.0,
            4: 369.0,
            5: 0.0,
            6: 269.0,
            7: 0.0,
            8: 0.0,
            9: 0.0,
            10: 3536.0,
            11: 0.0,
            12: 24.0,
            13: 0.0,
            14: 931.0,
            15: 0.0,
            16: 101.0,
            17: 78.0,
            18: 9643.0,
            19: 0.0,
            20: 0.0,
            21: 0.0,
            22: 63761.0,
            23: 0.0,
            24: 446.0,
            25: 0.0,
            26: 34773.0,
            27: 0.0,
            28: 729.0,
            29: 78.0,
            30: 0.0,
            31: 0.0,
            32: 3374.0,
            33: 0.0,
            34: 1391.0,
            35: 0.0,
            36: 361.0,
            37: 0.0,
            38: 61808.0,
            39: 0.0,
            40: 0.0,
            41: 0.0,
            42: 6677.0,
            43: 0.0,
            44: 802.0,
            45: 0.0,
            46: 2691.0,
            47: 0.0,
            48: 3582.0,
            49: 0.0,
            50: 734.0,
            51: 0.0,
            52: 627.0,
            53: 70.0,
            54: 2584.0,
            55: 0.0,
            56: 324.0,
            57: 0.0,
            58: 605.0,
            59: 0.0,
            60: 0.0,
            61: 0.0,
            62: 3989.0,
            63: 10.0,
            64: 42.0,
            65: 0.0,
            66: 904.0,
            67: 0.0,
            68: 88.0,
            69: 70.0,
            70: 8172.0,
            71: 0.0,
            72: 0.0,
            73: 0.0,
            74: 64902.0,
            75: 0.0,
            76: 347.0,
            77: 0.0,
            78: 36605.0,
            79: 0.0,
            80: 379.0,
            81: 70.0,
            82: 0.0,
            83: 0.0,
            84: 3001.0,
            85: 0.0,
            86: 1630.0,
            87: 7.0,
            88: 364.0,
            89: 0.0,
            90: 67404.0,
            91: 9.0,
            92: 0.0,
            93: 0.0,
            94: 7685.0,
            95: 0.0,
            96: 1017.0,
            97: 0.0,
            98: 2831.0,
            99: 0.0,
            100: 2963.0,
            101: 0.0,
            102: 854.0,
            103: 0.0,
            104: 0.0,
            105: 0.0,
            106: 0.0,
            107: 0.0,
            108: 0.0,
            109: 0.0,
            110: 0.0,
            111: 0.0,
            112: 0.0,
            113: 0.0,
            114: 0.0,
            115: 0.0,
            116: 0.0,
            117: 0.0,
            118: 0.0,
            119: 0.0,
            120: 0.0,
            121: 0.0,
            122: 0.0,
            123: 0.0,
            124: 0.0,
            125: 0.0,
            126: 67744.0,
            127: 22.0,
            128: 264.0,
            129: 0.0,
            260: 197.0,
            268: 0.0,
            265: 0.0,
            269: 0.0,
            261: 0.0,
            266: 1198.0,
            267: 0.0,
            262: 2629.0,
            258: 775.0,
            257: 0.0,
            263: 0.0,
            259: 0.0,
            264: 163.0,
            250: 10326.0,
            251: 0.0,
            252: 1228.0,
            253: 0.0,
            254: 2769.0,
            255: 0.0,
        }

        # smoke test for the repr
        s = Series(ser)
        result = s.value_counts()
        str(result)