|
|
@ -16,8 +16,10 @@
|
|
|
|
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
|
|
|
|
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
#include <memory>
|
|
|
|
#include <memory>
|
|
|
|
|
|
|
|
#include "backend/session/kernel_graph.h"
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
#include "utils/trace_base.h"
|
|
|
|
#include "utils/trace_base.h"
|
|
|
|
|
|
|
|
#include "utils/tensor_construct_utils.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace opt {
|
|
|
|
namespace opt {
|
|
|
@ -46,7 +48,7 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> output0_dims{origin_input9_shape[0], 4 * (((origin_input9_shape[1] + 15) / 16) * 16)};
|
|
|
|
std::vector<size_t> output0_dims{origin_input9_shape[0], 4 * (((origin_input9_shape[1] + 15) / 16) * 16)};
|
|
|
|
std::vector<size_t> output1_dims{input_i_shape[1], input_i_shape[2]};
|
|
|
|
std::vector<size_t> output1_dims{input_i_shape[1], input_i_shape[2]};
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {output0_dims, output1_dims},
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16, kNumberTypeFloat32}, {output0_dims, output1_dims},
|
|
|
|
basic_lstm_cell_c_state_grad.get());
|
|
|
|
basic_lstm_cell_c_state_grad.get());
|
|
|
|
AnfAlgo::SetNodeAttr("forget_bias", MakeValue(1.0f), basic_lstm_cell_c_state_grad);
|
|
|
|
AnfAlgo::SetNodeAttr("forget_bias", MakeValue(1.0f), basic_lstm_cell_c_state_grad);
|
|
|
|
AnfAlgo::SetNodeAttr("activation", MakeValue("Tanh"), basic_lstm_cell_c_state_grad);
|
|
|
|
AnfAlgo::SetNodeAttr("activation", MakeValue("Tanh"), basic_lstm_cell_c_state_grad);
|
|
|
@ -260,7 +262,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|
|
|
// Create lstm_gage_concat
|
|
|
|
// Create lstm_gage_concat
|
|
|
|
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input);
|
|
|
|
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input);
|
|
|
|
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
|
|
|
|
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32},
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16},
|
|
|
|
{{origin_input7_shape[0], origin_input7_shape[1], 4 * origin_input7_shape[2]}},
|
|
|
|
{{origin_input7_shape[0], origin_input7_shape[1], 4 * origin_input7_shape[2]}},
|
|
|
|
lstm_gage_concat.get());
|
|
|
|
lstm_gage_concat.get());
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
|
|
|
@ -413,6 +415,24 @@ AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &l
|
|
|
|
return batch_matmul;
|
|
|
|
return batch_matmul;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
|
|
|
|
|
|
|
|
const AnfNodePtr &node) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
|
|
|
// Create node
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
|
|
|
|
|
|
|
|
node, lstm_input_grad};
|
|
|
|
|
|
|
|
auto batch_matmul = func_graph->NewCNode(matmul_inputs);
|
|
|
|
|
|
|
|
// Set infer data type and shape
|
|
|
|
|
|
|
|
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[0], IntToSize(1),
|
|
|
|
|
|
|
|
AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[2]};
|
|
|
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, batch_matmul.get());
|
|
|
|
|
|
|
|
// Set attr
|
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
|
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), batch_matmul);
|
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
|
|
|
|
|
|
|
|
return batch_matmul;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
|
|
|
AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
|
|
|
const AnfNodePtr &batch_matmul) {
|
|
|
|
const AnfNodePtr &batch_matmul) {
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
@ -430,18 +450,38 @@ AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
|
|
|
|
return reduce_sum;
|
|
|
|
return reduce_sum;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
|
|
|
|
|
|
|
|
auto origin_input7 = dynamic_rnn_grad_cnode->input(8);
|
|
|
|
|
|
|
|
auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
|
|
|
|
|
|
|
|
auto t_size = origin_input7_shape[0];
|
|
|
|
|
|
|
|
auto n_size = origin_input7_shape[1];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> shape = {t_size, IntToSize(1), n_size};
|
|
|
|
|
|
|
|
std::vector<int64_t> output_shape = {SizeToLong(t_size), SizeToLong(1), SizeToLong(n_size)};
|
|
|
|
|
|
|
|
std::vector<int64_t> output_tensor = {(SizeToLong(n_size) + SizeToLong(15)) / SizeToLong(16) * SizeToLong(16) *
|
|
|
|
|
|
|
|
SizeToLong(16) * SizeToLong(t_size)};
|
|
|
|
|
|
|
|
auto tensor = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, output_tensor);
|
|
|
|
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape);
|
|
|
|
|
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
|
|
|
|
|
|
|
auto value_node = kernel_graph->NewValueNode(x_abstract, tensor);
|
|
|
|
|
|
|
|
kernel_graph->AddValueNodeToGraph(value_node);
|
|
|
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, value_node.get());
|
|
|
|
|
|
|
|
return value_node;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
|
|
|
AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
|
|
|
|
const AnfNodePtr &lstm_input_grad) {
|
|
|
|
const AnfNodePtr &lstm_input_grad, const AnfNodePtr &value_node) {
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
// Create node
|
|
|
|
// Create node
|
|
|
|
|
|
|
|
auto batch_matmul = CreateBatchMatMul2(func_graph, lstm_input_grad, value_node);
|
|
|
|
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
|
|
|
std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
|
|
|
|
lstm_input_grad};
|
|
|
|
batch_matmul};
|
|
|
|
auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
|
|
|
|
auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
|
|
|
|
// Set infer data type and shape
|
|
|
|
// Set infer data type and shape
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 1)},
|
|
|
|
auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[2]};
|
|
|
|
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 1)}, reduce_sum.get());
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, reduce_sum.get());
|
|
|
|
// Set attr
|
|
|
|
// Set attr
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0, 1}), reduce_sum);
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
|
|
|
|
return reduce_sum;
|
|
|
|
return reduce_sum;
|
|
|
@ -486,8 +526,9 @@ const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph
|
|
|
|
make_tuple_inputs.emplace_back(batch_matmul);
|
|
|
|
make_tuple_inputs.emplace_back(batch_matmul);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto value_node = CreateValueNode(func_graph, dynamic_rnn_grad_cnode);
|
|
|
|
// create reduce_sum_2
|
|
|
|
// create reduce_sum_2
|
|
|
|
auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad);
|
|
|
|
auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad, value_node);
|
|
|
|
make_tuple_inputs.emplace_back(db_reduce_sum);
|
|
|
|
make_tuple_inputs.emplace_back(db_reduce_sum);
|
|
|
|
make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end());
|
|
|
|
make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end());
|
|
|
|
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
|
|
|
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
|
|
|