Fix repeat error

pull/8455/head
huangxinjing 4 years ago
parent 8aa78c2c8e
commit 3e9fac7f59

@ -103,6 +103,35 @@ Status UnsortedSegmentOpInfo::InferDevMatrixShape() {
return SUCCESS;
}
Status UnsortedSegmentOpInfo::InferMirrorOps() {
mirror_ops_.clear();
// Only the first input could be parameter.
Shape tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group failed.";
return FAILED;
}
OperatorVector mirror_op;
OperatorVector op_for_value;
OperatorVector op_for_value2;
if (group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
return SUCCESS;
} else {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
mirror_ops_.push_back(op_for_value);
mirror_ops_.push_back(op_for_value2);
std::string group_name = group[0].name();
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
}
return SUCCESS;
}
// As the op converts the vector x1,x2,x3...,xr -> number of segments, xn,..,xr
// the dimension x1,x2,x3,..,xn is eliminated
// suppose the strategy of the inputs is (a,b,c,d), (a,b)
@ -221,9 +250,6 @@ Status UnsortedSegmentOpInfo::InferForwardCommunication() {
std::vector<Group> group_list;
Shape tmp_group_tensor_map = outputs_tensor_map_.at(0);
if (repeated_calc_num_ > 1) {
for (size_t i = 1; i < tmp_group_tensor_map.size(); ++i) {
tmp_group_tensor_map[i] += 1;
}
tmp_group_tensor_map.push_back(0);
}
if (CreateGroupByTensorMap(tmp_group_tensor_map, &group_list) != SUCCESS) {

@ -47,7 +47,7 @@ class UnsortedSegmentOpInfo : public OperatorInfo {
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override;
Status InferMirrorOps() override { return SUCCESS; }
Status InferMirrorOps() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;

Loading…
Cancel
Save