Refine InferShape for recurrent_network_op.

* the tensor only contains shape and does not hold memory when inferring shape.
cblas_new
dangqingqing 8 years ago
parent 0973c2c97b
commit 6cfb9a3262

File diff suppressed because it is too large Load Diff

@ -72,19 +72,22 @@ struct ArgumentName {
*/
void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& inlinks,
const size_t seq_len);
const size_t seq_len,
bool infer_shape);
/**
* Process outputs of step nets and merge to variables.
*/
void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& outlinks,
const size_t seq_len);
const size_t seq_len,
bool infer_shape);
void LinkMemories(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<MemoryAttr>& memories,
size_t step_id,
int offset);
const size_t step_id,
const int offset,
bool infer_shape);
void InitArgument(const ArgumentName& name, Argument* arg);
@ -125,7 +128,7 @@ protected:
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
}
void InitMemories(std::shared_ptr<Scope> step_scopes) const;
void InitMemories(std::shared_ptr<Scope> step_scopes, bool infer_shape) const;
private:
std::unique_ptr<rnn::Argument> arg_;
@ -149,7 +152,8 @@ public:
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const;
void LinkBootMemoryGradients(std::shared_ptr<Scope> step_scopes) const;
void LinkBootMemoryGradients(std::shared_ptr<Scope> step_scopes,
bool infer_shape) const;
/**
* InferShape must be called before Run.

@ -56,7 +56,7 @@ protected:
w->GetMutable<Tensor>()->mutable_data<float>(
make_ddim(std::vector<int>{30, 30}), platform::CPUPlace());
for (auto boot : std::vector<std::string>{"x_boot", "h_boot"}) {
for (auto boot : std::vector<std::string>{"h_boot"}) {
LOG(INFO) << "create global variable " << boot;
Variable* h_boot = scope_->CreateVariable(boot);
h_boot->GetMutable<Tensor>()->mutable_data<float>(
@ -80,7 +80,6 @@ protected:
op_desc.add_inputs("x0");
op_desc.add_inputs("x1");
// boot_memories 3
op_desc.add_inputs("x_boot");
op_desc.add_inputs("h_boot");
// step net 5
op_desc.add_inputs("step_net");
@ -92,7 +91,7 @@ protected:
auto _input_format = std::vector<int>{
0, // in_link
3, // memories
5 // step_net
4 // step_net
};
auto input_format = op_desc.add_attrs();
input_format->set_name("input_format");
@ -130,12 +129,11 @@ protected:
inlink_alias->add_strings(item);
}
// pre memories
for (const auto& item :
std::vector<std::string>{"rnn/x@pre", "rnn/h@pre"}) {
for (const auto& item : std::vector<std::string>{"rnn/h@pre"}) {
pre_memories->add_strings(item);
}
// memories
for (const auto& item : std::vector<std::string>{"rnn/x", "rnn/h"}) {
for (const auto& item : std::vector<std::string>{"rnn/h"}) {
memories->add_strings(item);
}
// output alias
@ -152,14 +150,11 @@ protected:
LOG(INFO) << "create variable step_net";
Variable* var = scope_->CreateVariable("step_net");
auto net = var->GetMutable<NetOp>();
// rnn/s is net's input or output?
net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"};
net->inputs_ = {"rnn/s", "rnn/h"};
net->AddOp(
OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {}));
net->AddOp(
OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {}));
OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {}));
net->CompleteAddOp();
}
@ -303,7 +298,7 @@ protected:
std::vector<std::shared_ptr<Scope>>* step_scopes =
scope_->GetVariable("step_scopes")
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10);
rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10, true);
}
void LinkeMemories() {
@ -318,7 +313,7 @@ protected:
scope_->GetVariable("step_scopes")
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
for (int i = 1; i < 10; ++i) {
rnn::LinkMemories(*step_scopes, memories, i, -1);
rnn::LinkMemories(*step_scopes, memories, i, -1, true);
}
}
@ -347,7 +342,7 @@ TEST(RecurrentOp, LinkMemories) {
scope->CreateVariable("pre_h");
auto tensor = scope->CreateVariable("h")->GetMutable<Tensor>();
float* data = tensor->mutable_data<float>(make_ddim({15, 20}), CPUPlace());
for (int i = 0; i < 15 * 20; ++i) {
for (int j = 0; j < 15 * 20; ++j) {
data[i] = rand() * (1. / (double)RAND_MAX);
}
step_scopes.push_back(scope);
@ -362,7 +357,7 @@ TEST(RecurrentOp, LinkMemories) {
memories.push_back(mem_attr);
for (int i = 1; i < len; ++i) {
rnn::LinkMemories(step_scopes, memories, i, -1);
rnn::LinkMemories(step_scopes, memories, i, -1, false);
}
// check
for (int i = 0; i < len - 1; ++i) {
@ -372,13 +367,13 @@ TEST(RecurrentOp, LinkMemories) {
->GetVariable("pre_h")
->GetMutable<Tensor>()
->data<float>();
for (size_t i = 0; i < 15 * 20; ++i) {
ASSERT_FLOAT_EQ(a[i], b[i]);
for (size_t j = 0; j < 15 * 20; ++j) {
ASSERT_FLOAT_EQ(a[j], b[j]);
}
}
for (int i = len - 2; i >= 0; --i) {
rnn::LinkMemories(step_scopes, memories, i, 1);
rnn::LinkMemories(step_scopes, memories, i, 1, false);
}
// check
for (int i = len - 2; i >= 0; --i) {
@ -390,8 +385,8 @@ TEST(RecurrentOp, LinkMemories) {
->GetVariable("h")
->GetMutable<Tensor>()
->data<float>();
for (size_t i = 0; i < 15 * 20; ++i) {
ASSERT_FLOAT_EQ(a[i], b[i]);
for (size_t j = 0; j < 15 * 20; ++j) {
ASSERT_FLOAT_EQ(a[j], b[j]);
}
}
}

Loading…
Cancel
Save