add Tensor::IsSharedBufferWith method, test=develop (#23175)

revert-23830-2.0-beta
Zeng Jinle 5 years ago committed by GitHub
parent 2787041246
commit 7ca77a90ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -104,7 +104,7 @@ void ShareTensorBufferFunctor::operator()(Scope *exec_scope) {
// If in_var is inplaced in the previous batch and we want to fetch
// in_var in the current batch, we have to reset memory of out_var
// to avoid wrong calculation result.
if (in_tensor.Holder() == out_tensor->Holder()) {
if (out_tensor->IsSharedBufferWith(in_tensor)) {
VLOG(1) << "Clear " << out_var_names_[i]
<< " because you may want to fetch an inplaced variable "
<< in_var_info->Name()

@ -160,6 +160,10 @@ class Tensor {
offset_ = tensor.offset_;
}
bool IsSharedBufferWith(const Tensor& src) const {
return holder_ && holder_ == src.Holder();
}
const std::shared_ptr<memory::Allocation>& Holder() const { return holder_; }
size_t offset() const { return offset_; }

@ -1100,7 +1100,7 @@ void CommonElementwiseBroadcastBackward(
// for inplace strategy. memset will make dx and dout clear and get wrong
// result.
if (dx && dout.Holder() == dx->Holder()) {
if (dx && dx->IsSharedBufferWith(dout)) {
dx->clear();
dx->mutable_data<T>(x_dims, ctx.GetPlace());
}

Loading…
Cancel
Save