Fix compile error

Replace `OperatorPtr` with `std::shared_ptr<OperatorBase>`
cblas_new
fengjiayi 8 years ago
parent 5f3bc2a44a
commit f4e25550cd

@ -9,8 +9,9 @@ namespace paddle {
namespace framework {
TEST(GradOpCreator, AddTwo) {
OperatorPtr add_op(OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
OperatorPtr grad_add_op = OpRegistry::CreateGradOp(add_op);
std::shared_ptr<OperatorBase> add_op(
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4);
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2);
EXPECT_EQ(grad_add_op->Input("X"), "x");

@ -298,9 +298,10 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs);
}
static OperatorPtr CreateGradOp(OperatorPtr op) {
static std::shared_ptr<OperatorBase> CreateGradOp(
std::shared_ptr<OperatorBase> op) {
GradOpCreator creator(op.get());
OperatorPtr grad_op(creator.Create());
std::shared_ptr<OperatorBase> grad_op(creator.Create());
grad_op->Init();
return grad_op;
}

Loading…
Cancel
Save