"fix return net error"

cblas_new
dongzhihong 8 years ago
parent 1197420598
commit 302046aa51

@ -60,6 +60,16 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
}
};
class NoGradOpMaker : public OpProtoAndCheckerMaker {
public:
NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "X input");
AddOutput("Y", "Y output");
AddComment("NoGradOp, same input output. no Grad");
}
};
class FcOp : public NetOp {
public:
void Init() override {
@ -139,6 +149,7 @@ REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp);
REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
@ -266,9 +277,11 @@ TEST(Backward, net_shared_weight) {
}
TEST(Backward, op_register_grad_not_for_network) {
auto fwd =
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
{{"temporary_index", std::vector<int>{1}}});
// auto fwd =
// f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
// {{"temporary_index", std::vector<int>{1}}});
auto fwd = f::OpRegistry::CreateOp("nograd", {"x"}, {"x"}, {});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
}
@ -316,11 +329,7 @@ TEST(Backward, op_part_of_output_are_not_need) {
TEST(Backward, op_part_of_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
auto backward = f::Backward(*fwd, {"a"});
ASSERT_FALSE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 1UL);
auto &grad_mul = *net->ops_[0];
auto &grad_mul = *backward;
ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL);

Loading…
Cancel
Save