"add simple net test"

cblas_new
dongzhihong 8 years ago
parent 404cc056b8
commit 65d2678720

@ -86,8 +86,6 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
// travesal subnet/op
for (auto it = forwardNet.ops_.end(); it != forwardNet.ops_.begin(); --it) {
auto fwd = *it;
// for (auto& fwd : forwardNet.ops_) {
// auto bwd = Backward(*fwd, no_grad_names);
auto bwd = Backward(*fwd, no_grad_names);
net->AddOp(bwd);
for (size_t i = 0; i < bwd->outputs_.size(); ++i) {

@ -63,10 +63,10 @@ class FcOp : public NetOp {
public:
void Init() override {
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
{Output("before_act")}, {}));
{Output("mul_out")}, {}));
auto b_name = Input("b");
if (b_name != EMPTY_VAR_NAME()) {
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name},
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_out"), b_name},
{Output("before_act")}, {}));
}
AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")},
@ -82,6 +82,7 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput("X", "x");
AddInput("W", "w");
AddInput("b", "b");
AddOutput("mul_out", "mul output").SetTemporary();
AddOutput("before_act", "before act").SetTemporary();
AddOutput("Out", "");
AddComment("");
@ -140,6 +141,7 @@ TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd);
LOG(INFO) << gop->DebugString();
ASSERT_EQ(1UL, gop->inputs_.size());
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_);
@ -151,10 +153,18 @@ TEST(Backward, simple_op_grad) {
// LOG(INFO) << gop->Output("X" + "@GRAD");
}
TEST(Backward, simple_net_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::Backward(*fwd, {});
LOG(INFO) << gop->DebugString();
}
TEST(Backward, net_fc_backward_normal) {
std::shared_ptr<f::OperatorBase> fwd =
f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {});
ASSERT_NE(fwd, nullptr);
LOG(INFO) << fwd->DebugString();
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp());
auto net = static_cast<f::NetOp *>(gop.get());

Loading…
Cancel
Save