!2687 [CT][MS][Auto-Parallel]Double recursion does not support the gatherv2 operator

Merge pull request !2687 from Chong/zc
pull/2687/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 512d8e8510

@ -164,9 +164,34 @@ std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &gr
return strategies;
}
std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s) {
std::vector<std::vector<int32_t>> PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s) {
std::vector<std::vector<int32_t>> strategies;
strategies.push_back(*s);
int32_t axis = 0;
auto axis_input = GetValue<int>(ops[iter_ops]->input_value().at(2));
if (axis_input < 0) {
axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
}
axis = axis_input;
if (axis >= SizeToInt(s.size())) {
MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range.";
}
s[axis] = 1;
strategies.push_back(s);
auto pos = ops[iter_ops]->name().find("Info");
auto name = ops[iter_ops]->name().substr(0, pos);
if (name == "GatherV2") {
return strategies;
}
std::vector<int32_t> s_indices;
for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) {
s_indices.push_back(1);
}
strategies.push_back(s_indices);
return strategies;
}
@ -607,7 +632,7 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
return PrepareBiasAdd(s_ptr);
}
if (ops[iter_ops]->type() == GATHERV2) {
return PrepareGatherV2(s_ptr);
return PrepareGatherV2(ops, iter_ops, basic_stra);
}
if (ops[iter_ops]->type() == L2_NORMALIZE) {
return PrepareL2Normalize(ops, iter_ops, basic_stra);

@ -38,7 +38,8 @@ std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vect
std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s);
std::vector<std::vector<int32_t>> PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<std::vector<int32_t>> PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,

@ -40,7 +40,7 @@ const TensorParam MakeTensor(int n, int c, int h, int w) {
return tensor;
}
Graph::NodeType MakeNewOperator(std::vector<std::shared_ptr<OperatorInfo>> ops, size_t iter_ops) {
Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops) {
Graph::NodeType NewOp;
NewOp.name = ops[iter_ops]->name();
NewOp.info = InfoType::kApplication;
@ -140,7 +140,7 @@ std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo
return graph;
}
void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, std::shared_ptr<Graph> graph) {
void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph) {
for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) {
for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) {
size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]);

@ -111,7 +111,7 @@ const std::map<std::string, OperatorType> DictOpType{
const TensorParam MakeTensor(int n, int c, int h, int w);
Graph::NodeType MakeNewOperator(std::vector<std::shared_ptr<OperatorInfo>> ops, size_t iter_ops);
Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops);
OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Graph::NodeType NewTensor);
@ -122,7 +122,7 @@ TensorParam Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &o
std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names);
void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, std::shared_ptr<Graph> graph);
void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph);
size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_names,
const std::string &input_name);

@ -93,7 +93,7 @@ double GetWeights(const Graph::NodeType &node) {
}
// Sort all the nodes by their weights
std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph) {
std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<std::pair<double, size_t>> weight_to_node_index;
@ -124,7 +124,7 @@ std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph) {
// Get optimal strategy to partition the target node
StrategyRec PartitionNode(const Graph::NodeType &node,
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
std::shared_ptr<Graph> graph) {
const std::shared_ptr<Graph> &graph) {
bool enable_conv_chw_partition = false;
MS_EXCEPTION_IF_NULL(graph);
@ -191,7 +191,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
}
// Parttion graph into all devices.
Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) {
Status PartitionForAllDevices(const size_t num_device, const double device_memory,
const std::shared_ptr<Graph> &graph) {
if (num_device < 1) {
MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << ".";
}
@ -261,7 +262,7 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) {
return Node;
}
Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) {
Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph) {
MS_EXCEPTION_IF_NULL(graph);
if (num_device == 0) {
MS_LOG(EXCEPTION) << "Failure: device number is 0.";

@ -32,19 +32,19 @@
namespace mindspore {
namespace parallel {
std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph);
std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> &graph);
double GetWeights(const Graph::NodeType &node);
StrategyRec PartitionNode(const Graph::NodeType &node,
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
std::shared_ptr<Graph> graph);
const std::shared_ptr<Graph> &graph);
Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph);
Status PartitionForAllDevices(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph);
Graph::NodeType ApplyStrToTensor(Graph::NodeType Node);
Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph);
Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph);
size_t GetDataTypeSize(const TensorType &type);
} // namespace parallel

Loading…
Cancel
Save