|
|
|
@ -116,12 +116,8 @@ class ExecutorTesterRandom : public ::testing::Test {
|
|
|
|
|
{{"dims", std::vector<int>{input_dim, embed_dim}}}, init_root_block);
|
|
|
|
|
AddOp("gaussian_random", {}, {{"Out", {"w2"}}},
|
|
|
|
|
{{"dims", std::vector<int>{embed_dim, input_dim}}}, init_root_block);
|
|
|
|
|
AddOp("fetch", {{"Input", {"w1"}}}, {},
|
|
|
|
|
{{"dims", std::vector<int>{input_dim, embed_dim}}, {"col", 0}},
|
|
|
|
|
init_root_block);
|
|
|
|
|
AddOp("fetch", {{"Input", {"w2"}}}, {},
|
|
|
|
|
{{"dims", std::vector<int>{embed_dim, input_dim}}, {"col", 1}},
|
|
|
|
|
init_root_block);
|
|
|
|
|
AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, init_root_block);
|
|
|
|
|
AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, init_root_block);
|
|
|
|
|
|
|
|
|
|
// flush
|
|
|
|
|
init_program.Proto();
|
|
|
|
@ -163,12 +159,8 @@ class ExecutorTesterRandom : public ::testing::Test {
|
|
|
|
|
{"Grad", {"w2@GRAD"}}},
|
|
|
|
|
{{"ParamOut", {"w2"}}}, {}, root_block);
|
|
|
|
|
|
|
|
|
|
AddOp("fetch", {{"Input", {"w1"}}}, {},
|
|
|
|
|
{{"dims", std::vector<int>{input_dim, embed_dim}}, {"col", 0}},
|
|
|
|
|
root_block);
|
|
|
|
|
AddOp("fetch", {{"Input", {"w2"}}}, {},
|
|
|
|
|
{{"dims", std::vector<int>{embed_dim, input_dim}}, {"col", 1}},
|
|
|
|
|
root_block);
|
|
|
|
|
AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, root_block);
|
|
|
|
|
AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, root_block);
|
|
|
|
|
|
|
|
|
|
// flush
|
|
|
|
|
program.Proto();
|
|
|
|
@ -197,10 +189,8 @@ class ExecutorTesterFeedAndFetch : public ::testing::Test {
|
|
|
|
|
root_block);
|
|
|
|
|
AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}},
|
|
|
|
|
root_block);
|
|
|
|
|
AddOp("fetch", {{"Input", {"a"}}}, {}, {{"dims", dim}, {"col", 0}},
|
|
|
|
|
root_block);
|
|
|
|
|
AddOp("fetch", {{"Input", {"b"}}}, {}, {{"dims", dim}, {"col", 1}},
|
|
|
|
|
root_block);
|
|
|
|
|
AddOp("fetch", {{"Input", {"a"}}}, {}, {{"col", 0}}, root_block);
|
|
|
|
|
AddOp("fetch", {{"Input", {"b"}}}, {}, {{"col", 1}}, root_block);
|
|
|
|
|
|
|
|
|
|
// flush
|
|
|
|
|
program.Proto();
|
|
|
|
|