|
|
|
@ -53,7 +53,9 @@ class MulGradKernel : public framework::OpKernel {
|
|
|
|
|
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
auto* device_context =
|
|
|
|
|
const_cast<platform::DeviceContext*>(ctx.device_context_);
|
|
|
|
|
// dX = dOut' * Y. dX: M x K, dOut : M x N, Y : K x N
|
|
|
|
|
math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context);
|
|
|
|
|
// dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K
|
|
|
|
|
math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|