|
|
|
@ -140,9 +140,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
|
|
|
|
|
// fail since "states" and "ex_states" cannot be found in main block.
|
|
|
|
|
// When memory optimization is enabled, "states", "ex_states" and their
|
|
|
|
|
// gradient should be skipped.
|
|
|
|
|
auto& ex_states =
|
|
|
|
|
auto ex_states =
|
|
|
|
|
boost::get<std::vector<std::string>>(op_desc->GetAttr("ex_states"));
|
|
|
|
|
auto& states =
|
|
|
|
|
auto states =
|
|
|
|
|
boost::get<std::vector<std::string>>(op_desc->GetAttr("states"));
|
|
|
|
|
if (op_type == "recurrent") {
|
|
|
|
|
UpdateSkipVarSet(skip_vars, {ex_states, states});
|
|
|
|
@ -154,7 +154,7 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
|
|
|
|
|
UpdateSkipVarSet(
|
|
|
|
|
skip_vars,
|
|
|
|
|
{ToGradVarName(op_desc->Input("parameters")),
|
|
|
|
|
ToGradVarName(op_desc->Input("input")), ex_states, states,
|
|
|
|
|
ToGradVarName(op_desc->Input("inputs")), ex_states, states,
|
|
|
|
|
ToGradVarName(ex_states), ToGradVarName(states)});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|