|
|
@ -20,6 +20,7 @@
|
|
|
|
#include <numeric>
|
|
|
|
#include <numeric>
|
|
|
|
#include <functional>
|
|
|
|
#include <functional>
|
|
|
|
#include <utility>
|
|
|
|
#include <utility>
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
|
|
|
|
#include "parallel/device_matrix.h"
|
|
|
|
#include "parallel/device_matrix.h"
|
|
|
|
#include "parallel/graph_util/generate_graph.h"
|
|
|
|
#include "parallel/graph_util/generate_graph.h"
|
|
|
@ -62,6 +63,55 @@ Status GatherV2PInfo::GetAttrs() {
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto manual_split_iter = attrs_.find("manual_split");
|
|
|
|
|
|
|
|
if (manual_split_iter != attrs_.end()) {
|
|
|
|
|
|
|
|
param_split_shapes_.clear();
|
|
|
|
|
|
|
|
manual_split_ = true;
|
|
|
|
|
|
|
|
auto var = manual_split_iter->second->cast<ValueTuplePtr>();
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (var->size() > 0) {
|
|
|
|
|
|
|
|
std::vector<ValuePtr> elements = var->value();
|
|
|
|
|
|
|
|
for (auto &ele : elements) {
|
|
|
|
|
|
|
|
if (ele->isa<ValueSequeue>()) {
|
|
|
|
|
|
|
|
auto value_tuple = ele->cast<ValueTuplePtr>();
|
|
|
|
|
|
|
|
std::vector<ValuePtr> value_vector = value_tuple->value();
|
|
|
|
|
|
|
|
if (value_vector.size() != 2) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2.";
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
param_split_shapes_.push_back(static_cast<int32_t>(GetValue<int>(value_vector[0])));
|
|
|
|
|
|
|
|
index_offsets_.push_back(static_cast<int32_t>(GetValue<int>(value_vector[1])));
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue";
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (param_split_shapes_.empty()) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Failed to extract param split strategy.";
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::CheckManualSplit() {
|
|
|
|
|
|
|
|
auto param_shape = inputs_shape_.at(0);
|
|
|
|
|
|
|
|
int32_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
|
|
|
|
|
|
|
|
[](int32_t s, int32_t shape) { return s + shape; });
|
|
|
|
|
|
|
|
if (split_shape_sum < param_shape.at(0)) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape.";
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int32_t &offset) { return offset < 0; })) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Failure: Index offset must not less than 0.";
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -103,6 +153,14 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
|
|
|
if (CheckManualSplit() != SUCCESS) {
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// when using manual_split, no need to check belowings.
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
|
|
|
|
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
|
|
|
|
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) {
|
|
|
|
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) {
|
|
|
|
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
|
|
|
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
|
|
@ -130,6 +188,11 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferMirrorOps() {
|
|
|
|
Status GatherV2PInfo::InferMirrorOps() {
|
|
|
|
|
|
|
|
// There is no mirror operators for manual split
|
|
|
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
mirror_ops_.clear();
|
|
|
|
mirror_ops_.clear();
|
|
|
|
Shape input_a_tensor_map = inputs_tensor_map_.at(0);
|
|
|
|
Shape input_a_tensor_map = inputs_tensor_map_.at(0);
|
|
|
|
std::vector<Group> input_a_group;
|
|
|
|
std::vector<Group> input_a_group;
|
|
|
@ -160,6 +223,13 @@ Status GatherV2PInfo::InferDevMatrixShape() {
|
|
|
|
// infer input dev_matrix_shape
|
|
|
|
// infer input dev_matrix_shape
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
auto index_strategy = strategy_->GetInputDim().at(1);
|
|
|
|
auto index_strategy = strategy_->GetInputDim().at(1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
|
|
|
dev_matrix_shape_ = param_strategy;
|
|
|
|
|
|
|
|
out_dev_matrix_shape_ = dev_matrix_shape_;
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
dev_matrix_shape_ = param_strategy;
|
|
|
|
dev_matrix_shape_ = param_strategy;
|
|
|
|
|
|
|
|
|
|
|
|
// param_strategy(axis)!=1,
|
|
|
|
// param_strategy(axis)!=1,
|
|
|
@ -195,6 +265,12 @@ Status GatherV2PInfo::InferDevMatrixShape() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferTensorMap() {
|
|
|
|
Status GatherV2PInfo::InferTensorMap() {
|
|
|
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
|
|
|
inputs_tensor_map_.push_back({1, 0});
|
|
|
|
|
|
|
|
inputs_tensor_map_.push_back({-1, 1});
|
|
|
|
|
|
|
|
outputs_tensor_map_.push_back({-1, 1, 0});
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
|
|
|
}
|
|
|
|
// infer input tensor map
|
|
|
|
// infer input tensor map
|
|
|
|
// param_strategy(axis) != 1
|
|
|
|
// param_strategy(axis) != 1
|
|
|
|
size_t param_size = inputs_shape_.at(0).size();
|
|
|
|
size_t param_size = inputs_shape_.at(0).size();
|
|
|
@ -261,8 +337,13 @@ Status GatherV2PInfo::InferTensorInfo() {
|
|
|
|
Shape input_shape = inputs_shape_.at(0);
|
|
|
|
Shape input_shape = inputs_shape_.at(0);
|
|
|
|
Shape input_index_shape = inputs_shape_.at(1);
|
|
|
|
Shape input_index_shape = inputs_shape_.at(1);
|
|
|
|
Shape output_shape = outputs_shape_.at(0);
|
|
|
|
Shape output_shape = outputs_shape_.at(0);
|
|
|
|
|
|
|
|
int32_t rank = g_device_manager->global_rank();
|
|
|
|
// infer tensor layout
|
|
|
|
// infer tensor layout
|
|
|
|
TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
|
|
|
|
TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
|
|
|
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
|
|
|
input_shape[0] = param_split_shapes_[rank / dev_matrix_shape_[1]];
|
|
|
|
|
|
|
|
input_shape[0] = input_shape[0] * dev_matrix_shape_[0];
|
|
|
|
|
|
|
|
}
|
|
|
|
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
|
|
|
|
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
|
|
|
|
(input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
|
|
|
|
(input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
|
|
|
|
(output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) !=
|
|
|
|
(output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) !=
|
|
|
@ -274,6 +355,9 @@ Status GatherV2PInfo::InferTensorInfo() {
|
|
|
|
TensorInfo input_index_info(input_index_layout);
|
|
|
|
TensorInfo input_index_info(input_index_layout);
|
|
|
|
TensorInfo output_tensor_info(output_tensor_layout);
|
|
|
|
TensorInfo output_tensor_info(output_tensor_layout);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Shape slice_shape = input_tensor_info.slice_shape();
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape);
|
|
|
|
|
|
|
|
|
|
|
|
inputs_tensor_info_.push_back(input_tensor_info);
|
|
|
|
inputs_tensor_info_.push_back(input_tensor_info);
|
|
|
|
inputs_tensor_info_.push_back(input_index_info);
|
|
|
|
inputs_tensor_info_.push_back(input_index_info);
|
|
|
|
outputs_tensor_info_.push_back(output_tensor_info);
|
|
|
|
outputs_tensor_info_.push_back(output_tensor_info);
|
|
|
@ -312,6 +396,19 @@ Status GatherV2PInfo::InferBias() {
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferOffset() {
|
|
|
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
|
|
|
|
size_t rank = g_device_manager->global_rank();
|
|
|
|
|
|
|
|
if (rank < index_offsets_.size()) {
|
|
|
|
|
|
|
|
index_offset_ = index_offsets_.at(rank);
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Get index offset failed, index offset size is" << index_offsets_.size();
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferGroup() {
|
|
|
|
Status GatherV2PInfo::InferGroup() {
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
size_t dim = IntToSize(axis_);
|
|
|
|
size_t dim = IntToSize(axis_);
|
|
|
@ -410,6 +507,19 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|
|
|
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
|
|
|
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
|
|
|
if (InferOffset() != SUCCESS) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)});
|
|
|
|
|
|
|
|
auto gather_v2 =
|
|
|
|
|
|
|
|
gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub, CreatInt32Imm(axis_)});
|
|
|
|
|
|
|
|
std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)};
|
|
|
|
|
|
|
|
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>(
|
|
|
|
|
|
|
|
std::make_pair(input_nodes, gather_v2));
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
|
|
|
}
|
|
|
|
if (InferBias() != SUCCESS) {
|
|
|
|
if (InferBias() != SUCCESS) {
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
@ -444,6 +554,14 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
|
|
|
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
|
|
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
|
|
|
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return replace_graph_;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
// target_ == CPU, no need to raplace graph
|
|
|
|
// target_ == CPU, no need to raplace graph
|
|
|
|
if (target_ == CPU) {
|
|
|
|
if (target_ == CPU) {
|
|
|
|