|
|
|
@ -44,6 +44,16 @@ class LoadOp : public framework::OperatorBase {
|
|
|
|
|
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
|
|
|
|
|
out_var_name);
|
|
|
|
|
|
|
|
|
|
if (out_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
SaveLodTensor(filename, place, out_var);
|
|
|
|
|
} else if (out_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
SaveSelectedRows(filename, scope, place, out_var);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
false,
|
|
|
|
|
"Load only support LoDTensor and SelectedRows, %s has wrong type",
|
|
|
|
|
iname);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -91,7 +101,7 @@ class LoadOp : public framework::OperatorBase {
|
|
|
|
|
const platform::Place &place,
|
|
|
|
|
framework::Variable *var) const {
|
|
|
|
|
|
|
|
|
|
auto &selectedRows = var->Get<framework::SelectedRows>();
|
|
|
|
|
auto *selectedRows = var->GetMutable<framework::SelectedRows>();
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|