Rewrite StaticRNN with Executor (#5224)
* Init commit * Make executor use ProgramDescBind * Change Attribute from BlockDesc to BlockDescBind * Since we will get the program desc in RNN, just BlockDesc is not enough. * Add DeviceContext to Executor API * Rewrite RNN * Pass Python * AddBiasOp does not care num_flatten_dims * Stash * Fix MacOS Compile * Pass RNN forward * add python test * refactor test * Make compile pass * add gradopmaker * First draft done * Polish code * add grad op maker and grad infershape * Polish code * Fix backward.cc bug * Fix infershape * Rename function * add backward test * simplify recurrent test * Update * Pass unittest * Add comments & refine test * Add comments * refactor test * Complete Unittest * fix StepScopes enforce * Remove unused unittest * no type error * Update * Make RNN Pass unittestfix-typo
parent
8cdb42c2b3
commit
0a32e74d13
File diff suppressed because it is too large
Load Diff
@ -1,170 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/operator.h"
|
||||
#include "paddle/operators/net_op.h"
|
||||
#include "paddle/operators/rnn/recurrent_op_utils.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// The sequence format in RecurrentOp is Tensor<seq_len, batch_size, dim> now.
|
||||
// TODO(Superjom)
|
||||
// 1. No-padding computing for sequences with indifinite length in one batch.
|
||||
// 2. Hierarchical RNN for sequence with sub-sequence.
|
||||
// 3. Internal Memory.
|
||||
// 4. More Complex RNN architecture, such as Gated Feedback RNN.
|
||||
// Refer to: https://arxiv.org/pdf/1502.02367.pdf
|
||||
|
||||
class RecurrentAlgorithm {
|
||||
public:
|
||||
void Run(const framework::Scope& scope,
|
||||
const platform::DeviceContext& dev_ctx) const;
|
||||
|
||||
void Init(rnn::Argument* arg,
|
||||
std::unique_ptr<framework::OperatorBase>* stepnet) {
|
||||
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
|
||||
arg_ = arg;
|
||||
stepnet_ = stepnet;
|
||||
}
|
||||
|
||||
protected:
|
||||
/*
|
||||
* The step scopes will be stored in the father scope as a variable.
|
||||
*
|
||||
* NOTE the scopes are reused in both the forward and backward, so just
|
||||
* create once and expand its size if more steps need.
|
||||
*/
|
||||
void CreateScopes(const framework::Scope& scope, size_t seq_len) const;
|
||||
|
||||
const std::vector<framework::Scope*>& GetStepScopes(
|
||||
const framework::Scope& scope) const {
|
||||
return *scope.FindVar(arg_->step_scopes)
|
||||
->GetMutable<std::vector<framework::Scope*>>();
|
||||
}
|
||||
|
||||
void InitMemories(framework::Scope* step_scopes) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<framework::OperatorBase>* stepnet_;
|
||||
rnn::Argument* arg_;
|
||||
};
|
||||
|
||||
class RecurrentGradientAlgorithm {
|
||||
/**
|
||||
* RNN's backward alogorithm.
|
||||
*
|
||||
* To accelerate the development of RecurrentGradientOp, we decouple RNN's
|
||||
* algorithm and `OperatorBase`'s implementation, the former contains the core
|
||||
* implementation of a RNN, and will keep stable even if the framework changes
|
||||
* a
|
||||
* lot, and the latter is a wrapper acts like an dapter for it to make RNN an
|
||||
* operator.
|
||||
*/
|
||||
public:
|
||||
void Init(rnn::Argument* arg,
|
||||
std::unique_ptr<framework::OperatorBase>* stepnet) {
|
||||
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
|
||||
arg_ = std::move(arg);
|
||||
stepnet_ = stepnet;
|
||||
}
|
||||
|
||||
void Run(const framework::Scope& scope,
|
||||
const platform::DeviceContext& dev_ctx) const;
|
||||
|
||||
void LinkBootMemoryGradients(framework::Scope* step_scopes) const;
|
||||
|
||||
protected:
|
||||
inline const std::vector<framework::Scope*>& GetStepScopes(
|
||||
const framework::Scope& scope) const {
|
||||
return *scope.FindVar(arg_->step_scopes)
|
||||
->GetMutable<std::vector<framework::Scope*>>();
|
||||
}
|
||||
|
||||
private:
|
||||
rnn::Argument* arg_;
|
||||
std::unique_ptr<framework::OperatorBase>* stepnet_;
|
||||
};
|
||||
|
||||
class RecurrentOp : public framework::OperatorBase {
|
||||
public:
|
||||
RecurrentOp(const std::string& type, const framework::VariableNameMap& inputs,
|
||||
const framework::VariableNameMap& outputs,
|
||||
const framework::AttributeMap& attrs);
|
||||
|
||||
RecurrentOp(const RecurrentOp& o)
|
||||
: framework::OperatorBase(
|
||||
static_cast<const framework::OperatorBase&>(o)) {
|
||||
// TODO(yuyang18): Implement copy ctor well.
|
||||
PADDLE_THROW("Not implemented");
|
||||
}
|
||||
|
||||
void Run(const framework::Scope& scope,
|
||||
const platform::DeviceContext& dev_ctx) const override {
|
||||
alg_.Run(scope, dev_ctx);
|
||||
}
|
||||
|
||||
void set_stepnet(std::unique_ptr<OperatorBase> net) {
|
||||
stepnet_ = std::move(net);
|
||||
}
|
||||
|
||||
const OperatorBase& stepnet() const { return *stepnet_; }
|
||||
|
||||
static const rnn::ArgumentName kArgName;
|
||||
|
||||
private:
|
||||
RecurrentAlgorithm alg_;
|
||||
rnn::Argument arg_;
|
||||
std::unique_ptr<OperatorBase> stepnet_;
|
||||
};
|
||||
|
||||
class RecurrentGradientOp : public framework::OperatorBase {
|
||||
public:
|
||||
RecurrentGradientOp(const std::string& type,
|
||||
const framework::VariableNameMap& inputs,
|
||||
const framework::VariableNameMap& outputs,
|
||||
const framework::AttributeMap& attrs);
|
||||
|
||||
RecurrentGradientOp(const RecurrentGradientOp& o)
|
||||
: framework::OperatorBase(
|
||||
static_cast<const framework::OperatorBase&>(o)) {
|
||||
// TODO(yuyang18): Implement Copy ctor.
|
||||
PADDLE_THROW("Not Implemented");
|
||||
}
|
||||
|
||||
void Run(const framework::Scope& scope,
|
||||
const platform::DeviceContext& dev_ctx) const override {
|
||||
alg_.Run(scope, dev_ctx);
|
||||
}
|
||||
|
||||
static const rnn::ArgumentName kArgName;
|
||||
|
||||
/*
|
||||
* set a stepnet that is created according to a RecurrentOp's stepnet.
|
||||
*/
|
||||
void set_stepnet(std::unique_ptr<OperatorBase> net) {
|
||||
stepnet_ = std::move(net);
|
||||
}
|
||||
const OperatorBase& stepnet() const { return *stepnet_; }
|
||||
|
||||
private:
|
||||
RecurrentGradientAlgorithm alg_;
|
||||
std::unique_ptr<OperatorBase> stepnet_;
|
||||
rnn::Argument arg_;
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -1,38 +0,0 @@
|
||||
import unittest
|
||||
from paddle.v2.framework.layers import *
|
||||
from paddle.v2.framework.framework import g_program
|
||||
|
||||
|
||||
class TestRNN(unittest.TestCase):
|
||||
def test_rnn(self):
|
||||
img = data(
|
||||
shape=[
|
||||
80, # sequence length
|
||||
22, # image height
|
||||
22
|
||||
], # image width
|
||||
data_type='float32',
|
||||
name='image')
|
||||
hidden = fc(input=img, size=100, act='sigmoid', num_flatten_dims=2)
|
||||
self.assertEqual((-1, 80, 100), hidden.shape)
|
||||
hidden = fc(input=hidden, size=100, act='sigmoid', num_flatten_dims=2)
|
||||
self.assertEqual((-1, 80, 100), hidden.shape)
|
||||
|
||||
rnn = StaticRNN()
|
||||
with rnn.step():
|
||||
hidden = rnn.step_input(hidden)
|
||||
self.assertEqual((-1, 100), hidden.shape)
|
||||
memory = rnn.memory(shape=(-1, 32), dtype='float32', init_value=0.0)
|
||||
|
||||
rnn_out = fc(input=[hidden, memory], size=32, act='sigmoid')
|
||||
self.assertEqual((-1, 32), rnn_out.shape)
|
||||
rnn.update_memory(memory, rnn_out)
|
||||
rnn.output(rnn_out)
|
||||
|
||||
out = rnn()
|
||||
self.assertEqual((-1, 80, 32), out.shape)
|
||||
print g_program
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue