|
|
|
@ -32,8 +32,6 @@ class TestStepParallel : public UT::Common {
|
|
|
|
|
void TearDown() {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void TestStepParallel::SetUp() { UT::InitPythonPath(); }
|
|
|
|
|
|
|
|
|
|
void Init_Device_Manager() {
|
|
|
|
|
RankList dev_list;
|
|
|
|
|
|
|
|
|
@ -52,6 +50,11 @@ void Init_Device_Manager() {
|
|
|
|
|
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TestStepParallel::SetUp() {
|
|
|
|
|
UT::InitPythonPath();
|
|
|
|
|
Init_Device_Manager();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr Make_Node(Shape x, Shape y, Shape out, int64_t condition = 0) {
|
|
|
|
|
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
|
|
|
|
|
ParameterPtr param1 = func_graph->add_parameter();
|
|
|
|
@ -345,7 +348,6 @@ TEST_F(TestStepParallel, CreatOpInstance1) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestStepParallel, OperatorInstance) {
|
|
|
|
|
Init_Device_Manager();
|
|
|
|
|
// creat attrs and prim
|
|
|
|
|
PrimitivePtr prim = NewValueNode(prim::kPrimMatMul)->value()->cast<PrimitivePtr>();
|
|
|
|
|
ValuePtr transpose_a = MakeValue(false);
|
|
|
|
@ -369,7 +371,6 @@ TEST_F(TestStepParallel, OperatorInstance) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestStepParallel, ExtractInformation) {
|
|
|
|
|
Init_Device_Manager();
|
|
|
|
|
FuncGraphManagerPtr manager = Make_Manager();
|
|
|
|
|
FuncGraphSet graphs = manager->func_graphs();
|
|
|
|
|
FuncGraphPtr graph = *graphs.begin();
|
|
|
|
@ -379,7 +380,6 @@ TEST_F(TestStepParallel, ExtractInformation) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestStepParallel, ExtractInformation2) {
|
|
|
|
|
Init_Device_Manager();
|
|
|
|
|
FuncGraphManagerPtr manager = Make_Manager(2);
|
|
|
|
|
FuncGraphSet graphs = manager->func_graphs();
|
|
|
|
|
FuncGraphPtr graph = *graphs.begin();
|
|
|
|
@ -389,7 +389,6 @@ TEST_F(TestStepParallel, ExtractInformation2) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestStepParallel, ExtractInformation3) {
|
|
|
|
|
Init_Device_Manager();
|
|
|
|
|
FuncGraphManagerPtr manager = Make_Manager(3);
|
|
|
|
|
FuncGraphSet graphs = manager->func_graphs();
|
|
|
|
|
FuncGraphPtr graph = *graphs.begin();
|
|
|
|
@ -399,7 +398,6 @@ TEST_F(TestStepParallel, ExtractInformation3) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestStepParallel, ForwardCommunication1) {
|
|
|
|
|
Init_Device_Manager();
|
|
|
|
|
ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM);
|
|
|
|
|
ValuePtr attr1_value = MakeValue("0-1-2");
|
|
|
|
|
Attr attr0 = std::make_pair("op", attr0_value);
|
|
|
|
@ -499,7 +497,6 @@ TEST_F(TestStepParallel, ForwardCommunication3) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestStepParallel, GetTensorInLayout) {
|
|
|
|
|
Init_Device_Manager();
|
|
|
|
|
// creat attrs and prim
|
|
|
|
|
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
|
|
|
|
|
Shape inputs_x_dims = {64, 32};
|
|
|
|
|