modified cast32 to cast64

pull/8475/head
gaojing 4 years ago
parent 0de9d3e5b7
commit 335f3b95c1

@ -71,11 +71,11 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
{split_v_output0_shape, split_v_output1_shape}, split_v.get());
AnfAlgo::SetNodeAttr(kAttrSizeSplits,
MakeValue(std::vector<int>{SizeToInt((origin_output2_shape[2] + 15) / 16),
SizeToInt((origin_output3_shape[1] + 15) / 16)}),
MakeValue(std::vector<int64_t>{SizeToLong((origin_output2_shape[2] + 15) / 16),
SizeToLong((origin_output3_shape[1] + 15) / 16)}),
split_v);
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), split_v);
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(2), split_v);
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(0)), split_v);
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(2)), split_v);
basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad);
matmul_nodes.emplace_back(matmul);
@ -88,15 +88,15 @@ void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn
AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const std::vector<std::vector<size_t>> &split_shapes,
const std::vector<TypeId> &split_types, const std::vector<int> &size_split,
const std::vector<TypeId> &split_types, const std::vector<int64_t> &size_split,
size_t num_split_x) {
std::vector<AnfNodePtr> lstm_split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
input};
auto lstm_split = func_graph->NewCNode(lstm_split_input);
AnfAlgo::SetOutputInferTypeAndShape(split_types, split_shapes, lstm_split.get());
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_split), lstm_split);
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), lstm_split);
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToInt(num_split_x)), lstm_split);
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(0)), lstm_split);
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(num_split_x)), lstm_split);
return lstm_split;
}
@ -112,7 +112,7 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0];
std::vector<std::vector<size_t>> split_shapes;
std::vector<TypeId> split_types;
std::vector<int> size_split;
std::vector<int64_t> size_split;
for (size_t i = 0; i < num_split_x; ++i) {
split_shapes.emplace_back(split_c_dims);
split_types.emplace_back(kNumberTypeFloat32);
@ -238,9 +238,9 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
auto lstm_x_concat = func_graph->NewCNode(lstm_x_concat_input);
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2)},
lstm_x_concat.get());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(num_split_x)), lstm_x_concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(num_split_x)}), lstm_x_concat);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_x_concat);
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_x_concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_x_concat);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), lstm_x_concat);
// Create lstm_gage_concat
auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input);
@ -248,8 +248,8 @@ AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32},
{{origin_input7_shape[0], origin_input7_shape[1], 4 * origin_input7_shape[2]}},
lstm_gage_concat.get());
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(num_split_x)), lstm_gage_concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(num_split_x)}), lstm_gage_concat);
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), lstm_gage_concat);
outputs->emplace_back(lstm_x_concat);
@ -274,9 +274,10 @@ AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
std::vector<std::vector<size_t>> shapes = {shape1, shape2};
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get());
// Set attr
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(0), split_v);
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(2), split_v);
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(std::vector<int>{SizeToInt(origin_input6_shape[0] - 1), 1}), split_v);
AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(SizeToLong(0)), split_v);
AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(2)), split_v);
AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(std::vector<int64_t>{SizeToLong(origin_input6_shape[0] - 1), 1}),
split_v);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v);
return split_v;
}
@ -315,9 +316,9 @@ AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic
std::vector<size_t> shape = {splitv_output0_shape[0] + 1, origin_input4_shape[0], origin_input4_shape[1]};
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape}, concat.get());
// Set attr
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), concat);
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), concat);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
return concat;
}
@ -338,9 +339,9 @@ AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_
origin_output0_shape[2] + h_concat_output_shape[2]};
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
// Set attr
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(2), concat);
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(2)), concat);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
return concat;
}
@ -373,9 +374,9 @@ AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dy
std::vector<size_t> shape = {origin_input0_shape[0], origin_input0_shape[1], origin_input0_shape[2] + shape_tmp[2]};
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
// Set attr
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(2), concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int>{2}), concat);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(2), concat);
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat);
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{2}), concat);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(2)), concat);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
return concat;
}
@ -410,7 +411,7 @@ AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get());
// Set attr
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0}), reduce_sum);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
return reduce_sum;
@ -427,7 +428,7 @@ AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dyn
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 1)},
{AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 1)}, reduce_sum.get());
// Set attr
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int>{0, 1}), reduce_sum);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0, 1}), reduce_sum);
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
return reduce_sum;

@ -28,7 +28,6 @@ do
mkdir ${current_exec_path}/device$i
cd ${current_exec_path}/device$i || exit
cp ../../*.py .
cp ../../*.sh .
cp -r ../../src .
cp -r ../../config .
export RANK_ID=$i

Loading…
Cancel
Save