"tensor mutable data"

revert-3824-remove_grad_op_type
dongzhihong 8 years ago
parent 2799da6634
commit 0cf5bdec56

@ -51,9 +51,11 @@ class MulGradKernel : public framework::OpKernel {
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
dX->mutable_data<T>(ctx.GetPlace());
dY->mutable_data<T>(ctx.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);
// dX = dOut' * Y. dX: M x K, dOut : M x N, Y : K x N
// dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N
math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context);
// dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K
math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context);

Loading…
Cancel
Save