!7345 fix a bug case in reshape redistribution

Merge pull request !7345 from yao_yf/reshape_redistribution_all_scene_support_add
pull/7345/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit a5e8c1eab3

@ -770,7 +770,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
} }
} }
bool FindReshape(const CNodePtr &cnode) { bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false; return false;
} }
@ -780,7 +780,16 @@ bool FindReshape(const CNodePtr &cnode) {
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
return (prim->name() == RESHAPE); if (prim->name() == RESHAPE) {
auto operator_info = cnode->user_data<OperatorInfo>();
std::string op_info_name = operator_info->name();
if (op_cache->find(op_info_name) != op_cache->end()) {
return false;
}
op_cache->insert(op_info_name);
return true;
}
return false;
} }
// find previous node, then obtain its strategy_cost_ vector to get its layout vector. // find previous node, then obtain its strategy_cost_ vector to get its layout vector.
@ -871,9 +880,10 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator
} }
void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
std::unordered_set<std::string> op_cache;
for (auto node : all_nodes) { for (auto node : all_nodes) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if (!FindReshape(cnode)) { if (!FindReshape(cnode, &op_cache)) {
continue; continue;
} }
MS_ASSERT(cnode->inputs().size() == 3); MS_ASSERT(cnode->inputs().size() == 3);

@ -36,11 +36,14 @@ std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::UnifyDeviceArrange
while (!is_unified) { while (!is_unified) {
std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo();
if (temp_layout_ptr == nullptr) { if (temp_layout_ptr == nullptr) {
return nullptr; out_layout_ptr->SetExpandAble(false);
return out_layout_ptr;
} }
out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom();
if (out_layout_ptr == nullptr) { if (out_layout_ptr == nullptr) {
return nullptr; std::shared_ptr<ReshapeLayoutTransfer> layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this);
layout_ptr->SetExpandAble(false);
return layout_ptr;
} }
is_unified = out_layout_ptr->IsSameTensorShape(); is_unified = out_layout_ptr->IsSameTensorShape();
} }

@ -58,7 +58,11 @@ Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tens
MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString(); MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString();
return Status::SUCCESS; return Status::SUCCESS;
} else { } else {
MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString(); if (layout_transfer_) {
MS_LOG(WARNING) << "invalid origin tensor layout " << this->OriginToString();
} else {
MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString();
}
return Status::FAILED; return Status::FAILED;
} }
} }
@ -90,7 +94,11 @@ bool TensorLayout::IsValidTensorLayout() const {
return false; return false;
} }
if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) { if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) {
MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; if (layout_transfer_) {
MS_LOG(WARNING) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
} else {
MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
}
return false; return false;
} }
return true; return true;
@ -214,6 +222,7 @@ std::shared_ptr<TensorLayout> TensorLayout::ExpandTensorShapeWithoutExtendDevice
return nullptr; return nullptr;
} }
TensorLayout tensor_layout_new; TensorLayout tensor_layout_new;
tensor_layout_new.set_layout_transfer(true);
Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape); Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape);
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return nullptr; return nullptr;
@ -391,9 +400,9 @@ TensorLayout TensorLayout::SqueezeShape() const {
} }
TensorLayout TensorLayout::TransferRepeatLayout() const { TensorLayout TensorLayout::TransferRepeatLayout() const {
Shape dev_mat(device_arrangement_.array()); Shape dev_mat(device_arrangement_origin_.array());
Shape tensor_map(tensor_map_.GetDimSize(), -1); Shape tensor_map(tensor_map_origin_.GetDimSize(), -1);
Shape tensor_shape(tensor_shape_.array()); Shape tensor_shape(tensor_shape_origin_.array());
TensorLayout repeat; TensorLayout repeat;
repeat.InitFromVector(dev_mat, tensor_map, tensor_shape); repeat.InitFromVector(dev_mat, tensor_map, tensor_shape);
return repeat; return repeat;

@ -46,6 +46,10 @@ class TensorLayout {
void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; } void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; }
bool layout_transfer() const { return layout_transfer_; }
void set_layout_transfer(bool flag) { layout_transfer_ = flag; }
int32_t get_field_size() const { return field_size_; } int32_t get_field_size() const { return field_size_; }
void set_field_size(int32_t field_size) { field_size_ = field_size; } void set_field_size(int32_t field_size) { field_size_ = field_size; }
@ -113,14 +117,15 @@ class TensorLayout {
int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const; int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const;
Arrangement device_arrangement_origin_; Arrangement device_arrangement_origin_;
Map tensor_map_origin_;
Arrangement tensor_shape_origin_; Arrangement tensor_shape_origin_;
Arrangement device_arrangement_; Arrangement device_arrangement_;
Map tensor_map_;
Arrangement tensor_shape_; Arrangement tensor_shape_;
Map tensor_map_;
Map tensor_map_origin_;
bool skip_redistribution_ = false; bool skip_redistribution_ = false;
int32_t field_size_ = 0;
bool uniform_split_ = true; bool uniform_split_ = true;
bool layout_transfer_ = false;
int32_t field_size_ = 0;
Shape opt_shard_slice_shape_; Shape opt_shard_slice_shape_;
std::string opt_shard_group_ = ""; std::string opt_shard_group_ = "";
}; };

@ -43,7 +43,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL
TensorLayout from_repeat = from_origin_.TransferRepeatLayout(); TensorLayout from_repeat = from_origin_.TransferRepeatLayout();
TensorLayout to_repeat = to_origin_.TransferRepeatLayout(); TensorLayout to_repeat = to_origin_.TransferRepeatLayout();
MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString(); MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString();
MS_LOG(DEBUG) << "reshape to_layout " << to_repeat.ToString(); MS_LOG(DEBUG) << "reshape to_repeat " << to_repeat.ToString();
MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString(); MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString();
MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString(); MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString();
MS_LOG(DEBUG) << "reshape from_ " << from_.ToString(); MS_LOG(DEBUG) << "reshape from_ " << from_.ToString();

@ -204,3 +204,35 @@ def test_reshape_unexpand_6():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel() net.set_auto_parallel()
_executor.compile(net, x) _executor.compile(net, x)
def test_reshape_unexpand_7():
class Net(nn.Cell):
def __init__(self, in_channel=3, out_channel=8, axis=1, input_shape=(32, 4, 110, -1),
mul_size=(32, 1, 220, 220)):
super().__init__()
mul_np = np.full(mul_size, 0.5, dtype=np.float32)
self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
self.mul = P.Mul()
self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=5, has_bias=True, weight_init='ones',
bias_init='ones', pad_mode='valid')
self.softmax = nn.Softmax(axis=axis)
self.relu = nn.ReLU()
self.reshape = P.Reshape()
self.input_shape = input_shape
def construct(self, inputs):
x = self.conv(inputs)
x = self.softmax(x)
x = self.relu(x)
x = self.mul(x, self.mul_weight)
x = self.reshape(x, self.input_shape)
return x
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
x = Tensor(np.ones([32, 3, 224, 224]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
net.set_auto_parallel()
_executor.compile(net, x)

Loading…
Cancel
Save