|
|
|
@ -71,11 +71,11 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|
|
|
|
{split_v_output0_shape, split_v_output1_shape}, split_v.get());
|
|
|
|
|
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSizeSplits,
|
|
|
|
|
MakeValue(std::vector<int>{SizeToInt((origin_output2_shape[2] + 15) / 16),
|
|
|
|
|
SizeToInt((origin_output3_shape[1] + 15) / 16)}),
|
|
|
|
|
MakeValue(std::vector<int64_t>{SizeToLong((origin_output2_shape[2] + 15) / 16),
|
|
|
|
|
SizeToLong((origin_output3_shape[1] + 15) / 16)}),
|
|
|
|
|
split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(2), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(0)), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(2)), split_v);
|
|
|
|
|
|
|
|
|
|
basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad);
|
|
|
|
|
matmul_nodes.emplace_back(matmul);
|
|
|
|
@ -88,15 +88,15 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
|
|
|
|
|
|
|
|
|
|
AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
|
|
|
|
const std::vector<std::vector<size_t>> &split_shapes,
|
|
|
|
|
const std::vector<TypeId> &split_types, const std::vector<int> &size_split,
|
|
|
|
|
const std::vector<TypeId> &split_types, const std::vector<int64_t> &size_split,
|
|
|
|
|
size_t num_split_x) {
|
|
|
|
|
std::vector<AnfNodePtr> lstm_split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
|
|
|
|
|
input};
|
|
|
|
|
auto lstm_split = func_graph->NewCNode(lstm_split_input);
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(split_types, split_shapes, lstm_split.get());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_split), lstm_split);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), lstm_split);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToInt(num_split_x)), lstm_split);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(0)), lstm_split);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(num_split_x)), lstm_split);
|
|
|
|
|
return lstm_split;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -112,7 +112,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|
|
|
|
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0];
|
|
|
|
|
std::vector<std::vector<size_t>> split_shapes;
|
|
|
|
|
std::vector<TypeId> split_types;
|
|
|
|
|
std::vector<int> size_split;
|
|
|
|
|
std::vector<int64_t> size_split;
|
|
|
|
|
for (size_t i = 0; i < num_split_x; ++i) {
|
|
|
|
|
split_shapes.emplace_back(split_c_dims);
|
|
|
|
|
split_types.emplace_back(kNumberTypeFloat32);
|
|
|
|
@ -238,9 +238,9 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|
|
|
|
auto lstm_x_concat = func_graph->NewCNode(lstm_x_concat_input);
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2)},
|
|
|
|
|
lstm_x_concat.get());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(num_split_x)), lstm_x_concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(num_split_x)}), lstm_x_concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_x_concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_x_concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_x_concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), lstm_x_concat);
|
|
|
|
|
|
|
|
|
|
// Create lstm_gage_concat
|
|
|
|
|
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input);
|
|
|
|
@ -248,8 +248,8 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32},
|
|
|
|
|
{{origin_input7_shape[0], origin_input7_shape[1], 4 * origin_input7_shape[2]}},
|
|
|
|
|
lstm_gage_concat.get());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(num_split_x)), lstm_gage_concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(num_split_x)}), lstm_gage_concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_gage_concat);
|
|
|
|
|
|
|
|
|
|
outputs->emplace_back(lstm_x_concat);
|
|
|
|
@ -274,9 +274,10 @@ AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
|
|
|
|
|
std::vector<std::vector<size_t>> shapes = {shape1, shape2};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get());
|
|
|
|
|
// Set attr
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(2), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(std::vector<int>{SizeToInt(origin_input6_shape[0] - 1), 1}), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(SizeToLong(0)), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(2)), split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(std::vector<int64_t>{SizeToLong(origin_input6_shape[0] - 1), 1}),
|
|
|
|
|
split_v);
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v);
|
|
|
|
|
return split_v;
|
|
|
|
|
}
|
|
|
|
@ -315,9 +316,9 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
|
|
|
|
|
std::vector<size_t> shape = {splitv_output0_shape[0] + 1, origin_input4_shape[0], origin_input4_shape[1]};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape}, concat.get());
|
|
|
|
|
// Set attr
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
|
|
|
|
|
return concat;
|
|
|
|
|
}
|
|
|
|
@ -338,9 +339,9 @@ AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
|
|
|
|
|
origin_output0_shape[2] + h_concat_output_shape[2]};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
|
|
|
|
|
// Set attr
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(2), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(2)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
|
|
|
|
|
return concat;
|
|
|
|
|
}
|
|
|
|
@ -373,9 +374,9 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
|
|
|
|
|
std::vector<size_t> shape = {origin_input0_shape[0], origin_input0_shape[1], origin_input0_shape[2] + shape_tmp[2]};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
|
|
|
|
|
// Set attr
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(2), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(2)), concat);
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
|
|
|
|
|
return concat;
|
|
|
|
|
}
|
|
|
|
@ -410,7 +411,7 @@ AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
|
|
|
|
|
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get());
|
|
|
|
|
// Set attr
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0}), reduce_sum);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
|
|
|
|
|
return reduce_sum;
|
|
|
|
@ -427,7 +428,7 @@ AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 1)},
|
|
|
|
|
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 1)}, reduce_sum.get());
|
|
|
|
|
// Set attr
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0, 1}), reduce_sum);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0, 1}), reduce_sum);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
|
|
|
|
|
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
|
|
|
|
|
return reduce_sum;
|
|
|
|
|