|
|
|
@ -103,6 +103,35 @@ Status UnsortedSegmentOpInfo::InferDevMatrixShape() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status UnsortedSegmentOpInfo::InferMirrorOps() {
|
|
|
|
|
mirror_ops_.clear();
|
|
|
|
|
|
|
|
|
|
// Only the first input could be parameter.
|
|
|
|
|
Shape tensor_map = inputs_tensor_map_[0];
|
|
|
|
|
std::vector<Group> group;
|
|
|
|
|
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Create group failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OperatorVector mirror_op;
|
|
|
|
|
OperatorVector op_for_value;
|
|
|
|
|
OperatorVector op_for_value2;
|
|
|
|
|
if (group.empty()) {
|
|
|
|
|
MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
} else {
|
|
|
|
|
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
|
|
|
|
|
mirror_ops_.push_back(mirror_op);
|
|
|
|
|
mirror_ops_.push_back(op_for_value);
|
|
|
|
|
mirror_ops_.push_back(op_for_value2);
|
|
|
|
|
std::string group_name = group[0].name();
|
|
|
|
|
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// As the op converts the vector x1,x2,x3...,xr -> number of segments, xn,..,xr
|
|
|
|
|
// the dimension x1,x2,x3,..,xn is eliminated
|
|
|
|
|
// suppose the strategy of the inputs is (a,b,c,d), (a,b)
|
|
|
|
@ -221,9 +250,6 @@ Status UnsortedSegmentOpInfo::InferForwardCommunication() {
|
|
|
|
|
std::vector<Group> group_list;
|
|
|
|
|
Shape tmp_group_tensor_map = outputs_tensor_map_.at(0);
|
|
|
|
|
if (repeated_calc_num_ > 1) {
|
|
|
|
|
for (size_t i = 1; i < tmp_group_tensor_map.size(); ++i) {
|
|
|
|
|
tmp_group_tensor_map[i] += 1;
|
|
|
|
|
}
|
|
|
|
|
tmp_group_tensor_map.push_back(0);
|
|
|
|
|
}
|
|
|
|
|
if (CreateGroupByTensorMap(tmp_group_tensor_map, &group_list) != SUCCESS) {
|
|
|
|
|