transpose_mkldnn code change to meet Paddle standards (#22591)

revert-22710-feature/integrated_ps_api
Adam 5 years ago committed by GitHub
parent 8f035fb637
commit ab610a34ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,7 +40,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* input_data = input->data<T>();
if (ndims == 1) {
output->ShareDataWith(*input);
framework::TensorCopy(*input, input->place(), output);
output->set_format(input->format());
return;
}
@ -85,7 +86,8 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> reversed_axis(axis);
int ndims = axis.size();
if (ndims == 1) {
x_grad->ShareDataWith(*out_grad);
framework::TensorCopy(*out_grad, out_grad->place(), x_grad);
x_grad->set_format(out_grad->format());
return;
}

Loading…
Cancel
Save