#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import sys
import six
import re
from .graphviz import GraphPreviewGenerator
from .proto import framework_pb2
from google.protobuf import text_format

_vartype2str_ = [
    "UNK",
    "LoDTensor",
    "SelectedRows",
    "FeedMinibatch",
    "FetchList",
    "StepScopes",
    "LodRankTable",
    "LoDTensorArray",
    "PlaceList",
]
_dtype2str_ = [
    "bool",
    "int16",
    "int32",
    "int64",
    "float16",
    "float32",
    "float64",
]


def repr_data_type(type):
    return _dtype2str_[type]


def repr_tensor(proto):
    return "tensor(type={}, shape={})".format(_dtype2str_[int(proto.data_type)],
                                              str(proto.dims))


reprtpl = "{ttype} {name} ({reprs})"


def repr_lodtensor(proto):
    if proto.type.type != framework_pb2.VarType.LOD_TENSOR:
        return

    level = proto.type.lod_tensor.lod_level
    reprs = repr_tensor(proto.type.lod_tensor.tensor)
    return reprtpl.format(
        ttype="LoDTensor" if level > 0 else "Tensor",
        name=proto.name,
        reprs="level=%d, %s" % (level, reprs) if level > 0 else reprs)


def repr_selected_rows(proto):
    if proto.type.type != framework_pb2.VarType.SELECTED_ROWS:
        return

    return reprtpl.format(
        ttype="SelectedRows",
        name=proto.name,
        reprs=repr_tensor(proto.type.selected_rows))


def repr_tensor_array(proto):
    if proto.type.type != framework_pb2.VarType.LOD_TENSOR_ARRAY:
        return

    return reprtpl.format(
        ttype="TensorArray",
        name=proto.name,
        reprs="level=%d, %s" % (proto.type.tensor_array.lod_level,
                                repr_tensor(proto.type.lod_tensor.tensor)))


type_handlers = [
    repr_lodtensor,
    repr_selected_rows,
    repr_tensor_array,
]


def repr_var(vardesc):
    for handler in type_handlers:
        res = handler(vardesc)
        if res:
            return res


def pprint_program_codes(program_desc):
    reprs = []
    for block_idx in range(program_desc.desc.num_blocks()):
        block_desc = program_desc.block(block_idx)
        block_repr = pprint_block_codes(block_desc)
        reprs.append(block_repr)
    return '\n'.join(reprs)


def pprint_block_codes(block_desc, show_backward=False):
    def is_op_backward(op_desc):
        if op_desc.type.endswith('_grad'): return True

        def is_var_backward(var):
            if "@GRAD" in var.parameter: return True
            for arg in var.arguments:
                if "@GRAD" in arg: return True

        for var in op_desc.inputs:
            if is_var_backward(var): return True
        for var in op_desc.outputs:
            if is_var_backward(var): return True
        return False

    def is_var_backward(var_desc):
        return "@GRAD" in var_desc.name

    if type(block_desc) is not framework_pb2.BlockDesc:
        block_desc = framework_pb2.BlockDesc.FromString(
            block_desc.desc.serialize_to_string())
    var_reprs = []
    op_reprs = []
    for var in block_desc.vars:
        if not show_backward and is_var_backward(var):
            continue
        var_reprs.append(repr_var(var))

    for op in block_desc.ops:
        if not show_backward and is_op_backward(op): continue
        op_reprs.append(repr_op(op))

    tpl = "// block-{idx}  parent-{pidx}\n// variables\n{vars}\n\n// operators\n{ops}\n"
    return tpl.format(
        idx=block_desc.idx,
        pidx=block_desc.parent_idx,
        vars='\n'.join(var_reprs),
        ops='\n'.join(op_reprs), )


def repr_attr(desc):
    tpl = "{key}={value}"
    valgetter = [
        lambda attr: attr.i,
        lambda attr: attr.f,
        lambda attr: attr.s,
        lambda attr: attr.ints,
        lambda attr: attr.floats,
        lambda attr: attr.strings,
        lambda attr: attr.b,
        lambda attr: attr.bools,
        lambda attr: attr.block_idx,
        lambda attr: attr.l,
    ]
    key = desc.name
    value = valgetter[desc.type](desc)
    if key == "dtype":
        value = repr_data_type(value)
    return tpl.format(key=key, value=str(value)), (key, value)


def _repr_op_fill_constant(optype, inputs, outputs, attrs):
    if optype == "fill_constant":
        return "{output} = {data} [shape={shape}]".format(
            output=','.join(outputs),
            data=attrs['value'],
            shape=str(attrs['shape']))


op_repr_handlers = [_repr_op_fill_constant, ]


def repr_op(opdesc):
    optype = None
    attrs = []
    attr_dict = {}
    is_target = None
    inputs = []
    outputs = []

    tpl = "{outputs} = {optype}({inputs}{is_target}) [{attrs}]"
    args2value = lambda args: args[0] if len(args) == 1 else str(list(args))
    for var in opdesc.inputs:
        key = var.parameter
        value = args2value(var.arguments)
        inputs.append("%s=%s" % (key, value))
    for var in opdesc.outputs:
        value = args2value(var.arguments)
        outputs.append(value)
    for attr in opdesc.attrs:
        attr_repr, attr_pair = repr_attr(attr)
        attrs.append(attr_repr)
        attr_dict[attr_pair[0]] = attr_pair[1]

    is_target = opdesc.is_target

    for handler in op_repr_handlers:
        res = handler(opdesc.type, inputs, outputs, attr_dict)
        if res: return res

    return tpl.format(
        outputs=', '.join(outputs),
        optype=opdesc.type,
        inputs=', '.join(inputs),
        attrs="{%s}" % ','.join(attrs),
        is_target=", is_target" if is_target else "")


def draw_block_graphviz(block, highlights=None, path="./temp.dot"):
    '''
    Generate a debug graph for block.
    Args:
        block(Block): a block.
    '''
    graph = GraphPreviewGenerator("some graph")
    # collect parameters and args
    protostr = block.desc.serialize_to_string()
    desc = framework_pb2.BlockDesc.FromString(six.binary_type(protostr))

    def need_highlight(name):
        if highlights is None: return False
        for pattern in highlights:
            assert type(pattern) is str
            if re.match(pattern, name):
                return True
        return False

    # draw parameters and args
    vars = {}
    for var in desc.vars:
        # TODO(gongwb): format the var.type
        # create var
        if var.persistable:
            varn = graph.add_param(
                var.name,
                str(var.type).replace("\n", "<br />", 1),
                highlight=need_highlight(var.name))
        else:
            varn = graph.add_arg(var.name, highlight=need_highlight(var.name))
        vars[var.name] = varn

    def add_op_link_var(op, var, op2var=False):
        for arg in var.arguments:
            if arg not in vars:
                # add missing variables as argument
                vars[arg] = graph.add_arg(arg, highlight=need_highlight(arg))
            varn = vars[arg]
            highlight = need_highlight(op.description) or need_highlight(
                varn.description)
            if op2var:
                graph.add_edge(op, varn, highlight=highlight)
            else:
                graph.add_edge(varn, op, highlight=highlight)

    for op in desc.ops:
        opn = graph.add_op(op.type, highlight=need_highlight(op.type))
        for var in op.inputs:
            add_op_link_var(opn, var, False)
        for var in op.outputs:
            add_op_link_var(opn, var, True)

    graph(path, show=False)