update infer mirror ops

pull/11299/head
yangzhenzhang 4 years ago
parent dfa6daaa57
commit 38ea8784c6

@ -57,22 +57,6 @@ Status BatchParallelInfo::InferDevMatrixShape() {
return SUCCESS;
}
Status BatchParallelInfo::InferMirrorOps() {
mirror_ops_.clear();
if (g_device_manager->DeviceNum() == 1) {
MS_LOG(INFO) << name_ << " : The device num is 1, no need to create mirror ops.";
return SUCCESS;
}
MS_LOG(INFO) << name_ << " : Batch parallel input number " << strategy_->GetInputNumber();
for (size_t i = 0; i < input_value_.size(); i++) {
MS_EXCEPTION_IF_NULL(g_device_manager);
OperatorVector op_vec = CreateMirrorOps(g_device_manager->world_group(), g_device_manager->DeviceNum());
mirror_ops_.push_back(op_vec);
}
return SUCCESS;
}
Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; }
Status BatchParallelInfo::InferTensorMap() {

@ -44,7 +44,6 @@ class BatchParallelInfo : public OperatorInfo {
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;

@ -48,7 +48,6 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status GetAttrs() override;
Status InferMirrorOps() override { return SUCCESS; }
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorMap() override;
Status InferTensorInfo() override;

@ -209,32 +209,6 @@ Status MatMulBase::InferDevMatrixShape() {
return SUCCESS;
}
// all-reduce weight's grad
Status MatMulBase::InferMirrorOps() {
mirror_ops_.clear();
Shape mat_b_tensor_map = inputs_tensor_map_[1];
std::vector<Group> mat_b_group;
if (CreateGroupByTensorMap(mat_b_tensor_map, &mat_b_group) != SUCCESS) {
return FAILED;
}
OperatorVector op_for_inputs; // op_for_inputs is empty
OperatorVector op_for_weight;
if (mat_b_group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
return SUCCESS;
} else {
op_for_weight = CreateMirrorOps(mat_b_group[0].name(), mat_b_group[0].GetDevNum());
mirror_ops_.push_back(op_for_inputs);
mirror_ops_.push_back(op_for_weight);
MS_LOG(INFO) << name_ << " : Create the mirror ops for weight success, group is " << mat_b_group[0].name();
}
return SUCCESS;
}
Status MatMulBase::InferForwardCommunication() {
forward_op_.clear();
size_t dimension = origin_dev_matrix_shape_.size();

@ -49,7 +49,6 @@ class MatMulBase : public OperatorInfo {
Status SwapLastTwoElements(Shape *shape);
protected:
Status InferMirrorOps() override;
Status InferForwardCommunication() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;

@ -130,6 +130,46 @@ Status OperatorInfo::InferAttrs() {
return SUCCESS;
}
Status OperatorInfo::InferMirrorOps() {
mirror_ops_.clear();
if (inputs_shape_.empty()) {
MS_LOG(INFO) << name_ << ": The inputs size is empty";
return SUCCESS;
}
if (inputs_tensor_map_.size() != inputs_shape_.size()) {
MS_LOG(ERROR) << name_ << ": The size of inputs tensor map is not equal to the size of inputs shape";
return FAILED;
}
bool group_is_empty = true;
for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
std::vector<Group> group;
if (CreateGroupByTensorMap(inputs_tensor_map_[i], &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group failed, the input index is " << i;
mirror_ops_.clear();
return FAILED;
}
OperatorVector mirror_op;
if (group.empty()) {
MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
mirror_ops_.push_back(mirror_op);
continue;
}
group_is_empty = false;
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
}
if (group_is_empty) {
mirror_ops_.clear();
MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
}
return SUCCESS;
}
Status OperatorInfo::InferRepeatedCalcInfo() {
int64_t g_dev_list_size = stage_device_size_;
int64_t dev_matrix_size =

@ -187,10 +187,10 @@ class OperatorInfo {
virtual Status CheckStrategy(const StrategyPtr &strategy) = 0;
virtual Status InferTensorMap() = 0;
virtual Status InferForwardCommunication() = 0;
virtual Status InferMirrorOps() = 0;
virtual Status GetAttrs() = 0;
virtual Status InferTensorInfo() = 0;
virtual Status InferDevMatrixShape() = 0;
virtual Status InferMirrorOps();
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
void SetRepeatedCalcDevMatrix();
void ResetTensorMapIfRepeatedCalc();

@ -463,7 +463,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs4) {
matmul1->Init(strategy);
MirrorOps mirror_ops = matmul1->mirror_ops();
ASSERT_EQ(mirror_ops.size(), 0); // all reduce only in -3 dim (strategy is 1);
ASSERT_EQ(mirror_ops.size(), 2);
}
TEST_F(TestMatmulInfo, InitTwice) {

@ -32,8 +32,6 @@ class TestStepParallel : public UT::Common {
void TearDown() {}
};
void TestStepParallel::SetUp() { UT::InitPythonPath(); }
void Init_Device_Manager() {
RankList dev_list;
@ -52,6 +50,11 @@ void Init_Device_Manager() {
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
}
void TestStepParallel::SetUp() {
UT::InitPythonPath();
Init_Device_Manager();
}
CNodePtr Make_Node(Shape x, Shape y, Shape out, int64_t condition = 0) {
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
ParameterPtr param1 = func_graph->add_parameter();
@ -345,7 +348,6 @@ TEST_F(TestStepParallel, CreatOpInstance1) {
}
TEST_F(TestStepParallel, OperatorInstance) {
Init_Device_Manager();
// creat attrs and prim
PrimitivePtr prim = NewValueNode(prim::kPrimMatMul)->value()->cast<PrimitivePtr>();
ValuePtr transpose_a = MakeValue(false);
@ -369,7 +371,6 @@ TEST_F(TestStepParallel, OperatorInstance) {
}
TEST_F(TestStepParallel, ExtractInformation) {
Init_Device_Manager();
FuncGraphManagerPtr manager = Make_Manager();
FuncGraphSet graphs = manager->func_graphs();
FuncGraphPtr graph = *graphs.begin();
@ -379,7 +380,6 @@ TEST_F(TestStepParallel, ExtractInformation) {
}
TEST_F(TestStepParallel, ExtractInformation2) {
Init_Device_Manager();
FuncGraphManagerPtr manager = Make_Manager(2);
FuncGraphSet graphs = manager->func_graphs();
FuncGraphPtr graph = *graphs.begin();
@ -389,7 +389,6 @@ TEST_F(TestStepParallel, ExtractInformation2) {
}
TEST_F(TestStepParallel, ExtractInformation3) {
Init_Device_Manager();
FuncGraphManagerPtr manager = Make_Manager(3);
FuncGraphSet graphs = manager->func_graphs();
FuncGraphPtr graph = *graphs.begin();
@ -399,7 +398,6 @@ TEST_F(TestStepParallel, ExtractInformation3) {
}
TEST_F(TestStepParallel, ForwardCommunication1) {
Init_Device_Manager();
ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
ValuePtr attr1_value = MakeValue("0-1-2");
Attr attr0 = std::make_pair("op", attr0_value);
@ -499,7 +497,6 @@ TEST_F(TestStepParallel, ForwardCommunication3) {
}
TEST_F(TestStepParallel, GetTensorInLayout) {
Init_Device_Manager();
// creat attrs and prim
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
Shape inputs_x_dims = {64, 32};

Loading…
Cancel
Save