fix backward test case

tonyyang-svail-feed-op-desgin
dongzhihong 8 years ago
parent 3dc4f46f81
commit 494b3bda7d

@ -159,7 +159,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
insert_position.push_back(
{dup_op.back(),
OpRegistry::CreateOp(
"add", {{"X", {insert_add_x}}, {"X", {insert_add_y}}},
"sum", {{"X", {insert_add_x}}, {"X", {insert_add_y}}},
{{"Out", {insert_add_out}}}, {})});
}
}

@ -133,15 +133,18 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
}
};
class AddOpMaker : public OpProtoAndCheckerMaker {
class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").AsDuplicable();
AddOutput("Out", "out");
AddInput("X", "the input tensors of sum operator.")
.AsDuplicable()
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddComment("");
}
};
} // namespace framework
} // namespace paddle
@ -154,7 +157,7 @@ REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker);
REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP);
REGISTER_OP(sum, f::NOP, f::SumOpMaker, sum_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
f::NOP);
@ -283,7 +286,7 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add", bwd_net->ops_[2]->Type());
ASSERT_EQ("sum", bwd_net->ops_[2]->Type());
}
TEST(Backward, op_register_grad_not_for_network) {

Loading…
Cancel
Save