just add the op error message for the matmul xpu (#30246)

add the op error message for the matmul xpu
revert-31562-mean
wawltor 5 years ago committed by GitHub
parent 6bfdef727e
commit fee424411a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -127,10 +127,18 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
mat_dim_a.width_, mat_dim_b.height_,
platform::errors::InvalidArgument("Shape mistake in matmul_op"));
PADDLE_ENFORCE_EQ(
mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument("Shape mistake in matmul_op"));
platform::errors::InvalidArgument("Shape mistake in matmul_op, the "
"first tensor width must be same as "
"second tensor height, but received "
"width:%d, height:%d",
mat_dim_a.width_, mat_dim_b.height_));
PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_op, the two input"
"tensor batch_size must be same, but received first "
"tensor batch_size:%d, second "
"tensor batch_size:%d",
mat_dim_a.batch_size_, mat_dim_b.batch_size_));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto &dev_ctx = context.template device_context<DeviceContext>();
@ -251,12 +259,20 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
}
}
PADDLE_ENFORCE_EQ(
mat_dim_a.width_, mat_dim_b.height_,
platform::errors::InvalidArgument("Shape mistake in matmul_grad_op"));
PADDLE_ENFORCE_EQ(
mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument("Shape mistake in matmul_grad_op"));
PADDLE_ENFORCE_EQ(mat_dim_a.width_, mat_dim_b.height_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_grad_op, the "
"first tensor width must be same as second tensor "
"height, but received "
"width:%d, height:%d",
mat_dim_a.width_, mat_dim_b.height_));
PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_grad_op, the two input"
"tensor batch_size must be same, but received first "
"tensor batch_size:%d, second "
"tensor batch_size:%d",
mat_dim_a.batch_size_, mat_dim_b.batch_size_));
T alpha = static_cast<T>(context.Attr<float>("alpha"));

Loading…
Cancel
Save