optimize assign op to avoid copy data from GPU to GPU (#21181)

* optimize assign op to avoid copy data from GPU to GPU, test=develop

* modified GetkernelTypeForVar and just avoid device transform, test=develop
revert-21172-masked_select_api
Zhang Ting 6 years ago committed by GitHub
parent c91cb6c550
commit 01a9646323
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,6 +41,14 @@ class AssignOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(

@ -47,7 +47,7 @@ class AssignFunctor {
out_rows.set_height(rows.height());
auto &t = rows.value();
auto *m = out_rows.mutable_value();
framework::TensorCopy(t, t.place(), dev_ctx_, m);
framework::TensorCopy(t, dev_ctx_.GetPlace(), dev_ctx_, m);
}
template <typename T>
@ -60,7 +60,7 @@ class AssignFunctor {
framework::LoDTensor *out) const {
if (lod_tensor.numel() == 0) return;
auto &out_tensor = *out;
TensorCopy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor);
TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, &out_tensor);
out_tensor.set_lod(lod_tensor.lod());
}

Loading…
Cancel
Save