diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc index 6d7c7fe350..b27e33f284 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc @@ -488,10 +488,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector RedistributionLayoutTransfer::UnifyDevice if (unified_device_arrangement_ptr == nullptr) { return nullptr; } + Shape in_expand_shape; + Status status = ExpandShape(unified_device_arrangement_ptr->from_in().tensor_shape().array(), + unified_device_arrangement_ptr->to_in().tensor_shape().array(), &in_expand_shape); + if (status != Status::SUCCESS) { + MS_LOG(INFO) << "The shape of from and to cannot transfer by unify"; + unified_device_arrangement_ptr->SetExpandAble(false); + return unified_device_arrangement_ptr; + } return unified_device_arrangement_ptr->UnifyDeviceArrangementAndTensorShape(); } } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h index fc9583f38a..20b571059c 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h @@ -35,12 +35,15 @@ class ReshapeLayoutTransfer : public LayoutTransfer { std::shared_ptr ExpandFromTensorShapeAndExpandToDeviceArrangement( const Arrangement &expand_shape) const; std::shared_ptr ExchangeFromAndTo() const; + bool ExpandAble() const { return is_expand_able_; } + bool FromTensorShapeCanBeExpandByTo() const; + bool ToTensorShapeCanBeExpandByFrom() const; + void SetExpandAble(const bool is_expand_able) { is_expand_able_ = is_expand_able; } private: Status CheckValidTransfer() override; std::shared_ptr ComputeExpandedFromTensorShapeByTo() const; - bool FromTensorShapeCanBeExpandByTo() const; - bool ToTensorShapeCanBeExpandByFrom() const; + bool is_expand_able_ = true; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc index 9828558508..760b748883 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc @@ -97,11 +97,11 @@ Status AccumulateProductReverseToShape(const Shape &shape_accum_reverse, Shape * int64_t value = 1; for (auto iter = shape_accum_reverse.end() - 1; iter >= shape_accum_reverse.begin(); --iter) { if (*iter == 0) { - MS_LOG(ERROR) << "element of shape_accum should not be zero"; + MS_LOG(WARNING) << "element of shape_accum should not be zero"; return Status::FAILED; } if ((*iter) % value != 0) { - MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; + MS_LOG(WARNING) << "shape_accum is not a accumulate product in ascending order"; return Status::FAILED; } (void)shape->insert(shape->begin(), static_cast((*iter) / value)); diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc index 135cbc37ab..daea00cd82 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc @@ -390,6 +390,15 @@ TensorLayout TensorLayout::SqueezeShape() const { return out; } +TensorLayout TensorLayout::TransferRepeatLayout() const { + Shape dev_mat(device_arrangement_.array()); + Shape tensor_map(tensor_map_.GetDimSize(), -1); + Shape tensor_shape(tensor_shape_.array()); + TensorLayout repeat; + repeat.InitFromVector(dev_mat, tensor_map, tensor_shape); + return repeat; +} + // Generate a totally shard tensor slice shape for parallel optimizer Status TensorLayout::GenerateOptShardSliceShape() { MS_LOG(INFO) << "layout for GetOptShardSliceShape is " << StandardToString(); diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h index c412ed93d0..8b0e9a662b 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h @@ -88,6 +88,8 @@ class TensorLayout { TensorLayout SqueezeShape() const; + TensorLayout TransferRepeatLayout() const; + Status GenerateOptShardSliceShape(); Shape opt_shard_slice_shape() { return opt_shard_slice_shape_; } diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc index 79991d1f09..cd2e248ce5 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc @@ -39,6 +39,42 @@ Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout & return Status::SUCCESS; } +RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorListUnExpand(bool is_cost_model) { + TensorLayout from_repeat = from_origin_.TransferRepeatLayout(); + TensorLayout to_repeat = to_origin_.TransferRepeatLayout(); + MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString(); + MS_LOG(DEBUG) << "reshape to_layout " << to_repeat.ToString(); + MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString(); + MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString(); + MS_LOG(DEBUG) << "reshape from_ " << from_.ToString(); + MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); + OperatorVector operator_vector; + OutPutInfoVector output_info_vector; + if (InferRedistribution(from_origin_, from_repeat, &operator_vector, &output_info_vector, is_cost_model) == + Status::FAILED) { + return nullptr; + } + if (from_repeat.slice_shape().array() != to_repeat.slice_shape().array()) { + reshape_flag_ = true; + ConstructOperator constructor; + constructor.UpdateTensorShape(from_repeat.slice_shape().array()); + Arrangement shape = to_repeat.slice_shape(); + MS_LOG(DEBUG) << "reshape " << shape.ToString(); + if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { + return nullptr; + } else { + (void)operator_vector.push_back(constructor.GetOperator()); + (void)output_info_vector.push_back(std::make_pair(false, 0)); + } + } + if (InferRedistribution(to_repeat, to_origin_, &operator_vector, &output_info_vector, is_cost_model) == + Status::FAILED) { + return nullptr; + } + return std::make_shared>( + std::make_pair(operator_vector, output_info_vector)); +} + RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) { // Step 1: Match device arrangement between from_ and to_ RedistributionLayoutTransfer layout_transfer; @@ -51,6 +87,10 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL MS_LOG(ERROR) << "Infer tensor layout return nullptr!"; return nullptr; } + if (!ptr->ExpandAble()) { + expand_able_ = false; + return InferTensorRedistributionOperatorListUnExpand(is_cost_model); + } TensorLayout from_layout = ptr->from_in(); TensorLayout to_layout = ptr->to_in(); MS_LOG(DEBUG) << "reshape from_layout " << from_layout.ToString(); @@ -61,27 +101,17 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); // Step 2: Infer redistribution and insert operators RedistributionOperatorInfer operator_infer(construct_op_flag_); - if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) { - MS_LOG(ERROR) << "Init operatorInfer failed!"; - return nullptr; - } OperatorVector operator_vector; OutPutInfoVector output_info_vector; - if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) { - MS_LOG(ERROR) << "Infer redistribution failed!"; + if (InferRedistribution(from_layout, to_layout, &operator_vector, &output_info_vector, is_cost_model) != + Status::SUCCESS) { return nullptr; - } else { - operator_vector = operator_infer.operator_vector(); - output_info_vector = operator_infer.output_info_vector(); - operator_list_ = operator_infer.operator_list(); } - // Step 3: Infer reshape and insert operators if (InferReshape(from_layout, to_layout, &operator_vector, &output_info_vector) != Status::SUCCESS) { MS_LOG(ERROR) << "Construct Reshape operator failed!"; return nullptr; } - return std::make_shared>( std::make_pair(operator_vector, output_info_vector)); } @@ -136,6 +166,31 @@ Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const return Status::SUCCESS; } +Status TensorRedistribution::InferRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout, + OperatorVector *const operator_vector, + OutPutInfoVector *const output_info_vector, bool is_cost_model) { + RedistributionOperatorInfer operator_infer(construct_op_flag_); + if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) { + MS_LOG(ERROR) << "Init operatorInfer failed"; + return Status::FAILED; + } + if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) { + MS_LOG(ERROR) << "Infer redistribution failed"; + return Status::FAILED; + } else { + for (auto op : operator_infer.operator_vector()) { + operator_vector->insert(operator_vector->end(), op); + } + for (auto info : operator_infer.output_info_vector()) { + output_info_vector->insert(output_info_vector->end(), info); + } + for (auto opc : operator_infer.operator_list()) { + operator_list_.insert(operator_list_.end(), opc); + } + } + return Status::SUCCESS; +} + Status TensorRedistribution::ComputeCost() { RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true); if (redistribution_oplist_ptr == nullptr) { @@ -162,8 +217,13 @@ Status TensorRedistribution::ComputeCost() { } } if (reshape_flag()) { - Shape prev_slice_shape = from_.slice_shape().array(); - double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies()); + Shape prev_shape; + if (expand_able_) { + prev_shape = from_.slice_shape().array(); + } else { + prev_shape = from_.tensor_shape().array(); + } + double prev_prod = std::accumulate(prev_shape.begin(), prev_shape.end(), 1, std::multiplies()); computation_cost_ += 2.0 * prev_prod; memory_cost_ += 2.0 * prev_prod; } diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h index 2509e28553..1a7735611a 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h @@ -61,8 +61,12 @@ class TensorRedistribution { private: Status InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector); + Status InferRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout, + OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector, + bool is_cost_model); Status ComputeConcatCost(double input_size, Shape attrs); Status ComputePermuteCost(double input_size, Shape attrs); + RedistributionOpListPtr InferTensorRedistributionOperatorListUnExpand(bool is_cost_model = false); TensorLayout from_origin_; TensorLayout to_origin_; TensorLayout from_; @@ -84,6 +88,7 @@ class TensorRedistribution { double memory_cost_; bool construct_op_flag_; bool keep_reshape_; + bool expand_able_ = true; }; } // namespace parallel } // namespace mindspore diff --git a/tests/ut/python/parallel/test_reshape_unexpand.py b/tests/ut/python/parallel/test_reshape_unexpand.py new file mode 100644 index 0000000000..aed4db905d --- /dev/null +++ b/tests/ut/python/parallel/test_reshape_unexpand.py @@ -0,0 +1,206 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.common.parameter import Parameter +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +grad_all = C.GradOperation(get_all=True) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x): + predict = self.network(x) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x): + return grad_all(self.network)(x) + +def test_reshape_unexpand(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape = P.Reshape() + self.mul = P.Mul().shard(((1, 8), (1, 1, 8))) + self.mul_weight = Parameter(Tensor(np.ones([96, 128]), dtype=ms.float32), name="weight") + + def construct(self, x): + weight = self.reshape(self.mul_weight, (1, 128, 96)) + out = self.mul(x, weight) + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([128, 96]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x) + +def test_reshape_unexpand_1(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape = P.Reshape() + self.mul = P.Mul().shard(((1, 8), (1, 1, 8))) + self.mul_weight = Parameter(Tensor(np.ones([96, 128]), dtype=ms.float32), name="weight") + + def construct(self, x): + weight = self.reshape(self.mul_weight, (1, 128, 96)) + out = self.mul(x, weight) + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([128, 96]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x) + +def test_reshape_unexpand_2(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape = P.Reshape() + self.mul = P.Mul().shard(((1, 4, 2), (4, 2))) + self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") + + def construct(self, data): + x = self.reshape(self.mul_weight, (1, 128, 96)) + out = self.mul(x, self.mul_weight) + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([128, 96]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x) + +def test_reshape_unexpand_3(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape = P.Reshape() + self.relu1 = P.ReLU().shard(((4, 1),)) + self.relu2 = P.ReLU().shard(((1, 4),)) + + def construct(self, data): + x = self.relu1(data) + x = self.reshape(x, (3, 4)) + x = self.relu2(x) + return x + + size = 4 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([4, 3]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x) + +def test_reshape_unexpand_4(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape = P.Reshape() + self.relu1 = P.ReLU().shard(((4, 1),)) + self.relu2 = P.ReLU().shard(((1, 2, 2),)) + + def construct(self, data): + x = self.relu1(data) + x = self.reshape(x, (3, 2, 2)) + x = self.relu2(x) + return x + + size = 4 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([4, 3]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x) + +def test_reshape_unexpand_5(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape = P.Reshape() + self.relu1 = P.ReLU().shard(((2, 2, 1),)) + self.relu2 = P.ReLU().shard(((1, 4),)) + + def construct(self, data): + x = self.relu1(data) + x = self.reshape(x, (3, 4)) + x = self.relu2(x) + return x + + size = 4 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([2, 2, 3]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x) + +def test_reshape_unexpand_6(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape = P.Reshape() + self.relu1 = P.ReLU().shard(((2, 1),)) + self.relu2 = P.ReLU().shard(((1, 1, 4),)) + + def construct(self, data): + x = self.relu1(data) + x = self.reshape(x, (1, 3, 4)) + x = self.relu2(x) + return x + + size = 4 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([4, 3]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x)