|
|
@ -74,20 +74,17 @@ TEST_F(TestData, test_build_value) {
|
|
|
|
AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false);
|
|
|
|
AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false);
|
|
|
|
AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2}));
|
|
|
|
AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2}));
|
|
|
|
ValuePtr func_tuple_built = abs_func_tuple->BuildValue();
|
|
|
|
ValuePtr func_tuple_built = abs_func_tuple->BuildValue();
|
|
|
|
ASSERT_EQ(*func_tuple_built,
|
|
|
|
ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
|
|
|
|
ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// BuildValue(List(AbstractFunction)) should return kAnyValue;
|
|
|
|
// BuildValue(List(AbstractFunction)) should return kAnyValue;
|
|
|
|
AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2}));
|
|
|
|
AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2}));
|
|
|
|
ValuePtr func_list_built = abs_func_list->BuildValue();
|
|
|
|
ValuePtr func_list_built = abs_func_list->BuildValue();
|
|
|
|
ASSERT_EQ(*func_list_built,
|
|
|
|
ASSERT_EQ(*func_list_built, ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
|
|
|
|
ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue
|
|
|
|
// BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue
|
|
|
|
abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2}));
|
|
|
|
abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2}));
|
|
|
|
func_tuple_built = abs_func_tuple->BuildValue();
|
|
|
|
func_tuple_built = abs_func_tuple->BuildValue();
|
|
|
|
ASSERT_EQ(*func_tuple_built,
|
|
|
|
ASSERT_EQ(*func_tuple_built, ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd}));
|
|
|
|
ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd}));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(TestData, test_build_type) {
|
|
|
|
TEST_F(TestData, test_build_type) {
|
|
|
@ -129,7 +126,7 @@ TEST_F(TestData, test_build_shape) {
|
|
|
|
AbstractBasePtr abstract_tup = FromValue(vec, true);
|
|
|
|
AbstractBasePtr abstract_tup = FromValue(vec, true);
|
|
|
|
std::shared_ptr<TupleShape> shape_tuple = dyn_cast<TupleShape>(abstract_tup->BuildShape());
|
|
|
|
std::shared_ptr<TupleShape> shape_tuple = dyn_cast<TupleShape>(abstract_tup->BuildShape());
|
|
|
|
ASSERT_TRUE(shape_tuple);
|
|
|
|
ASSERT_TRUE(shape_tuple);
|
|
|
|
const std::vector<BaseShapePtr>& ptr_vec = shape_tuple->shape();
|
|
|
|
const std::vector<BaseShapePtr> &ptr_vec = shape_tuple->shape();
|
|
|
|
ASSERT_EQ(ptr_vec.size(), 2);
|
|
|
|
ASSERT_EQ(ptr_vec.size(), 2);
|
|
|
|
|
|
|
|
|
|
|
|
ShapePtr shape1 = dyn_cast<Shape>(ptr_vec[0]);
|
|
|
|
ShapePtr shape1 = dyn_cast<Shape>(ptr_vec[0]);
|
|
|
@ -148,14 +145,14 @@ TEST_F(TestData, test_clone) {
|
|
|
|
ASSERT_TRUE(s1->GetValueTrack() == s2->GetValueTrack());
|
|
|
|
ASSERT_TRUE(s1->GetValueTrack() == s2->GetValueTrack());
|
|
|
|
ASSERT_TRUE(*s1->GetShapeTrack() == *s2->GetShapeTrack());
|
|
|
|
ASSERT_TRUE(*s1->GetShapeTrack() == *s2->GetShapeTrack());
|
|
|
|
|
|
|
|
|
|
|
|
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
|
|
|
|
AbstractFunctionPtr f1 =
|
|
|
|
AnalysisContext::DummyContext());
|
|
|
|
std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AnalysisContext::DummyContext());
|
|
|
|
AbstractBasePtr f2 = f1->Clone();
|
|
|
|
AbstractBasePtr f2 = f1->Clone();
|
|
|
|
ASSERT_TRUE(*f2 == *f1);
|
|
|
|
ASSERT_TRUE(*f2 == *f1);
|
|
|
|
|
|
|
|
|
|
|
|
AbstractList l1 = AbstractList({s1, s2});
|
|
|
|
AbstractList l1 = AbstractList({s1, s2});
|
|
|
|
AbstractBasePtr l2 = l1.Clone();
|
|
|
|
AbstractBasePtr l2 = l1.Clone();
|
|
|
|
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
|
|
|
|
AbstractList *l2_cast = dynamic_cast<AbstractList *>(l2.get());
|
|
|
|
ASSERT_TRUE(l2_cast != nullptr);
|
|
|
|
ASSERT_TRUE(l2_cast != nullptr);
|
|
|
|
ASSERT_TRUE(l2_cast->GetValueTrack() == l1.GetValueTrack());
|
|
|
|
ASSERT_TRUE(l2_cast->GetValueTrack() == l1.GetValueTrack());
|
|
|
|
|
|
|
|
|
|
|
@ -184,19 +181,19 @@ TEST_F(TestData, test_broaden) {
|
|
|
|
AbstractBasePtr s2 = s1->Broaden();
|
|
|
|
AbstractBasePtr s2 = s1->Broaden();
|
|
|
|
ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack());
|
|
|
|
ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack());
|
|
|
|
ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1));
|
|
|
|
ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1));
|
|
|
|
ASSERT_TRUE(s2->GetValueTrack()->isa<AnyValue>());
|
|
|
|
ASSERT_TRUE(s2->GetValueTrack()->isa<Int64Imm>());
|
|
|
|
|
|
|
|
|
|
|
|
AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
|
|
|
|
AbstractFunctionPtr f1 =
|
|
|
|
AnalysisContext::DummyContext());
|
|
|
|
std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(), AnalysisContext::DummyContext());
|
|
|
|
AbstractBasePtr f2 = f1->Broaden();
|
|
|
|
AbstractBasePtr f2 = f1->Broaden();
|
|
|
|
ASSERT_TRUE(f2 == f1);
|
|
|
|
ASSERT_TRUE(f2 == f1);
|
|
|
|
|
|
|
|
|
|
|
|
AbstractList l1 = AbstractList({s1, s2});
|
|
|
|
AbstractList l1 = AbstractList({s1, s2});
|
|
|
|
AbstractBasePtr l2 = l1.Broaden();
|
|
|
|
AbstractBasePtr l2 = l1.Broaden();
|
|
|
|
AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
|
|
|
|
AbstractList *l2_cast = dynamic_cast<AbstractList *>(l2.get());
|
|
|
|
ASSERT_TRUE(l2_cast != nullptr);
|
|
|
|
ASSERT_TRUE(l2_cast != nullptr);
|
|
|
|
AbstractBasePtr csr = AbstractJoin(l2_cast->elements());
|
|
|
|
AbstractBasePtr csr = AbstractJoin(l2_cast->elements());
|
|
|
|
ASSERT_TRUE(csr->GetValueTrack()->isa<AnyValue>());
|
|
|
|
ASSERT_TRUE(csr->GetValueTrack()->isa<Int64Imm>());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace abstract
|
|
|
|
} // namespace abstract
|
|
|
|