|
|
@ -689,13 +689,15 @@ void BindImperative(py::module *m_ptr) {
|
|
|
|
x = linear(data)
|
|
|
|
x = linear(data)
|
|
|
|
print(x.numpy())
|
|
|
|
print(x.numpy())
|
|
|
|
)DOC")
|
|
|
|
)DOC")
|
|
|
|
.def("detach",
|
|
|
|
.def(
|
|
|
|
[](const imperative::VarBase
|
|
|
|
"detach",
|
|
|
|
&self) -> std::shared_ptr<imperative::VarBase> {
|
|
|
|
[](const imperative::VarBase &self)
|
|
|
|
|
|
|
|
-> std::shared_ptr<imperative::VarBase> {
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
self.Var().IsInitialized(), true,
|
|
|
|
self.Var().IsInitialized(), true,
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"Tensor %s has not been initialized!", self.Name()));
|
|
|
|
"Tensor %s has not been initialized!", self.Name()));
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
self.Var().IsType<framework::LoDTensor>() ||
|
|
|
|
self.Var().IsType<framework::LoDTensor>() ||
|
|
|
|
self.Var().IsType<framework::SelectedRows>(),
|
|
|
|
self.Var().IsType<framework::SelectedRows>(),
|
|
|
@ -703,38 +705,23 @@ void BindImperative(py::module *m_ptr) {
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"Type of Tensor[%s] must be LoDTensor or SelectedRows!",
|
|
|
|
"Type of Tensor[%s] must be LoDTensor or SelectedRows!",
|
|
|
|
self.Name()));
|
|
|
|
self.Name()));
|
|
|
|
|
|
|
|
|
|
|
|
auto detach_var = std::make_shared<imperative::VarBase>(
|
|
|
|
auto detach_var = std::make_shared<imperative::VarBase>(
|
|
|
|
true, "detach_" + self.Name());
|
|
|
|
true, "detach_" + self.Name());
|
|
|
|
|
|
|
|
|
|
|
|
detach_var->SetPersistable(self.Persistable());
|
|
|
|
detach_var->SetPersistable(self.Persistable());
|
|
|
|
detach_var->SetType(self.Type());
|
|
|
|
detach_var->SetType(self.Type());
|
|
|
|
detach_var->SetDataType(self.DataType());
|
|
|
|
detach_var->SetDataType(self.DataType());
|
|
|
|
if (self.Var().IsType<framework::LoDTensor>()) {
|
|
|
|
|
|
|
|
const auto &origin_tensor =
|
|
|
|
|
|
|
|
self.Var().Get<framework::LoDTensor>();
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
origin_tensor.IsInitialized(), true,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"Tensor %s has not been initialized!", self.Name()));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto *detach_tensor =
|
|
|
|
// NOTE(liym27):
|
|
|
|
detach_var->MutableVar()->GetMutable<framework::LoDTensor>();
|
|
|
|
// Call Variable::SharePlaceholderWith but not
|
|
|
|
detach_tensor->ShareDataWith(origin_tensor);
|
|
|
|
// Tensor::ShareDataWith or Tensor::ShareBufferWith, because
|
|
|
|
} else {
|
|
|
|
// `detach_var` should share the same TensorInplaceVersion with
|
|
|
|
const auto &origin_selected_rows =
|
|
|
|
// `self`, and only SharePlaceholderWith can also share the same
|
|
|
|
self.Var().Get<framework::SelectedRows>();
|
|
|
|
// TensorInplaceVersion, which is used to check whether inplace
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
// operations are correct.
|
|
|
|
origin_selected_rows.value().IsInitialized(), true,
|
|
|
|
detach_var->MutableVar()->SharePlaceholderWith(self.Var());
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"Tensor %s has not been initialized!", self.Name()));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto *detach_selected_rows =
|
|
|
|
|
|
|
|
detach_var->MutableVar()
|
|
|
|
|
|
|
|
->GetMutable<framework::SelectedRows>();
|
|
|
|
|
|
|
|
detach_selected_rows->set_height(origin_selected_rows.height());
|
|
|
|
|
|
|
|
detach_selected_rows->set_rows(origin_selected_rows.rows());
|
|
|
|
|
|
|
|
detach_selected_rows->mutable_value()->ShareDataWith(
|
|
|
|
|
|
|
|
origin_selected_rows.value());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
VLOG(3) << "The detached Tensor(" << detach_var->Name()
|
|
|
|
VLOG(3) << "The detached Tensor(" << detach_var->Name()
|
|
|
|
<< ") share data with " << self.Name();
|
|
|
|
<< ") share data with " << self.Name();
|
|
|
|
return detach_var;
|
|
|
|
return detach_var;
|
|
|
|