diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc index 81f276182f..e2d01fb779 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc @@ -57,6 +57,15 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } + // param slice shape need 32Byte aligned + auto param_shape = inputs_shape_.at(0); + auto param_strategy = strategy->GetInputDim().at(0); + auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1); + if (slice_shape % 8 != 0) { + MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned."; + return FAILED; + } + // only support 1-dim and 2-dim param if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) { MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size(); @@ -71,14 +80,12 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { // axis=0, index_shape(0)%param_strategy(0) must be 0 Shape index_shape = inputs_shape_.at(1); - auto param_strategy = strategy->GetInputDim().at(0); if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) { MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; return FAILED; } // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 - Shape param_shape = inputs_shape_.at(0); if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) { MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; return FAILED; @@ -158,12 +165,12 @@ Status GatherV2PInfo::InferDevMatrixShape() { } else { out_dev_matrix_shape_ = dev_matrix_shape_; } - auto product_out = - std::accumulate(out_dev_matrix_shape_.begin(), out_dev_matrix_shape_.end(), 1, std::multiplies()); CheckGlobalDeviceManager(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - if (product_out == 1) { - out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), dev_num); + auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); + auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); + if (param_product * index_product < SizeToInt(dev_num)) { + out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), SizeToInt(dev_num / (param_product * index_product))); } return SUCCESS; @@ -174,7 +181,7 @@ Status GatherV2PInfo::InferTensorMap() { // param_strategy(axis) != 1 size_t param_size = inputs_shape_.at(0).size(); size_t index_size = inputs_shape_.at(1).size(); - size_t total_size = dev_matrix_shape_.size(); + size_t total_size = param_size + index_size; std::vector tensor_map_index; std::vector tensor_map_params; auto param_strategy = strategy_->GetInputDim().at(0); diff --git a/tests/ut/python/parallel/test_gather_v2.py b/tests/ut/python/parallel/test_gather_v2.py index 417c3ca45c..2720cb33e1 100644 --- a/tests/ut/python/parallel/test_gather_v2.py +++ b/tests/ut/python/parallel/test_gather_v2.py @@ -67,8 +67,8 @@ def test_gatherv2_semi_auto0(): net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) net.set_auto_parallel() - x = Tensor(np.ones([64, 32]), dtype=ms.float32) - y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) _executor.compile(net, x, y) @@ -79,8 +79,8 @@ def test_gatherv2_semi_auto1(): net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) net.set_auto_parallel() - x = Tensor(np.ones([64, 32]), dtype=ms.float32) - y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) _executor.compile(net, x, y) @@ -91,8 +91,8 @@ def test_gatherv2_semi_auto2(): net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) net.set_auto_parallel() - x = Tensor(np.ones([64, 32]), dtype=ms.float32) - y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) _executor.compile(net, x, y) @@ -103,7 +103,7 @@ def test_gatherv2_semi_auto3(): net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) net.set_auto_parallel() - x = Tensor(np.ones([64, 32]), dtype=ms.float32) + x = Tensor(np.ones([64, 64]), dtype=ms.float32) y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) _executor.compile(net, x, y)