|
|
|
@ -16,6 +16,7 @@ limitations under the License. */
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_version_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
@ -932,3 +933,14 @@ REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
ops::MatMulDoubleGradKernel<paddle::platform::CUDADeviceContext, float>,
|
|
|
|
|
ops::MatMulDoubleGradKernel<paddle::platform::CUDADeviceContext, double>);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_VERSION(matmul)
|
|
|
|
|
.AddCheckpoint(
|
|
|
|
|
R"ROC(Register matmul for adding the attribute of
|
|
|
|
|
fused_reshape_Y)ROC",
|
|
|
|
|
paddle::framework::compatible::OpVersionDesc().NewAttr(
|
|
|
|
|
"fused_reshape_Y",
|
|
|
|
|
"In order to support the function of fused the input Y "
|
|
|
|
|
" and input X into the input X when "
|
|
|
|
|
"using the operator of matmul, and get raw shape of input Y.",
|
|
|
|
|
std::vector<int>{}));
|
|
|
|
|