|
|
|
@ -59,10 +59,9 @@ REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
|
|
|
|
|
|
|
|
|
|
TEST(GradOpBuilder, MutiInOut) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
|
|
|
|
|
"mult_io",
|
|
|
|
|
{{"In1", {"in1"}},
|
|
|
|
|
{"In2_mult", {"in2_1", "in2_2", "in2_3"}},
|
|
|
|
|
{"In3", {"in3"}}},
|
|
|
|
|
"mult_io", {{"In1", {"in1"}},
|
|
|
|
|
{"In2_mult", {"in2_1", "in2_2", "in2_3"}},
|
|
|
|
|
{"In3", {"in3"}}},
|
|
|
|
|
{{"Out1", {"out1"}}, {"Out2_mult", {"out2_1", "out2_2"}}}, {}));
|
|
|
|
|
std::shared_ptr<f::OperatorBase> grad_test_op =
|
|
|
|
|
f::OpRegistry::CreateGradOp(*test_op);
|
|
|
|
@ -92,10 +91,9 @@ TEST(GradOpBuilder, MutiInOut) {
|
|
|
|
|
|
|
|
|
|
TEST(GradOpBuilder, IOIgnoredInGradient) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
|
|
|
|
|
"io_ignored",
|
|
|
|
|
{{"In1", {"in1"}},
|
|
|
|
|
{"In2_mult", {"in2_1", "in2_2"}},
|
|
|
|
|
{"In3_mult", {"in3_1", "in3_2"}}},
|
|
|
|
|
"io_ignored", {{"In1", {"in1"}},
|
|
|
|
|
{"In2_mult", {"in2_1", "in2_2"}},
|
|
|
|
|
{"In3_mult", {"in3_1", "in3_2"}}},
|
|
|
|
|
{{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, {}));
|
|
|
|
|
std::shared_ptr<f::OperatorBase> grad_test_op =
|
|
|
|
|
f::OpRegistry::CreateGradOp(*test_op);
|
|
|
|
|