|
|
|
@ -18,7 +18,7 @@
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
#include "paddle/framework/tensor.h"
|
|
|
|
|
#include "paddle/operators/recurrent_network_op.h"
|
|
|
|
|
#include "paddle/operators/recurrent_op.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -55,7 +55,7 @@ protected:
|
|
|
|
|
w->GetMutable<Tensor>()->mutable_data<float>(
|
|
|
|
|
make_ddim(std::vector<int>{30, 30}), platform::CPUPlace());
|
|
|
|
|
|
|
|
|
|
for (auto boot : std::vector<std::string>{"x_boot", "h_boot"}) {
|
|
|
|
|
for (auto boot : std::vector<std::string>{"h_boot"}) {
|
|
|
|
|
LOG(INFO) << "create global variable " << boot;
|
|
|
|
|
Variable* h_boot = scope_.NewVar(boot);
|
|
|
|
|
h_boot->GetMutable<Tensor>()->mutable_data<float>(
|
|
|
|
@ -79,7 +79,6 @@ protected:
|
|
|
|
|
op_desc.add_inputs("x0");
|
|
|
|
|
op_desc.add_inputs("x1");
|
|
|
|
|
// boot_memories 3
|
|
|
|
|
op_desc.add_inputs("x_boot");
|
|
|
|
|
op_desc.add_inputs("h_boot");
|
|
|
|
|
// step net 5
|
|
|
|
|
op_desc.add_inputs("step_net");
|
|
|
|
@ -91,7 +90,7 @@ protected:
|
|
|
|
|
auto _input_format = std::vector<int>{
|
|
|
|
|
0, // in_link
|
|
|
|
|
3, // memories
|
|
|
|
|
5 // step_net
|
|
|
|
|
4 // step_net
|
|
|
|
|
};
|
|
|
|
|
auto input_format = op_desc.add_attrs();
|
|
|
|
|
input_format->set_name("input_format");
|
|
|
|
@ -129,12 +128,11 @@ protected:
|
|
|
|
|
inlink_alias->add_strings(item);
|
|
|
|
|
}
|
|
|
|
|
// pre memories
|
|
|
|
|
for (const auto& item :
|
|
|
|
|
std::vector<std::string>{"rnn/x@pre", "rnn/h@pre"}) {
|
|
|
|
|
for (const auto& item : std::vector<std::string>{"rnn/h@pre"}) {
|
|
|
|
|
pre_memories->add_strings(item);
|
|
|
|
|
}
|
|
|
|
|
// memories
|
|
|
|
|
for (const auto& item : std::vector<std::string>{"rnn/x", "rnn/h"}) {
|
|
|
|
|
for (const auto& item : std::vector<std::string>{"rnn/h"}) {
|
|
|
|
|
memories->add_strings(item);
|
|
|
|
|
}
|
|
|
|
|
// output alias
|
|
|
|
@ -151,14 +149,11 @@ protected:
|
|
|
|
|
LOG(INFO) << "create variable step_net";
|
|
|
|
|
Variable* var = scope_.NewVar("step_net");
|
|
|
|
|
auto net = var->GetMutable<NetOp>();
|
|
|
|
|
// rnn/s is net's input or output?
|
|
|
|
|
net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"};
|
|
|
|
|
net->inputs_ = {"rnn/s", "rnn/h"};
|
|
|
|
|
net->AddOp(
|
|
|
|
|
OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {}));
|
|
|
|
|
|
|
|
|
|
net->AddOp(
|
|
|
|
|
OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {}));
|
|
|
|
|
OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {}));
|
|
|
|
|
net->CompleteAddOp();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -297,7 +292,10 @@ protected:
|
|
|
|
|
inlink.internal = "rnn/x";
|
|
|
|
|
auto step_scopes =
|
|
|
|
|
scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
|
|
|
|
|
rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10);
|
|
|
|
|
rnn::SegmentInputs(*step_scopes,
|
|
|
|
|
std::vector<rnn::Link>{inlink},
|
|
|
|
|
10,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LinkeMemories() {
|
|
|
|
@ -311,7 +309,8 @@ protected:
|
|
|
|
|
auto step_scopes =
|
|
|
|
|
scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
|
|
|
|
|
for (int i = 1; i < 10; ++i) {
|
|
|
|
|
rnn::LinkMemories(*step_scopes, memories, i, -1);
|
|
|
|
|
rnn::LinkMemories(
|
|
|
|
|
*step_scopes, memories, i, -1, true /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -333,14 +332,14 @@ TEST(RecurrentOp, LinkMemories) {
|
|
|
|
|
using namespace paddle::operators;
|
|
|
|
|
|
|
|
|
|
// create and init step scopes
|
|
|
|
|
int len = 10;
|
|
|
|
|
size_t len = 10;
|
|
|
|
|
std::vector<Scope*> step_scopes;
|
|
|
|
|
for (int i = 0; i < len; ++i) {
|
|
|
|
|
for (size_t i = 0; i < len; ++i) {
|
|
|
|
|
auto scope = new Scope();
|
|
|
|
|
scope->NewVar("pre_h");
|
|
|
|
|
auto tensor = scope->NewVar("h")->GetMutable<Tensor>();
|
|
|
|
|
float* data = tensor->mutable_data<float>({15, 20}, CPUPlace());
|
|
|
|
|
for (int j = 0; j < 15 * 20; ++j) {
|
|
|
|
|
for (size_t j = 0; j < 15 * 20; ++j) {
|
|
|
|
|
data[j] = rand() * (1. / (double)RAND_MAX);
|
|
|
|
|
}
|
|
|
|
|
step_scopes.push_back(scope);
|
|
|
|
@ -354,24 +353,24 @@ TEST(RecurrentOp, LinkMemories) {
|
|
|
|
|
std::vector<rnn::MemoryAttr> memories;
|
|
|
|
|
memories.push_back(mem_attr);
|
|
|
|
|
|
|
|
|
|
for (int i = 1; i < len; ++i) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, memories, i, -1);
|
|
|
|
|
for (size_t i = 1; i < len; ++i) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, memories, i, -1, false /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
// check
|
|
|
|
|
for (int i = 0; i < len - 1; ++i) {
|
|
|
|
|
for (size_t i = 0; i < len - 1; ++i) {
|
|
|
|
|
const float* a =
|
|
|
|
|
step_scopes[i]->FindVar("h")->GetMutable<Tensor>()->data<float>();
|
|
|
|
|
const float* b = step_scopes[i + 1]
|
|
|
|
|
->FindVar("pre_h")
|
|
|
|
|
->GetMutable<Tensor>()
|
|
|
|
|
->data<float>();
|
|
|
|
|
for (size_t i = 0; i < 15 * 20; ++i) {
|
|
|
|
|
ASSERT_FLOAT_EQ(a[i], b[i]);
|
|
|
|
|
for (size_t j = 0; j < 15 * 20; ++j) {
|
|
|
|
|
ASSERT_FLOAT_EQ(a[j], b[j]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = len - 2; i >= 0; --i) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, memories, i, 1);
|
|
|
|
|
rnn::LinkMemories(step_scopes, memories, i, 1, false /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
// check
|
|
|
|
|
for (int i = len - 2; i >= 0; --i) {
|
|
|
|
@ -379,8 +378,8 @@ TEST(RecurrentOp, LinkMemories) {
|
|
|
|
|
step_scopes[i]->FindVar("pre_h")->GetMutable<Tensor>()->data<float>();
|
|
|
|
|
const float* b =
|
|
|
|
|
step_scopes[i + 1]->FindVar("h")->GetMutable<Tensor>()->data<float>();
|
|
|
|
|
for (size_t i = 0; i < 15 * 20; ++i) {
|
|
|
|
|
ASSERT_FLOAT_EQ(a[i], b[i]);
|
|
|
|
|
for (size_t j = 0; j < 15 * 20; ++j) {
|
|
|
|
|
ASSERT_FLOAT_EQ(a[j], b[j]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -391,9 +390,3 @@ TEST(RecurrentOp, LinkMemories) {
|
|
|
|
|
|
|
|
|
|
USE_OP(add_two);
|
|
|
|
|
USE_OP(mul);
|
|
|
|
|
|
|
|
|
|
// int main() {
|
|
|
|
|
// //! TODO(yuyang18): Temporary disable this unit-test because implementation
|
|
|
|
|
// //! error.
|
|
|
|
|
// return 0;
|
|
|
|
|
//}
|