load op add seletedRows

port
tangwei12 7 years ago
parent 549f0aa0d3
commit a501766ab1

@ -44,6 +44,16 @@ class LoadOp : public framework::OperatorBase {
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
out_var_name);
if (out_var->IsType<framework::LoDTensor>()) {
SaveLodTensor(filename, place, out_var);
} else if (out_var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(filename, scope, place, out_var);
} else {
PADDLE_ENFORCE(
false,
"Load only support LoDTensor and SelectedRows, %s has wrong type",
iname);
}
}
}
@ -91,7 +101,7 @@ class LoadOp : public framework::OperatorBase {
const platform::Place &place,
framework::Variable *var) const {
auto &selectedRows = var->Get<framework::SelectedRows>();
auto *selectedRows = var->GetMutable<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();

Loading…
Cancel
Save