|
|
|
@ -137,6 +137,12 @@ class MatMulGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
y_dims.push_back(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int batch_count = 0;
|
|
|
|
|
//
|
|
|
|
|
if (x_dims.size() > 3) {
|
|
|
|
|
batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1,
|
|
|
|
|
std::multiplies<int>());
|
|
|
|
|
}
|
|
|
|
|
// Fix the dOut dimensions.
|
|
|
|
|
int M = 0, N = 0, batchCountX = 0, batchCountY = 0;
|
|
|
|
|
|
|
|
|
@ -149,8 +155,7 @@ class MatMulGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
M = transpose_x ? x_dims[2] : x_dims[1];
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
batchCountX = accumulate(x_dims.begin(), x_dims.end() - 2, 1,
|
|
|
|
|
std::multiplies<int>());
|
|
|
|
|
batchCountX = batch_count;
|
|
|
|
|
size_t mat_s = x_dims.size() - 2;
|
|
|
|
|
M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s];
|
|
|
|
|
}
|
|
|
|
@ -164,8 +169,7 @@ class MatMulGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
N = transpose_y ? y_dims[1] : y_dims[2];
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
batchCountY = accumulate(y_dims.begin(), y_dims.end() - 2, 1,
|
|
|
|
|
std::multiplies<int>());
|
|
|
|
|
batchCountY = batch_count;
|
|
|
|
|
size_t mat_s = y_dims.size() - 2;
|
|
|
|
|
N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1];
|
|
|
|
|
}
|
|
|
|
@ -180,8 +184,6 @@ class MatMulGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (batchCount) {
|
|
|
|
|
if (x_dims.size() > 3) {
|
|
|
|
|
dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2);
|
|
|
|
|
} else if (y_dims.size() > 3) {
|
|
|
|
|
dout_dims.insert(dout_dims.begin(), y_dims.begin(), y_dims.end() - 2);
|
|
|
|
|
} else {
|
|
|
|
|
dout_dims.insert(dout_dims.begin(), batchCount);
|
|
|
|
|
}
|
|
|
|
|