|
|
|
@ -69,23 +69,19 @@ struct ArgumentName {
|
|
|
|
|
* Prepare inputs for each step net.
|
|
|
|
|
*/
|
|
|
|
|
void SegmentInputs(const std::vector<framework::Scope*>& step_scopes,
|
|
|
|
|
const std::vector<Link>& inlinks,
|
|
|
|
|
const size_t seq_len,
|
|
|
|
|
const std::vector<Link>& inlinks, const size_t seq_len,
|
|
|
|
|
bool infer_shape_mode);
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Process outputs of step nets and merge to variables.
|
|
|
|
|
*/
|
|
|
|
|
void ConcatOutputs(const std::vector<framework::Scope*>& step_scopes,
|
|
|
|
|
const std::vector<Link>& outlinks,
|
|
|
|
|
const size_t seq_len,
|
|
|
|
|
const std::vector<Link>& outlinks, const size_t seq_len,
|
|
|
|
|
bool infer_shape_mode);
|
|
|
|
|
|
|
|
|
|
void LinkMemories(const std::vector<framework::Scope*>& step_scopes,
|
|
|
|
|
const std::vector<MemoryAttr>& memories,
|
|
|
|
|
const size_t step_id,
|
|
|
|
|
const int offset,
|
|
|
|
|
bool infer_shape_mode);
|
|
|
|
|
const std::vector<MemoryAttr>& memories, const size_t step_id,
|
|
|
|
|
const int offset, bool infer_shape_mode);
|
|
|
|
|
|
|
|
|
|
void InitArgument(const ArgumentName& name, Argument* arg);
|
|
|
|
|
|
|
|
|
@ -100,7 +96,7 @@ void InitArgument(const ArgumentName& name, Argument* arg);
|
|
|
|
|
// Refer to: https://arxiv.org/pdf/1502.02367.pdf
|
|
|
|
|
|
|
|
|
|
class RecurrentAlgorithm {
|
|
|
|
|
public:
|
|
|
|
|
public:
|
|
|
|
|
void Run(const framework::Scope& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const;
|
|
|
|
|
|
|
|
|
@ -111,7 +107,7 @@ public:
|
|
|
|
|
*/
|
|
|
|
|
void InferShape(const framework::Scope& scope) const;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
protected:
|
|
|
|
|
/*
|
|
|
|
|
* The step scopes will be stored in the father scope as a variable.
|
|
|
|
|
*
|
|
|
|
@ -128,7 +124,7 @@ protected:
|
|
|
|
|
|
|
|
|
|
void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
private:
|
|
|
|
|
std::unique_ptr<rnn::Argument> arg_;
|
|
|
|
|
mutable size_t seq_len_;
|
|
|
|
|
};
|
|
|
|
@ -144,7 +140,7 @@ class RecurrentGradientAlgorithm {
|
|
|
|
|
* lot, and the latter is a wrapper acts like an dapter for it to make RNN an
|
|
|
|
|
* operator.
|
|
|
|
|
*/
|
|
|
|
|
public:
|
|
|
|
|
public:
|
|
|
|
|
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
|
|
|
|
|
|
|
|
|
|
void Run(const framework::Scope& scope,
|
|
|
|
@ -158,20 +154,20 @@ public:
|
|
|
|
|
*/
|
|
|
|
|
void InferShape(const framework::Scope& scope) const;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
protected:
|
|
|
|
|
inline const std::vector<framework::Scope*>& GetStepScopes(
|
|
|
|
|
const framework::Scope& scope) const {
|
|
|
|
|
return *scope.FindVar(arg_->step_scopes)
|
|
|
|
|
->GetMutable<std::vector<framework::Scope*>>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
private:
|
|
|
|
|
std::unique_ptr<rnn::Argument> arg_;
|
|
|
|
|
mutable size_t seq_len_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class RecurrentOp final : public framework::OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
public:
|
|
|
|
|
void Init() override;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -188,12 +184,12 @@ public:
|
|
|
|
|
|
|
|
|
|
static const rnn::ArgumentName kArgName;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
private:
|
|
|
|
|
RecurrentAlgorithm alg_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class RecurrentGradientOp final : public framework::OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
public:
|
|
|
|
|
void Init() override;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -210,7 +206,7 @@ public:
|
|
|
|
|
|
|
|
|
|
static const rnn::ArgumentName kArgName;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
private:
|
|
|
|
|
RecurrentGradientAlgorithm alg_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|