|
|
|
|
@ -58,6 +58,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("X", "A");
|
|
|
|
|
AddInput("Y", "B");
|
|
|
|
|
AddOutput("Out", "Out");
|
|
|
|
|
AddAttr<int>("x_num_col_dims", "").SetDefault(1).EqualGreaterThan(1);
|
|
|
|
|
AddAttr<int>("y_num_col_dims", "").SetDefault(1).EqualGreaterThan(1);
|
|
|
|
|
AddComment("Mul");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
@ -453,6 +455,9 @@ TEST(Backward, default_attribute) {
|
|
|
|
|
AppendBackward(program, {});
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(block->AllOps().size(), 2UL);
|
|
|
|
|
EXPECT_EQ(boost::get<int>(op->GetAttr("x_num_col_dims")), 1);
|
|
|
|
|
EXPECT_EQ(boost::get<int>(op->GetAttr("y_num_col_dims")), 1);
|
|
|
|
|
|
|
|
|
|
f::OpDescBind *grad_op = block->AllOps()[1];
|
|
|
|
|
ASSERT_EQ(grad_op->Type(), "mul_grad");
|
|
|
|
|
EXPECT_EQ(boost::get<int>(grad_op->GetAttr("x_num_col_dims")), 1);
|
|
|
|
|
|