|
|
|
@ -32,6 +32,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class PlainNetTest : public testing::Test {
|
|
|
|
|
public:
|
|
|
|
|
virtual void SetUp() {
|
|
|
|
|
net_ = std::make_shared<PlainNet>();
|
|
|
|
|
ASSERT_NE(net_, nullptr);
|
|
|
|
@ -50,6 +51,8 @@ class PlainNetTest : public testing::Test {
|
|
|
|
|
|
|
|
|
|
virtual void TearDown() {}
|
|
|
|
|
|
|
|
|
|
virtual void TestBody() {}
|
|
|
|
|
|
|
|
|
|
void TestOpKernel() {
|
|
|
|
|
AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net_->inputs_);
|
|
|
|
|
AssertSameVectorWithoutOrder({"y", "z"}, net_->outputs_);
|
|
|
|
@ -67,6 +70,7 @@ class PlainNetTest : public testing::Test {
|
|
|
|
|
ASSERT_EQ(2, infer_shape_cnt);
|
|
|
|
|
ASSERT_EQ(2, run_cnt);
|
|
|
|
|
|
|
|
|
|
auto op2 = std::make_shared<TestOp>();
|
|
|
|
|
ASSERT_THROW(net_->AddOp(op2), EnforceNotMet);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -83,12 +87,12 @@ class PlainNetTest : public testing::Test {
|
|
|
|
|
|
|
|
|
|
TEST(OpKernel, all) {
|
|
|
|
|
PlainNetTest net;
|
|
|
|
|
net->TestOpKernel();
|
|
|
|
|
net.TestOpKernel();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(AddBackwardOp, TestAddBackwardOp) {
|
|
|
|
|
PlainNetTest net;
|
|
|
|
|
net->TestAddBackwardOp();
|
|
|
|
|
net.TestAddBackwardOp();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|