|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/backward.h"
|
|
|
|
|
|
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
#include "paddle/framework/net.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
@ -142,6 +143,7 @@ REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
|
|
|
|
|
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
|
|
|
|
|
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
|
|
|
|
|
REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
|
|
|
|
|
REGISTER_GRADIENT_OP(fc, fc_grad, f::EmptyOp);
|
|
|
|
|
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
|
|
|
|
|
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
|
|
|
|
|
|
|
|
|
@ -160,6 +162,18 @@ TEST(Backward, simple_op_grad) {
|
|
|
|
|
// LOG(INFO) << gop->Output("X" + "@GRAD");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, simple_op_not_need_grad) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
auto gop = f::Backward(*fwd, {"X"});
|
|
|
|
|
ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
|
|
|
|
|
"X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
gop->outputs_.end());
|
|
|
|
|
|
|
|
|
|
auto no_input_gop = f::Backward(*fwd, {"X", "b"});
|
|
|
|
|
ASSERT_NE(no_input_gop, nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_fc_backward_normal) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {});
|
|
|
|
@ -217,6 +231,8 @@ TEST(Backward, net_input_of_network_not_need_grad) {
|
|
|
|
|
bwd_net->outputs_.begin(), bwd_net->outputs_.end());
|
|
|
|
|
all_output.erase(f::OperatorBase::EMPTY_VAR_NAME());
|
|
|
|
|
|
|
|
|
|
LOG(INFO) << bwd_net->DebugString();
|
|
|
|
|
LOG(INFO) << bwd_net->ops_.size();
|
|
|
|
|
for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
|
|
|
|
|
ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
all_output.end());
|
|
|
|
@ -230,9 +246,9 @@ TEST(Backward, net_input_of_network_not_need_grad) {
|
|
|
|
|
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
|
|
|
|
|
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
|
|
|
|
|
ASSERT_EQ(3UL, first_fc_grad->ops_.size());
|
|
|
|
|
ASSERT_EQ(
|
|
|
|
|
f::OperatorBase::EMPTY_VAR_NAME(),
|
|
|
|
|
first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
|
LOG(INFO) << first_fc_grad->DebugString();
|
|
|
|
|
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(),
|
|
|
|
|
first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_shared_weight) {
|
|
|
|
@ -245,13 +261,14 @@ TEST(Backward, net_shared_weight) {
|
|
|
|
|
ASSERT_TRUE(bwd->IsNetOp());
|
|
|
|
|
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
|
|
|
|
|
ASSERT_EQ(3UL, bwd_net->ops_.size());
|
|
|
|
|
LOG(INFO) << bwd_net->DebugString();
|
|
|
|
|
ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, op_register_grad_not_for_network) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {"X", "W", "b"}, {"mul_result", "add_result", "Out"},
|
|
|
|
|
{{"temporary_index", std::vector<int>{1}}});
|
|
|
|
|
auto fwd =
|
|
|
|
|
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
|
|
|
|
|
{{"temporary_index", std::vector<int>{1}}});
|
|
|
|
|
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -299,9 +316,11 @@ TEST(Backward, op_part_of_output_are_not_need) {
|
|
|
|
|
TEST(Backward, op_part_of_input_are_not_need) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
|
|
|
|
|
auto backward = f::Backward(*fwd, {"a"});
|
|
|
|
|
ASSERT_TRUE(!backward->IsNetOp());
|
|
|
|
|
ASSERT_False(backward->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(backward.get());
|
|
|
|
|
ASSERT_EQ(net->ops_.size(), 1UL);
|
|
|
|
|
|
|
|
|
|
auto &grad_mul = *backward;
|
|
|
|
|
auto &grad_mul = *net->ops_[0];
|
|
|
|
|
ASSERT_EQ(grad_mul.type_, "mul_grad");
|
|
|
|
|
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
|
|
|
|
|
ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
|
|
|
|
@ -324,11 +343,11 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
|
|
|
|
|
{"mul_out2", "tmp_out2", "out2"}, {}));
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"},
|
|
|
|
|
{"mul_out3", "tmp_out3", "out3"}, {}));
|
|
|
|
|
net.CompleteAddOp();
|
|
|
|
|
net.CompleteAddOp(false);
|
|
|
|
|
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
|
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
|
auto bwd_net = static_cast<f::NetOp *>(backward.get());
|
|
|
|
|
ASSERT_EQ(bwd_net->ops_.size(), 1UL);
|
|
|
|
|
ASSERT_EQ(bwd_net->ops_.size(), 3UL);
|
|
|
|
|
|
|
|
|
|
auto &grad_fc = *bwd_net->ops_[0];
|
|
|
|
|
ASSERT_EQ(grad_fc.type_, "fc_grad");
|
|
|
|
|