Rewrite StaticRNN with Executor (#5224)
* Init commit * Make executor use ProgramDescBind * Change Attribute from BlockDesc to BlockDescBind * Since we will get the program desc in RNN, just BlockDesc is not enough. * Add DeviceContext to Executor API * Rewrite RNN * Pass Python * AddBiasOp does not care num_flatten_dims * Stash * Fix MacOS Compile * Pass RNN forward * add python test * refactor test * Make compile pass * add gradopmaker * First draft done * Polish code * add grad op maker and grad infershape * Polish code * Fix backward.cc bug * Fix infershape * Rename function * add backward test * simplify recurrent test * Update * Pass unittest * Add comments & refine test * Add comments * refactor test * Complete Unittest * fix StepScopes enforce * Remove unused unittest * no type error * Update * Make RNN Pass unittestfix-typo
parent
8cdb42c2b3
commit
0a32e74d13
File diff suppressed because it is too large
Load Diff
@ -1,170 +0,0 @@
|
|||||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License. */
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "paddle/framework/operator.h"
|
|
||||||
#include "paddle/operators/net_op.h"
|
|
||||||
#include "paddle/operators/rnn/recurrent_op_utils.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace operators {
|
|
||||||
|
|
||||||
// The sequence format in RecurrentOp is Tensor<seq_len, batch_size, dim> now.
|
|
||||||
// TODO(Superjom)
|
|
||||||
// 1. No-padding computing for sequences with indifinite length in one batch.
|
|
||||||
// 2. Hierarchical RNN for sequence with sub-sequence.
|
|
||||||
// 3. Internal Memory.
|
|
||||||
// 4. More Complex RNN architecture, such as Gated Feedback RNN.
|
|
||||||
// Refer to: https://arxiv.org/pdf/1502.02367.pdf
|
|
||||||
|
|
||||||
class RecurrentAlgorithm {
|
|
||||||
public:
|
|
||||||
void Run(const framework::Scope& scope,
|
|
||||||
const platform::DeviceContext& dev_ctx) const;
|
|
||||||
|
|
||||||
void Init(rnn::Argument* arg,
|
|
||||||
std::unique_ptr<framework::OperatorBase>* stepnet) {
|
|
||||||
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
|
|
||||||
arg_ = arg;
|
|
||||||
stepnet_ = stepnet;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
/*
|
|
||||||
* The step scopes will be stored in the father scope as a variable.
|
|
||||||
*
|
|
||||||
* NOTE the scopes are reused in both the forward and backward, so just
|
|
||||||
* create once and expand its size if more steps need.
|
|
||||||
*/
|
|
||||||
void CreateScopes(const framework::Scope& scope, size_t seq_len) const;
|
|
||||||
|
|
||||||
const std::vector<framework::Scope*>& GetStepScopes(
|
|
||||||
const framework::Scope& scope) const {
|
|
||||||
return *scope.FindVar(arg_->step_scopes)
|
|
||||||
->GetMutable<std::vector<framework::Scope*>>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void InitMemories(framework::Scope* step_scopes) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unique_ptr<framework::OperatorBase>* stepnet_;
|
|
||||||
rnn::Argument* arg_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class RecurrentGradientAlgorithm {
|
|
||||||
/**
|
|
||||||
* RNN's backward alogorithm.
|
|
||||||
*
|
|
||||||
* To accelerate the development of RecurrentGradientOp, we decouple RNN's
|
|
||||||
* algorithm and `OperatorBase`'s implementation, the former contains the core
|
|
||||||
* implementation of a RNN, and will keep stable even if the framework changes
|
|
||||||
* a
|
|
||||||
* lot, and the latter is a wrapper acts like an dapter for it to make RNN an
|
|
||||||
* operator.
|
|
||||||
*/
|
|
||||||
public:
|
|
||||||
void Init(rnn::Argument* arg,
|
|
||||||
std::unique_ptr<framework::OperatorBase>* stepnet) {
|
|
||||||
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
|
|
||||||
arg_ = std::move(arg);
|
|
||||||
stepnet_ = stepnet;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Run(const framework::Scope& scope,
|
|
||||||
const platform::DeviceContext& dev_ctx) const;
|
|
||||||
|
|
||||||
void LinkBootMemoryGradients(framework::Scope* step_scopes) const;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
inline const std::vector<framework::Scope*>& GetStepScopes(
|
|
||||||
const framework::Scope& scope) const {
|
|
||||||
return *scope.FindVar(arg_->step_scopes)
|
|
||||||
->GetMutable<std::vector<framework::Scope*>>();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
rnn::Argument* arg_;
|
|
||||||
std::unique_ptr<framework::OperatorBase>* stepnet_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class RecurrentOp : public framework::OperatorBase {
|
|
||||||
public:
|
|
||||||
RecurrentOp(const std::string& type, const framework::VariableNameMap& inputs,
|
|
||||||
const framework::VariableNameMap& outputs,
|
|
||||||
const framework::AttributeMap& attrs);
|
|
||||||
|
|
||||||
RecurrentOp(const RecurrentOp& o)
|
|
||||||
: framework::OperatorBase(
|
|
||||||
static_cast<const framework::OperatorBase&>(o)) {
|
|
||||||
// TODO(yuyang18): Implement copy ctor well.
|
|
||||||
PADDLE_THROW("Not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Run(const framework::Scope& scope,
|
|
||||||
const platform::DeviceContext& dev_ctx) const override {
|
|
||||||
alg_.Run(scope, dev_ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_stepnet(std::unique_ptr<OperatorBase> net) {
|
|
||||||
stepnet_ = std::move(net);
|
|
||||||
}
|
|
||||||
|
|
||||||
const OperatorBase& stepnet() const { return *stepnet_; }
|
|
||||||
|
|
||||||
static const rnn::ArgumentName kArgName;
|
|
||||||
|
|
||||||
private:
|
|
||||||
RecurrentAlgorithm alg_;
|
|
||||||
rnn::Argument arg_;
|
|
||||||
std::unique_ptr<OperatorBase> stepnet_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class RecurrentGradientOp : public framework::OperatorBase {
|
|
||||||
public:
|
|
||||||
RecurrentGradientOp(const std::string& type,
|
|
||||||
const framework::VariableNameMap& inputs,
|
|
||||||
const framework::VariableNameMap& outputs,
|
|
||||||
const framework::AttributeMap& attrs);
|
|
||||||
|
|
||||||
RecurrentGradientOp(const RecurrentGradientOp& o)
|
|
||||||
: framework::OperatorBase(
|
|
||||||
static_cast<const framework::OperatorBase&>(o)) {
|
|
||||||
// TODO(yuyang18): Implement Copy ctor.
|
|
||||||
PADDLE_THROW("Not Implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Run(const framework::Scope& scope,
|
|
||||||
const platform::DeviceContext& dev_ctx) const override {
|
|
||||||
alg_.Run(scope, dev_ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
static const rnn::ArgumentName kArgName;
|
|
||||||
|
|
||||||
/*
|
|
||||||
* set a stepnet that is created according to a RecurrentOp's stepnet.
|
|
||||||
*/
|
|
||||||
void set_stepnet(std::unique_ptr<OperatorBase> net) {
|
|
||||||
stepnet_ = std::move(net);
|
|
||||||
}
|
|
||||||
const OperatorBase& stepnet() const { return *stepnet_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
RecurrentGradientAlgorithm alg_;
|
|
||||||
std::unique_ptr<OperatorBase> stepnet_;
|
|
||||||
rnn::Argument arg_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace operators
|
|
||||||
} // namespace paddle
|
|
File diff suppressed because it is too large
Load Diff
@ -1,38 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from paddle.v2.framework.layers import *
|
|
||||||
from paddle.v2.framework.framework import g_program
|
|
||||||
|
|
||||||
|
|
||||||
class TestRNN(unittest.TestCase):
|
|
||||||
def test_rnn(self):
|
|
||||||
img = data(
|
|
||||||
shape=[
|
|
||||||
80, # sequence length
|
|
||||||
22, # image height
|
|
||||||
22
|
|
||||||
], # image width
|
|
||||||
data_type='float32',
|
|
||||||
name='image')
|
|
||||||
hidden = fc(input=img, size=100, act='sigmoid', num_flatten_dims=2)
|
|
||||||
self.assertEqual((-1, 80, 100), hidden.shape)
|
|
||||||
hidden = fc(input=hidden, size=100, act='sigmoid', num_flatten_dims=2)
|
|
||||||
self.assertEqual((-1, 80, 100), hidden.shape)
|
|
||||||
|
|
||||||
rnn = StaticRNN()
|
|
||||||
with rnn.step():
|
|
||||||
hidden = rnn.step_input(hidden)
|
|
||||||
self.assertEqual((-1, 100), hidden.shape)
|
|
||||||
memory = rnn.memory(shape=(-1, 32), dtype='float32', init_value=0.0)
|
|
||||||
|
|
||||||
rnn_out = fc(input=[hidden, memory], size=32, act='sigmoid')
|
|
||||||
self.assertEqual((-1, 32), rnn_out.shape)
|
|
||||||
rnn.update_memory(memory, rnn_out)
|
|
||||||
rnn.output(rnn_out)
|
|
||||||
|
|
||||||
out = rnn()
|
|
||||||
self.assertEqual((-1, 80, 32), out.shape)
|
|
||||||
print g_program
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
Loading…
Reference in new issue