|
|
@ -78,14 +78,14 @@ class FcOp : public ops::NetOp {
|
|
|
|
{Output("mul_result")}, {}));
|
|
|
|
{Output("mul_result")}, {}));
|
|
|
|
auto b_name = Input("b");
|
|
|
|
auto b_name = Input("b");
|
|
|
|
std::string before_act = "mul_result";
|
|
|
|
std::string before_act = "mul_result";
|
|
|
|
if (b_name != EMPTY_VAR_NAME()) {
|
|
|
|
if (b_name != kEmptyVarName) {
|
|
|
|
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name},
|
|
|
|
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name},
|
|
|
|
{Output("add_result")}, {}));
|
|
|
|
{Output("add_result")}, {}));
|
|
|
|
before_act = "add_result";
|
|
|
|
before_act = "add_result";
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
auto out_varname = Output("add_result");
|
|
|
|
auto out_varname = Output("add_result");
|
|
|
|
if (out_varname != EMPTY_VAR_NAME()) {
|
|
|
|
if (out_varname != kEmptyVarName) {
|
|
|
|
this->Rename(out_varname, EMPTY_VAR_NAME());
|
|
|
|
this->Rename(out_varname, kEmptyVarName);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -163,13 +163,12 @@ TEST(Backward, simple_op_grad) {
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
auto gop = f::OpRegistry::CreateGradOp(*fwd);
|
|
|
|
auto gop = f::OpRegistry::CreateGradOp(*fwd);
|
|
|
|
ASSERT_EQ(4UL, gop->inputs_.size());
|
|
|
|
ASSERT_EQ(4UL, gop->inputs_.size());
|
|
|
|
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), gop->inputs_[0]);
|
|
|
|
ASSERT_EQ(f::kEmptyVarName, gop->inputs_[0]);
|
|
|
|
ASSERT_EQ("rowwise_add_grad", gop->type_);
|
|
|
|
ASSERT_EQ("rowwise_add_grad", gop->type_);
|
|
|
|
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]);
|
|
|
|
ASSERT_EQ("X" + f::kGradVarSuffix, gop->outputs_[0]);
|
|
|
|
ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]);
|
|
|
|
ASSERT_EQ("b" + f::kGradVarSuffix, gop->outputs_[1]);
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(),
|
|
|
|
ASSERT_EQ("X" + f::kGradVarSuffix, gop->Output("X" + f::kGradVarSuffix));
|
|
|
|
gop->Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST(Backward, simple_op_not_need_grad) {
|
|
|
|
TEST(Backward, simple_op_not_need_grad) {
|
|
|
@ -177,7 +176,7 @@ TEST(Backward, simple_op_not_need_grad) {
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
auto gop = f::Backward(*fwd, {"X"});
|
|
|
|
auto gop = f::Backward(*fwd, {"X"});
|
|
|
|
ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
|
|
|
|
ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
|
|
|
|
"X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
"X" + f::kGradVarSuffix),
|
|
|
|
gop->outputs_.end());
|
|
|
|
gop->outputs_.end());
|
|
|
|
|
|
|
|
|
|
|
|
auto no_input_gop = f::Backward(*fwd, {"X", "b"});
|
|
|
|
auto no_input_gop = f::Backward(*fwd, {"X", "b"});
|
|
|
@ -210,9 +209,9 @@ TEST(Backward, net_fc_backward_normal) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_fc_backward_not_have_b) {
|
|
|
|
TEST(Backward, net_fc_backward_not_have_b) {
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd =
|
|
|
|
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()},
|
|
|
|
f::OpRegistry::CreateOp("fc", {"X", "w", f::kEmptyVarName},
|
|
|
|
{"mul_result", "add_result", "tmp"}, {});
|
|
|
|
{"mul_result", "add_result", "tmp"}, {});
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
|
|
|
|
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
|
|
|
|
ASSERT_TRUE(gop->IsNetOp());
|
|
|
|
ASSERT_TRUE(gop->IsNetOp());
|
|
|
@ -242,24 +241,21 @@ TEST(Backward, net_input_of_network_not_need_grad) {
|
|
|
|
|
|
|
|
|
|
|
|
std::unordered_set<std::string> all_output = std::unordered_set<std::string>(
|
|
|
|
std::unordered_set<std::string> all_output = std::unordered_set<std::string>(
|
|
|
|
bwd_net->outputs_.begin(), bwd_net->outputs_.end());
|
|
|
|
bwd_net->outputs_.begin(), bwd_net->outputs_.end());
|
|
|
|
all_output.erase(f::OperatorBase::EMPTY_VAR_NAME());
|
|
|
|
all_output.erase(f::kEmptyVarName);
|
|
|
|
|
|
|
|
|
|
|
|
for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
|
|
|
|
for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
|
|
|
|
ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
ASSERT_NE(all_output.find(out + f::kGradVarSuffix), all_output.end());
|
|
|
|
all_output.end());
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Not Generated X
|
|
|
|
// Not Generated X
|
|
|
|
ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
ASSERT_EQ(all_output.find("X" + f::kGradVarSuffix), all_output.end());
|
|
|
|
all_output.end());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(2UL, bwd_net->ops_.size());
|
|
|
|
ASSERT_EQ(2UL, bwd_net->ops_.size());
|
|
|
|
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
|
|
|
|
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
|
|
|
|
auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
|
|
|
|
auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
|
|
|
|
ASSERT_EQ(3UL, first_fc_grad->ops_.size());
|
|
|
|
ASSERT_EQ(3UL, first_fc_grad->ops_.size());
|
|
|
|
ASSERT_EQ(
|
|
|
|
ASSERT_EQ(f::kEmptyVarName,
|
|
|
|
f::OperatorBase::EMPTY_VAR_NAME(),
|
|
|
|
first_fc_grad->ops_[2]->Output("A" + f::kGradVarSuffix));
|
|
|
|
first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_shared_weight) {
|
|
|
|
TEST(Backward, net_shared_weight) {
|
|
|
@ -311,17 +307,15 @@ TEST(Backward, op_part_of_output_are_not_need) {
|
|
|
|
ASSERT_EQ(1UL, fill_zero.inputs_.size());
|
|
|
|
ASSERT_EQ(1UL, fill_zero.inputs_.size());
|
|
|
|
ASSERT_EQ("Z", fill_zero.inputs_[0]);
|
|
|
|
ASSERT_EQ("Z", fill_zero.inputs_[0]);
|
|
|
|
ASSERT_EQ(1UL, fill_zero.outputs_.size());
|
|
|
|
ASSERT_EQ(1UL, fill_zero.outputs_.size());
|
|
|
|
ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(), fill_zero.outputs_[0]);
|
|
|
|
ASSERT_EQ("Z" + f::kZeroVarSuffix, fill_zero.outputs_[0]);
|
|
|
|
|
|
|
|
|
|
|
|
auto &d_many_out = *net->ops_[1];
|
|
|
|
auto &d_many_out = *net->ops_[1];
|
|
|
|
ASSERT_EQ("many_output_op_grad", d_many_out.type_);
|
|
|
|
ASSERT_EQ("many_output_op_grad", d_many_out.type_);
|
|
|
|
ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
|
|
|
|
ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
|
|
|
|
ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(),
|
|
|
|
ASSERT_EQ("Z" + f::kZeroVarSuffix, d_many_out.Input("z" + f::kGradVarSuffix));
|
|
|
|
d_many_out.Input("z" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
ASSERT_EQ("Y" + f::kGradVarSuffix, d_many_out.Input("y" + f::kGradVarSuffix));
|
|
|
|
ASSERT_EQ("Y" + f::OperatorBase::GRAD_VAR_SUFFIX(),
|
|
|
|
ASSERT_EQ("X" + f::kGradVarSuffix,
|
|
|
|
d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
d_many_out.Output("x" + f::kGradVarSuffix));
|
|
|
|
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(),
|
|
|
|
|
|
|
|
d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST(Backward, op_part_of_input_are_not_need) {
|
|
|
|
TEST(Backward, op_part_of_input_are_not_need) {
|
|
|
@ -331,12 +325,10 @@ TEST(Backward, op_part_of_input_are_not_need) {
|
|
|
|
ASSERT_EQ(grad_mul.type_, "mul_grad");
|
|
|
|
ASSERT_EQ(grad_mul.type_, "mul_grad");
|
|
|
|
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
|
|
|
|
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
|
|
|
|
ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
|
|
|
|
ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
|
|
|
|
ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
ASSERT_EQ(grad_mul.Output("A" + f::kGradVarSuffix), f::kEmptyVarName);
|
|
|
|
f::OperatorBase::EMPTY_VAR_NAME());
|
|
|
|
ASSERT_EQ(grad_mul.Output("B" + f::kGradVarSuffix), "b" + f::kGradVarSuffix);
|
|
|
|
ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
ASSERT_EQ(grad_mul.Input("Out" + f::kGradVarSuffix),
|
|
|
|
"b" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
"out" + f::kGradVarSuffix);
|
|
|
|
ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
|
|
|
"out" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
|
|
|
|
ASSERT_EQ(grad_mul.Input("A"), "a");
|
|
|
|
ASSERT_EQ(grad_mul.Input("A"), "a");
|
|
|
|
ASSERT_EQ(grad_mul.Input("B"), "b");
|
|
|
|
ASSERT_EQ(grad_mul.Input("B"), "b");
|
|
|
|
ASSERT_EQ(grad_mul.Input("Out"), "out");
|
|
|
|
ASSERT_EQ(grad_mul.Input("Out"), "out");
|
|
|
@ -368,23 +360,4 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
|
|
|
|
EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL);
|
|
|
|
EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL);
|
|
|
|
EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL);
|
|
|
|
EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL);
|
|
|
|
EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL);
|
|
|
|
EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL);
|
|
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
|
|
|
EXPECT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
|
|
|
f::OperatorBase::EMPTY_VAR_NAME());
|
|
|
|
|
|
|
|
EXPECT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
|
|
|
"w3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
|
|
|
|
EXPECT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
|
|
|
"b3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
|
|
|
|
EXPECT_EQ(grad_fc.Output("mul_result" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
|
|
|
"mul_out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
|
|
|
"out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
|
|
|
|
EXPECT_EQ(grad_fc.Input("X"), "out2");
|
|
|
|
|
|
|
|
EXPECT_EQ(grad_fc.Input("W"), "w3");
|
|
|
|
|
|
|
|
EXPECT_EQ(grad_fc.Input("mul_result"), "mul_out3");
|
|
|
|
|
|
|
|
EXPECT_EQ(grad_fc.Input("add_result"), "tmp_out3");
|
|
|
|
|
|
|
|
EXPECT_EQ(grad_fc.Input("Out"), "out3");
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|