// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"

namespace paddle {
namespace framework {
namespace ir {

struct Param {
  std::string X = "concat_0.tmp_0";
  std::string C0 = "cell_init";
  std::string H0 = "hidden_init";
  std::string AttentionWeight = "attention_fc.w_0";
  std::string AttentionBias = "attention_fc.b_0";
  std::string AttentionScalar = "attention_output.w_0";
  std::string AttentionScalarBias = "attention_output.b_0";
  std::string LSTMWeight = "attention_w.new";
  std::string LSTMBias = "attention_b.new";
  std::string Hidden = "array_to_lod_tensor_0.tmp_0";
  std::string Cell = "at.cell.new";
  std::string AttentionedX = "at.x.new";
  std::string AttentionFCOut = "at.fc.new";
  std::string LSTMX = "at.lstmx.new";
  std::string LSTMOUT = "at.lstmout.new";
};

void PrepareParameters(Graph* graph, const Param& param);

void FindWhileOp(Graph* graph) {
  GraphPatternDetector gpd;
  std::unordered_set<int> fused_external_ops(
      {35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48,
       57, 55, 56, 52, 74, 80, 77, 78, 79, 50, 77, 39, 40, 51});

  gpd.mutable_pattern()->NewNode(
      [&](Node* n) { return fused_external_ops.count(n->id()); }, "while");

  if (!graph->Has(kGraphvizMarkedNodeAttr)) {
    graph->Set(kGraphvizMarkedNodeAttr, new GraphVizPass::marked_nodes_t);
  }
  auto& marked_nodes =
      graph->Get<GraphVizPass::marked_nodes_t>(kGraphvizMarkedNodeAttr);

  auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
                    Graph* g) {
    auto* while_pat_node = gpd.pattern().RetrieveNode("while");
    auto* while_node = subgraph.at(while_pat_node);
    marked_nodes.insert(while_node);
  };
  gpd(graph, handle);

  Param param;
  // Add AttentionLSTM node
  OpDesc op_desc;
  op_desc.SetType("attention_lstm");

#define OP_SET_IN(x) op_desc.SetInput(#x, {param.x});
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {param.x});
  OP_SET_IN(X);
  OP_SET_IN(C0);
  OP_SET_IN(H0);
  OP_SET_IN(AttentionWeight);
  OP_SET_IN(AttentionBias);
  OP_SET_IN(AttentionScalar);
  OP_SET_IN(AttentionScalarBias);
  OP_SET_IN(LSTMWeight);
  OP_SET_IN(LSTMBias);

  OP_SET_OUT(Hidden);
  OP_SET_OUT(Cell);
  OP_SET_OUT(AttentionedX);
  OP_SET_OUT(AttentionFCOut);
  OP_SET_OUT(LSTMX);
  OP_SET_OUT(LSTMOUT);
#undef OP_SET_IN
#undef OP_SET_OUT

  auto* X = graph->RetriveNode(34);
  auto* LSTMOUT = graph->RetriveNode(81);
  auto* cell_init = graph->RetriveNode(6);
  auto* hidden_init = graph->RetriveNode(8);

  auto* lstm_op = graph->CreateOpNode(&op_desc);
  PrepareParameters(graph, param);

  IR_NODE_LINK_TO(X, lstm_op);
  IR_NODE_LINK_TO(cell_init, lstm_op);
  IR_NODE_LINK_TO(hidden_init, lstm_op);
  IR_NODE_LINK_TO(lstm_op, LSTMOUT);

  GraphSafeRemoveNodes(graph, marked_nodes);
}

#define CHECK_P1(x) PADDLE_ENFORCE_NOT_NULL(x);
#define CHECK_P2(x0, x1) \
  CHECK_P1(x0);          \
  CHECK_P1(x1);
#define CHECK_P3(x0, x1, x2) \
  CHECK_P2(x0, x1);          \
  CHECK_P1(x2);
#define CHECK_P4(x0, x1, x2, x3) \
  CHECK_P3(x0, x1, x2);          \
  CHECK_P1(x3);
#define CHECK_P5(x0, x1, x2, x3, x4) \
  CHECK_P4(x0, x1, x2, x3);          \
  CHECK_P1(x4);

void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
                       const LoDTensor& W_forget_w1,
                       const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
                       const LoDTensor& W_output_w0,
                       const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
                       const LoDTensor& W_cell_w1, LoDTensor* out);

void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
                     const LoDTensor& B_output, const LoDTensor& B_cell,
                     LoDTensor* out);

