|
|
|
@ -57,6 +57,15 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// param slice shape need 32Byte aligned
|
|
|
|
|
|
|
|
auto param_shape = inputs_shape_.at(0);
|
|
|
|
|
|
|
|
auto param_strategy = strategy->GetInputDim().at(0);
|
|
|
|
|
|
|
|
auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1);
|
|
|
|
|
|
|
|
if (slice_shape % 8 != 0) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned.";
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// only support 1-dim and 2-dim param
|
|
|
|
// only support 1-dim and 2-dim param
|
|
|
|
if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) {
|
|
|
|
if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) {
|
|
|
|
MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size();
|
|
|
|
MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size();
|
|
|
|
@ -71,14 +80,12 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
|
|
|
|
|
|
|
// axis=0, index_shape(0)%param_strategy(0) must be 0
|
|
|
|
// axis=0, index_shape(0)%param_strategy(0) must be 0
|
|
|
|
Shape index_shape = inputs_shape_.at(1);
|
|
|
|
Shape index_shape = inputs_shape_.at(1);
|
|
|
|
auto param_strategy = strategy->GetInputDim().at(0);
|
|
|
|
|
|
|
|
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) {
|
|
|
|
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) {
|
|
|
|
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
|
|
|
|
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
|
|
|
|
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
|
|
|
|
Shape param_shape = inputs_shape_.at(0);
|
|
|
|
|
|
|
|
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) {
|
|
|
|
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) {
|
|
|
|
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
|
|
|
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
@ -158,12 +165,12 @@ Status GatherV2PInfo::InferDevMatrixShape() {
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
out_dev_matrix_shape_ = dev_matrix_shape_;
|
|
|
|
out_dev_matrix_shape_ = dev_matrix_shape_;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto product_out =
|
|
|
|
|
|
|
|
std::accumulate(out_dev_matrix_shape_.begin(), out_dev_matrix_shape_.end(), 1, std::multiplies<int>());
|
|
|
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
|
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
|
|
|
if (product_out == 1) {
|
|
|
|
auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
|
|
|
|
out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), dev_num);
|
|
|
|
auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int>());
|
|
|
|
|
|
|
|
if (param_product * index_product < SizeToInt(dev_num)) {
|
|
|
|
|
|
|
|
out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), SizeToInt(dev_num / (param_product * index_product)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
return SUCCESS;
|
|
|
|
@ -174,7 +181,7 @@ Status GatherV2PInfo::InferTensorMap() {
|
|
|
|
// param_strategy(axis) != 1
|
|
|
|
// param_strategy(axis) != 1
|
|
|
|
size_t param_size = inputs_shape_.at(0).size();
|
|
|
|
size_t param_size = inputs_shape_.at(0).size();
|
|
|
|
size_t index_size = inputs_shape_.at(1).size();
|
|
|
|
size_t index_size = inputs_shape_.at(1).size();
|
|
|
|
size_t total_size = dev_matrix_shape_.size();
|
|
|
|
size_t total_size = param_size + index_size;
|
|
|
|
std::vector<int32_t> tensor_map_index;
|
|
|
|
std::vector<int32_t> tensor_map_index;
|
|
|
|
std::vector<int32_t> tensor_map_params;
|
|
|
|
std::vector<int32_t> tensor_map_params;
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
|