#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import six
import math

__all__ = [
    'long_type',
    'to_text',
    'to_bytes',
    'round',
    'floor_division',
    'get_exception_message',
]

if six.PY2:
    int_type = int
    long_type = long
else:
    int_type = int
    long_type = int


#  str and bytes related functions
def to_text(obj, encoding='utf-8', inplace=False):
    """
      All string in PaddlePaddle should be represented as a literal string.
    This function will convert object to a literal string without any encoding.
    Especially, if the object type is a list or set container, we will iterate
    all items in the object and convert them to literal string.

    In Python3:
        Decode the bytes type object to str type with specific encoding

    In Python2:
        Decode the str type object to unicode type with specific encoding

    Args:
        obj(unicode|str|bytes|list|set) : The object to be decoded.
        encoding(str) : The encoding format to decode a string
        inplace(bool) : If we change the original object or we create a new one

    Returns:
        Decoded result of obj
    """
    if obj is None:
        return obj

    if isinstance(obj, list):
        if inplace:
            for i in six.moves.xrange(len(obj)):
                obj[i] = _to_text(obj[i], encoding)
            return obj
        else:
            return [_to_text(item, encoding) for item in obj]
    elif isinstance(obj, set):
        if inplace:
            for item in obj:
                obj.remove(item)
                obj.add(_to_text(item, encoding))
            return obj
        else:
            return set([_to_text(item, encoding) for item in obj])
    else:
        return _to_text(obj, encoding)


def _to_text(obj, encoding):
    """
    In Python3:
        Decode the bytes type object to str type with specific encoding

    In Python2:
        Decode the str type object to unicode type with specific encoding,
        or we just return the unicode string of object

    Args:
        obj(unicode|str|bytes) : The object to be decoded.
        encoding(str) : The encoding format

    Returns:
        decoded result of obj
    """
    if obj is None:
        return obj

    if isinstance(obj, six.binary_type):
        return obj.decode(encoding)
    elif isinstance(obj, six.text_type):
        return obj
    else:
        return six.u(obj)


def to_bytes(obj, encoding='utf-8', inplace=False):
    """
      All string in PaddlePaddle should be represented as a literal string.
    This function will convert object to a bytes with specific encoding.
    Especially, if the object type is a list or set container, we will iterate
    all items in the object and convert them to bytes.

    In Python3:
        Encode the str type object to bytes type with specific encoding

    In Python2:
        Encode the unicode type object to str type with specific encoding,
        or we just return the 8-bit string of object

    Args:
        obj(unicode|str|bytes|list|set) : The object to be encoded.
        encoding(str) : The encoding format to encode a string
        inplace(bool) : If we change the original object or we create a new one

    Returns:
        Decoded result of obj
    """
    if obj is None:
        return obj

    if isinstance(obj, list):
        if inplace:
            for i in six.moves.xrange(len(obj)):
                obj[i] = _to_bytes(obj[i], encoding)
            return obj
        else:
            return [_to_bytes(item, encoding) for item in obj]
    elif isinstance(obj, set):
        if inplace:
            for item in obj:
                obj.remove(item)
                obj.add(_to_bytes(item, encoding))
            return obj
        else:
            return set([_to_bytes(item, encoding) for item in obj])
    else:
        return _to_bytes(obj, encoding)


def _to_bytes(obj, encoding):
    """
    In Python3:
        Encode the str type object to bytes type with specific encoding

    In Python2:
        Encode the unicode type object to str type with specific encoding,
        or we just return the 8-bit string of object

    Args:
        obj(unicode|str|bytes) : The object to be encoded.
        encoding(str) : The encoding format

    Returns:
        encoded result of obj
    """
    if obj is None:
        return obj

    assert encoding is not None
    if isinstance(obj, six.text_type):
        return obj.encode(encoding)
    elif isinstance(obj, six.binary_type):
        return obj
    else:
        return six.b(obj)


# math related functions
def round(x, d=0):
    """
    Compatible round which act the same behaviour in Python3.

    Args:
        x(float) : The number to round halfway.

    Returns:
        round result of x
    """
    if six.PY3:
        # The official walkaround of round in Python3 is incorrect
        # we implement accroding this answer: https://www.techforgeek.info/round_python.html
        if x > 0.0:
            p = 10**d
            return float(math.floor((x * p) + math.copysign(0.5, x))) / p
        elif x < 0.0:
            p = 10**d
            return float(math.ceil((x * p) + math.copysign(0.5, x))) / p
        else:
            return math.copysign(0.0, x)
    else:
        import __builtin__
        return __builtin__.round(x, d)


def floor_division(x, y):
    """
    Compatible division which act the same behaviour in Python3 and Python2,
    whose result will be a int value of floor(x / y) in Python3 and value of
    (x / y) in Python2.

    Args:
        x(int|float) : The number to divide.
        y(int|float) : The number to be divided

    Returns:
        division result of x // y
    """
    return x // y


# exception related functions
def get_exception_message(exc):
    """
    Get the error message of a specific exception

    Args:
        exec(Exception) : The exception to get error message.

    Returns:
        the error message of exec
    """
    assert exc is not None

    if six.PY2:
        return exc.message
    else:
        return str(exc)