From 022005b94afb953ea3daf6cb51813143c6f35ef1 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Thu, 15 Oct 2020 16:31:32 +0800 Subject: [PATCH] fix a bug cases in reshape redistribution --- .../frontend/parallel/step_auto_parallel.cc | 16 ++++++++-- .../tensor_layout/reshape_layout_transfer.cc | 7 ++-- .../parallel/tensor_layout/tensor_layout.cc | 19 ++++++++--- .../parallel/tensor_layout/tensor_layout.h | 11 +++++-- .../tensor_layout/tensor_redistribution.cc | 2 +- .../python/parallel/test_reshape_unexpand.py | 32 +++++++++++++++++++ 6 files changed, 73 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 370ace402e..8c2b37b327 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -770,7 +770,7 @@ void AugmentCostGraph(const std::vector &all_nodes) { } } -bool FindReshape(const CNodePtr &cnode) { +bool FindReshape(const CNodePtr &cnode, std::unordered_set *op_cache) { if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { return false; } @@ -780,7 +780,16 @@ bool FindReshape(const CNodePtr &cnode) { ValueNodePtr prim_anf_node = cnode->input(0)->cast(); PrimitivePtr prim = GetValueNode(prim_anf_node); MS_EXCEPTION_IF_NULL(prim); - return (prim->name() == RESHAPE); + if (prim->name() == RESHAPE) { + auto operator_info = cnode->user_data(); + std::string op_info_name = operator_info->name(); + if (op_cache->find(op_info_name) != op_cache->end()) { + return false; + } + op_cache->insert(op_info_name); + return true; + } + return false; } // find previous node, then obtain its strategy_cost_ vector to get its layout vector. @@ -871,9 +880,10 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator } void ReshapeCostCompute(const std::vector &all_nodes) { + std::unordered_set op_cache; for (auto node : all_nodes) { auto cnode = node->cast(); - if (!FindReshape(cnode)) { + if (!FindReshape(cnode, &op_cache)) { continue; } MS_ASSERT(cnode->inputs().size() == 3); diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.cc index 98f7cf78fa..60cfffc45c 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.cc @@ -36,11 +36,14 @@ std::shared_ptr ReshapeLayoutTransfer::UnifyDeviceArrange while (!is_unified) { std::shared_ptr temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); if (temp_layout_ptr == nullptr) { - return nullptr; + out_layout_ptr->SetExpandAble(false); + return out_layout_ptr; } out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); if (out_layout_ptr == nullptr) { - return nullptr; + std::shared_ptr layout_ptr = std::make_shared(*this); + layout_ptr->SetExpandAble(false); + return layout_ptr; } is_unified = out_layout_ptr->IsSameTensorShape(); } diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc index daea00cd82..16fafbd113 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc @@ -58,7 +58,11 @@ Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tens MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString(); return Status::SUCCESS; } else { - MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString(); + if (layout_transfer_) { + MS_LOG(WARNING) << "invalid origin tensor layout " << this->OriginToString(); + } else { + MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString(); + } return Status::FAILED; } } @@ -90,7 +94,11 @@ bool TensorLayout::IsValidTensorLayout() const { return false; } if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) { - MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; + if (layout_transfer_) { + MS_LOG(WARNING) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; + } else { + MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; + } return false; } return true; @@ -214,6 +222,7 @@ std::shared_ptr TensorLayout::ExpandTensorShapeWithoutExtendDevice return nullptr; } TensorLayout tensor_layout_new; + tensor_layout_new.set_layout_transfer(true); Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape); if (status != Status::SUCCESS) { return nullptr; @@ -391,9 +400,9 @@ TensorLayout TensorLayout::SqueezeShape() const { } TensorLayout TensorLayout::TransferRepeatLayout() const { - Shape dev_mat(device_arrangement_.array()); - Shape tensor_map(tensor_map_.GetDimSize(), -1); - Shape tensor_shape(tensor_shape_.array()); + Shape dev_mat(device_arrangement_origin_.array()); + Shape tensor_map(tensor_map_origin_.GetDimSize(), -1); + Shape tensor_shape(tensor_shape_origin_.array()); TensorLayout repeat; repeat.InitFromVector(dev_mat, tensor_map, tensor_shape); return repeat; diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h index 8b0e9a662b..fefeae3e06 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h @@ -46,6 +46,10 @@ class TensorLayout { void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; } + bool layout_transfer() const { return layout_transfer_; } + + void set_layout_transfer(bool flag) { layout_transfer_ = flag; } + int32_t get_field_size() const { return field_size_; } void set_field_size(int32_t field_size) { field_size_ = field_size; } @@ -113,14 +117,15 @@ class TensorLayout { int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const; Arrangement device_arrangement_origin_; - Map tensor_map_origin_; Arrangement tensor_shape_origin_; Arrangement device_arrangement_; - Map tensor_map_; Arrangement tensor_shape_; + Map tensor_map_; + Map tensor_map_origin_; bool skip_redistribution_ = false; - int32_t field_size_ = 0; bool uniform_split_ = true; + bool layout_transfer_ = false; + int32_t field_size_ = 0; Shape opt_shard_slice_shape_; std::string opt_shard_group_ = ""; }; diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc index cd2e248ce5..4e45775123 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc @@ -43,7 +43,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL TensorLayout from_repeat = from_origin_.TransferRepeatLayout(); TensorLayout to_repeat = to_origin_.TransferRepeatLayout(); MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString(); - MS_LOG(DEBUG) << "reshape to_layout " << to_repeat.ToString(); + MS_LOG(DEBUG) << "reshape to_repeat " << to_repeat.ToString(); MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString(); MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString(); MS_LOG(DEBUG) << "reshape from_ " << from_.ToString(); diff --git a/tests/ut/python/parallel/test_reshape_unexpand.py b/tests/ut/python/parallel/test_reshape_unexpand.py index aed4db905d..792b982b65 100644 --- a/tests/ut/python/parallel/test_reshape_unexpand.py +++ b/tests/ut/python/parallel/test_reshape_unexpand.py @@ -204,3 +204,35 @@ def test_reshape_unexpand_6(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") net.set_auto_parallel() _executor.compile(net, x) + +def test_reshape_unexpand_7(): + class Net(nn.Cell): + def __init__(self, in_channel=3, out_channel=8, axis=1, input_shape=(32, 4, 110, -1), + mul_size=(32, 1, 220, 220)): + super().__init__() + mul_np = np.full(mul_size, 0.5, dtype=np.float32) + self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight") + self.mul = P.Mul() + self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, + kernel_size=5, has_bias=True, weight_init='ones', + bias_init='ones', pad_mode='valid') + self.softmax = nn.Softmax(axis=axis) + self.relu = nn.ReLU() + self.reshape = P.Reshape() + self.input_shape = input_shape + + def construct(self, inputs): + x = self.conv(inputs) + x = self.softmax(x) + x = self.relu(x) + x = self.mul(x, self.mul_weight) + x = self.reshape(x, self.input_shape) + return x + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + x = Tensor(np.ones([32, 3, 224, 224]), dtype=ms.float32) + net = GradWrap(NetWithLoss(Net())) + net.set_auto_parallel() + _executor.compile(net, x)