void PrepareParameters(Graph* graph, const Param& param) {
  // Check parameters
  PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
  auto* scope = graph->Get<Scope*>(kParamScopeAttr);

  // Create new parameters.
  scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>();
  scope->Var(param.LSTMBias)->GetMutable<LoDTensor>();
  scope->Var(param.Hidden)->GetMutable<LoDTensor>();
  scope->Var(param.Cell)->GetMutable<LoDTensor>();
  scope->Var(param.AttentionedX)->GetMutable<LoDTensor>();
  scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
  scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
  scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();

#define GATE_W(name__)                                               \
  auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0");            \
  auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1");            \
  auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0");            \
  CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0);       \
  VLOG(4) << #name__ "_w0"                                           \
          << " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
  VLOG(4) << #name__ "_w1"                                           \
          << " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \
  VLOG(4) << #name__ "_b0"                                           \
          << " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \
  auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>();       \
  auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>();       \
  auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>();

  GATE_W(forget);
  GATE_W(input);
  GATE_W(output);
  GATE_W(c);
#undef GATE_W

  auto* attention_fc_w = scope->FindVar("attention_fc.w_0");
  auto* attention_fc_b = scope->FindVar("attention_fc.b_0");
  auto* attention_output_w = scope->FindVar("attention_output.w_0");
  auto* attention_output_b = scope->FindVar("attention_output.b_0");
  CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
           attention_output_b);

  auto* lstm_weight = scope->Var(param.LSTMWeight);
  auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
  auto* lstm_bias = scope->Var(param.LSTMBias);
  auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();

  // reshape attention_bias
  auto* attention_bias_t =
      scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
  PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
  attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));

  auto* attention_scalar_bias_t =
      scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
  attention_scalar_bias_t->Resize(
      make_ddim({1, attention_scalar_bias_t->dims()[0]}));

  PrepareLSTMWeight(W_forget_w0_t, W_forget_w1_t, W_input_w0_t, W_input_w1_t,
                    W_output_w0_t, W_output_w1_t, W_c_w0_t, W_c_w1_t,
                    lstm_weight_t);
  PrepareLSTMBias(W_forget_b0_t, W_input_b0_t, W_output_b0_t, W_c_b0_t,
                  lstm_bias_t);
}

// Prepare parameters
void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
                       const LoDTensor& W_forget_w1,
                       const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
                       const LoDTensor& W_output_w0,
                       const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
                       const LoDTensor& W_cell_w1, LoDTensor* out) {
  int D = W_forget_w0.dims()[0];
  int M = W_forget_w1.dims()[0];
  out->Resize(make_ddim({D + M, 4 * D}));
  VLOG(3) << "LSTMWeight resized to " << out->dims();

  float* out_data = out->mutable_data<float>(platform::CPUPlace());
  std::array<const float*, 4> tensors(
      {{W_forget_w0.data<float>(), W_input_w0.data<float>(),
        W_output_w0.data<float>(), W_cell_w0.data<float>()}});
  std::array<const float*, 4> tensors1(
      {{W_forget_w1.data<float>(), W_input_w1.data<float>(),
        W_output_w1.data<float>(), W_cell_w1.data<float>()}});

  for (int row = 0; row < D; row++) {
    for (int col = 0; col < 4; col++) {
      float* dst = out_data + 4 * D * row + D * col;
      const float* src = tensors[col] + D * row;
      memcpy(dst, src, D * sizeof(float));
    }
  }

  for (int row = 0; row < M; row++) {
    for (int col = 0; col < 4; col++) {
      float* dst = out_data + 4 * D * (D + row) + D * col;
      const float* src = tensors1[col] + D * row;
      memcpy(dst, src, D * sizeof(float));
    }
  }
}

void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
                     const LoDTensor& B_output, const LoDTensor& B_cell,
                     LoDTensor* out) {
  std::array<const float*, 4> tensors(
      {{B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(),
        B_cell.data<float>()}});

  PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1);
  int D = B_forget.dims()[0];
  out->Resize(make_ddim({1, 4 * D}));
  auto* out_data = out->mutable_data<float>(platform::CPUPlace());
  for (size_t i = 0; i < tensors.size(); i++) {
    memcpy(out_data + D * i, tensors[i], D * sizeof(float));
  }
}

// Parameters

std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
    std::unique_ptr<ir::Graph> graph) const {
  PDPattern external_pattern, subblock_pattern;

  // Use the following variables to tell whether this model is RNN1.
  // This fuse can only works on the RNN1 model.
  std::unordered_set<std::string> specified_vars({"data_lod_attention",
                                                  "cell_init", "hidden_init",
                                                  "data", "week", "minute"});
  int count = 0;
  for (auto* node : graph->Nodes()) {
    if (node->IsVar() && specified_vars.count(node->Name())) {
      ++count;
    }
  }
  if (count < specified_vars.size()) {
    return graph;
  }

  // Continue to fuse.
  FindWhileOp(graph.get());
  return graph;
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(attention_lstm_fuse_pass,
              paddle::framework::ir::AttentionLSTMFusePass);