|
|
|
@ -290,8 +290,10 @@ class MatMulOp : public framework::OperatorWithKernel {
|
|
|
|
|
context->Attrs().Get<bool>("transpose_Y"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
|
|
|
|
|
PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
|
|
|
|
|
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
|
|
|
|
|
if (context->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
|
|
|
|
|
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
|
|
|
|
|
}
|
|
|
|
|
std::vector<int64_t> dim_out;
|
|
|
|
|
if (mat_dim_x.batch_size_ != 0) {
|
|
|
|
|
dim_out = framework::vectorize(dim_x);
|
|
|
|
|