|
|
|
@ -60,6 +60,16 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class NoGradOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "X input");
|
|
|
|
|
AddOutput("Y", "Y output");
|
|
|
|
|
AddComment("NoGradOp, same input output. no Grad");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class FcOp : public NetOp {
|
|
|
|
|
public:
|
|
|
|
|
void Init() override {
|
|
|
|
@ -139,6 +149,7 @@ REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker);
|
|
|
|
|
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp);
|
|
|
|
|
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker);
|
|
|
|
|
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp);
|
|
|
|
|
REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker);
|
|
|
|
|
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
|
|
|
|
|
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
|
|
|
|
|
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
|
|
|
|
@ -266,9 +277,11 @@ TEST(Backward, net_shared_weight) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, op_register_grad_not_for_network) {
|
|
|
|
|
auto fwd =
|
|
|
|
|
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
|
|
|
|
|
{{"temporary_index", std::vector<int>{1}}});
|
|
|
|
|
// auto fwd =
|
|
|
|
|
// f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
|
|
|
|
|
// {{"temporary_index", std::vector<int>{1}}});
|
|
|
|
|
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("nograd", {"x"}, {"x"}, {});
|
|
|
|
|
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -316,11 +329,7 @@ TEST(Backward, op_part_of_output_are_not_need) {
|
|
|
|
|
TEST(Backward, op_part_of_input_are_not_need) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
|
|
|
|
|
auto backward = f::Backward(*fwd, {"a"});
|
|
|
|
|
ASSERT_FALSE(backward->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(backward.get());
|
|
|
|
|
ASSERT_EQ(net->ops_.size(), 1UL);
|
|
|
|
|
|
|
|
|
|
auto &grad_mul = *net->ops_[0];
|
|
|
|
|
auto &grad_mul = *backward;
|
|
|
|
|
ASSERT_EQ(grad_mul.type_, "mul_grad");
|
|
|
|
|
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
|
|
|
|
|
ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
|
|
|
|
|