|
|
|
@ -19,7 +19,7 @@
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
using namespace paddle::framework;
|
|
|
|
using namespace paddle::framework; // NOLINT
|
|
|
|
|
|
|
|
|
|
|
|
namespace rnn {
|
|
|
|
namespace rnn {
|
|
|
|
|
|
|
|
|
|
|
|
@ -94,7 +94,7 @@ void InitArgument(const ArgumentName& name, Argument* arg);
|
|
|
|
}; // namespace rnn
|
|
|
|
}; // namespace rnn
|
|
|
|
|
|
|
|
|
|
|
|
// The sequence format in RecurrentOp is Tensor<seq_len, batch_size, dim> now.
|
|
|
|
// The sequence format in RecurrentOp is Tensor<seq_len, batch_size, dim> now.
|
|
|
|
// TODO:
|
|
|
|
// TODO(Yan Chunwei):
|
|
|
|
// 1. No-padding computing for sequences with indifinite length in one batch.
|
|
|
|
// 1. No-padding computing for sequences with indifinite length in one batch.
|
|
|
|
// 2. Hierarchical RNN for sequence with sub-sequence.
|
|
|
|
// 2. Hierarchical RNN for sequence with sub-sequence.
|
|
|
|
// 3. Internal Memory.
|
|
|
|
// 3. Internal Memory.
|
|
|
|
@ -172,11 +172,9 @@ public:
|
|
|
|
/**
|
|
|
|
/**
|
|
|
|
* InferShape must be called before Run.
|
|
|
|
* InferShape must be called before Run.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
virtual void InferShape(const Scope& scope) const override {
|
|
|
|
void InferShape(const Scope& scope) const override { alg_.InferShape(scope); }
|
|
|
|
alg_.InferShape(scope);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
virtual void Run(const Scope& scope,
|
|
|
|
void Run(const Scope& scope,
|
|
|
|
const platform::DeviceContext& dev_ctx) const override {
|
|
|
|
const platform::DeviceContext& dev_ctx) const override {
|
|
|
|
alg_.Run(scope, dev_ctx);
|
|
|
|
alg_.Run(scope, dev_ctx);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -194,11 +192,9 @@ public:
|
|
|
|
/**
|
|
|
|
/**
|
|
|
|
* InferShape must be called before Run.
|
|
|
|
* InferShape must be called before Run.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
virtual void InferShape(const Scope& scope) const override {
|
|
|
|
void InferShape(const Scope& scope) const override { alg_.InferShape(scope); }
|
|
|
|
alg_.InferShape(scope);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
virtual void Run(const Scope& scope,
|
|
|
|
void Run(const Scope& scope,
|
|
|
|
const platform::DeviceContext& dev_ctx) const override {
|
|
|
|
const platform::DeviceContext& dev_ctx) const override {
|
|
|
|
alg_.Run(scope, dev_ctx);
|
|
|
|
alg_.Run(scope, dev_ctx);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|