|
|
@ -358,11 +358,11 @@ LoDTensor* RNNAlgorithm::ArgCache::GetTensor(const framework::Scope& scope,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const std::array<rnn::ArgumentName, 2> RNNAlgorithm::kArgNames{
|
|
|
|
const std::array<rnn::ArgumentName, 2> RNNAlgorithm::kArgNames{
|
|
|
|
rnn::ArgumentName{"step_unit", "step_scopes", "inputs", "outputs", "states",
|
|
|
|
{rnn::ArgumentName{"step_unit", "step_scopes", "inputs", "outputs",
|
|
|
|
"ex_states", "initial_states"},
|
|
|
|
"states", "ex_states", "initial_states"},
|
|
|
|
rnn::ArgumentName{"step_unit", "step_scopes@GRAD", "outputs@GRAD",
|
|
|
|
rnn::ArgumentName{"step_unit", "step_scopes@GRAD", "outputs@GRAD",
|
|
|
|
"inputs@GRAD", "states", "ex_states",
|
|
|
|
"inputs@GRAD", "states", "ex_states",
|
|
|
|
"initial_states@GRAD"}};
|
|
|
|
"initial_states@GRAD"}}};
|
|
|
|
|
|
|
|
|
|
|
|
void DynamicRecurrentOp::Run(const framework::Scope& scope,
|
|
|
|
void DynamicRecurrentOp::Run(const framework::Scope& scope,
|
|
|
|
const platform::DeviceContext& dev_ctx) const {
|
|
|
|
const platform::DeviceContext& dev_ctx) const {
|
|
|
|