|
|
|
@ -15,8 +15,9 @@
|
|
|
|
|
#include "paddle/framework/backward.h"
|
|
|
|
|
|
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
#include "paddle/framework/net.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/operators/net_op.h"
|
|
|
|
|
#include "paddle/operators/type_alias.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -70,7 +71,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class FcOp : public NetOp {
|
|
|
|
|
class FcOp : public ops::NetOp {
|
|
|
|
|
public:
|
|
|
|
|
void Init() override {
|
|
|
|
|
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
|
|
|
|
@ -182,7 +183,8 @@ TEST(Backward, simple_op_not_need_grad) {
|
|
|
|
|
auto no_input_gop = f::Backward(*fwd, {"X", "b"});
|
|
|
|
|
ASSERT_NE(no_input_gop, nullptr);
|
|
|
|
|
ASSERT_TRUE(no_input_gop->IsNetOp());
|
|
|
|
|
ASSERT_EQ(0UL, std::static_pointer_cast<f::NetOp>(no_input_gop)->ops_.size());
|
|
|
|
|
ASSERT_EQ(0UL,
|
|
|
|
|
std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_fc_backward_normal) {
|
|
|
|
@ -191,7 +193,7 @@ TEST(Backward, net_fc_backward_normal) {
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
|
|
|
|
|
ASSERT_TRUE(gop->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(gop.get());
|
|
|
|
|
auto net = static_cast<ops::NetOp *>(gop.get());
|
|
|
|
|
|
|
|
|
|
ASSERT_NO_THROW(net->DebugString());
|
|
|
|
|
|
|
|
|
@ -214,7 +216,7 @@ TEST(Backward, net_fc_backward_not_have_b) {
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
|
|
|
|
|
ASSERT_TRUE(gop->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(gop.get());
|
|
|
|
|
auto net = static_cast<ops::NetOp *>(gop.get());
|
|
|
|
|
|
|
|
|
|
ASSERT_NO_THROW(net->DebugString());
|
|
|
|
|
|
|
|
|
@ -228,7 +230,7 @@ TEST(Backward, net_fc_backward_not_have_b) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_input_of_network_not_need_grad) {
|
|
|
|
|
f::NetOp net;
|
|
|
|
|
ops::NetOp net;
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"},
|
|
|
|
|
{"mul_tmp_0", "add_tmp_0", "hidden0"}, {}));
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"},
|
|
|
|
@ -236,7 +238,7 @@ TEST(Backward, net_input_of_network_not_need_grad) {
|
|
|
|
|
net.CompleteAddOp();
|
|
|
|
|
auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
|
|
|
|
|
ASSERT_TRUE(bwd->IsNetOp());
|
|
|
|
|
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
|
|
|
|
|
auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
|
|
|
|
|
|
|
|
|
|
std::unordered_set<std::string> all_output = std::unordered_set<std::string>(
|
|
|
|
|
bwd_net->outputs_.begin(), bwd_net->outputs_.end());
|
|
|
|
@ -253,7 +255,7 @@ TEST(Backward, net_input_of_network_not_need_grad) {
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(2UL, bwd_net->ops_.size());
|
|
|
|
|
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
|
|
|
|
|
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
|
|
|
|
|
auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
|
|
|
|
|
ASSERT_EQ(3UL, first_fc_grad->ops_.size());
|
|
|
|
|
ASSERT_EQ(
|
|
|
|
|
f::OperatorBase::EMPTY_VAR_NAME(),
|
|
|
|
@ -261,14 +263,14 @@ TEST(Backward, net_input_of_network_not_need_grad) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_shared_weight) {
|
|
|
|
|
f::NetOp net;
|
|
|
|
|
ops::NetOp net;
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {}));
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {}));
|
|
|
|
|
net.CompleteAddOp();
|
|
|
|
|
|
|
|
|
|
auto bwd = f::Backward(net, {});
|
|
|
|
|
ASSERT_TRUE(bwd->IsNetOp());
|
|
|
|
|
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
|
|
|
|
|
auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
|
|
|
|
|
ASSERT_EQ(3UL, bwd_net->ops_.size());
|
|
|
|
|
ASSERT_EQ("add", bwd_net->ops_[2]->type_);
|
|
|
|
|
}
|
|
|
|
@ -285,7 +287,7 @@ TEST(Backward, op_all_input_are_not_need) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
|
|
|
|
|
auto backward = f::Backward(*fwd, {"X", "b"});
|
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(backward.get());
|
|
|
|
|
auto net = static_cast<ops::NetOp *>(backward.get());
|
|
|
|
|
ASSERT_TRUE(net->ops_.empty());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -293,7 +295,7 @@ TEST(Backward, op_all_output_are_not_need) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
|
|
|
|
|
auto backward = f::Backward(*fwd, {"Out"});
|
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(backward.get());
|
|
|
|
|
auto net = static_cast<ops::NetOp *>(backward.get());
|
|
|
|
|
ASSERT_TRUE(net->ops_.empty());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -301,7 +303,7 @@ TEST(Backward, op_part_of_output_are_not_need) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {});
|
|
|
|
|
auto backward = f::Backward(*fwd, {"Z"});
|
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(backward.get());
|
|
|
|
|
auto net = static_cast<ops::NetOp *>(backward.get());
|
|
|
|
|
ASSERT_EQ(net->ops_.size(), 2UL);
|
|
|
|
|
|
|
|
|
|
auto &fill_zero = *net->ops_[0];
|
|
|
|
@ -341,7 +343,7 @@ TEST(Backward, op_part_of_input_are_not_need) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
|
|
|
|
|
f::NetOp net;
|
|
|
|
|
ops::NetOp net;
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"},
|
|
|
|
|
{"mul_out1", "add_out1", "out1"}, {}));
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"},
|
|
|
|
@ -351,7 +353,7 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
|
|
|
|
|
net.CompleteAddOp();
|
|
|
|
|
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
|
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
|
auto bwd_net = static_cast<f::NetOp *>(backward.get());
|
|
|
|
|
auto bwd_net = static_cast<ops::NetOp *>(backward.get());
|
|
|
|
|
ASSERT_EQ(bwd_net->ops_.size(), 3UL);
|
|
|
|
|
auto &grad_fc = *bwd_net->ops_[0];
|
|
|
|
|
EXPECT_EQ(grad_fc.inputs_.size(),
|
|
|
|
|