|
|
|
@ -59,7 +59,9 @@ class MatMulKernel : public framework::OpKernel<T> {
|
|
|
|
|
RowMatrixFromVector(x.dims()), 0, context.Attr<bool>("transpose_X"));
|
|
|
|
|
auto mat_dim_b = math::CreateMatrixDescriptor(
|
|
|
|
|
ColumnMatrixFromVector(y.dims()), 0, context.Attr<bool>("transpose_Y"));
|
|
|
|
|
blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0));
|
|
|
|
|
auto scale = static_cast<T>(context.Attr<float>("scale"));
|
|
|
|
|
auto bias = static_cast<T>(context.Attr<float>("bias"));
|
|
|
|
|
blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, bias);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -185,7 +187,8 @@ class MatMulGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(context);
|
|
|
|
|
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
|
|
|
|
|
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
|
|
|
|
|
blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0));
|
|
|
|
|
blas.MatMul(a, mat_dim_a, b, mat_dim_b,
|
|
|
|
|
static_cast<T>(context.Attr<float>("scale")), out, T(0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CalcInputGrad(const framework::ExecutionContext &context,
|
|
|
|
@ -334,6 +337,8 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
R"DOC(If true, use the transpose of `Y`.
|
|
|
|
|
)DOC")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<float>("scale", "Scale").SetDefault(1.0f);
|
|
|
|
|
AddAttr<float>("bias", "Bias").SetDefault(0.0f);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
MatMul Operator.
|
|
|
|
|
|
|
|
|
|