|
|
|
@ -188,34 +188,34 @@ class MatMulFactory {
|
|
|
|
|
memory::dims strides_y;
|
|
|
|
|
std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y");
|
|
|
|
|
|
|
|
|
|
const auto x_bs = mat_dim_x.batch_size_;
|
|
|
|
|
const auto y_bs = mat_dim_y.batch_size_;
|
|
|
|
|
auto x_bs = mat_dim_x.batch_size_;
|
|
|
|
|
auto y_bs = mat_dim_y.batch_size_;
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, false,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"If batch sizes of X and Y are positive,"
|
|
|
|
|
"they have to be equal."));
|
|
|
|
|
|
|
|
|
|
// Store 1 if both batches are zero, otherwise save the nonzero batch
|
|
|
|
|
const memory::dim BS = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
|
|
|
|
|
memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
|
|
|
|
|
const memory::dim M = mat_dim_x.height_;
|
|
|
|
|
const memory::dim N = mat_dim_y.width_;
|
|
|
|
|
const memory::dim K = mat_dim_x.width_;
|
|
|
|
|
|
|
|
|
|
batch_size_ = 1;
|
|
|
|
|
auto b = BS;
|
|
|
|
|
if (BS > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) {
|
|
|
|
|
if (out_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) {
|
|
|
|
|
auto& x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto& y_dims = ctx.Input<Tensor>("Y")->dims();
|
|
|
|
|
batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0];
|
|
|
|
|
b = BS / batch_size_;
|
|
|
|
|
x_bs /= batch_size_;
|
|
|
|
|
y_bs /= batch_size_;
|
|
|
|
|
out_bs /= batch_size_;
|
|
|
|
|
}
|
|
|
|
|
memory::dims x_dims = {b, M, K};
|
|
|
|
|
memory::dims y_dims = {b, K, N};
|
|
|
|
|
memory::dims out_dims = {b, M, N};
|
|
|
|
|
memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
|
|
|
|
|
memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N};
|
|
|
|
|
memory::dims out_dims = {out_bs, M, N};
|
|
|
|
|
|
|
|
|
|
x_offset_ = b * M * K * sizeof(XT);
|
|
|
|
|
y_offset_ = b * K * N * sizeof(YT);
|
|
|
|
|
out_offset_ = b * M * N * sizeof(OT);
|
|
|
|
|
x_offset_ = x_bs * M * K * sizeof(XT);
|
|
|
|
|
y_offset_ = y_bs * K * N * sizeof(YT);
|
|
|
|
|
out_offset_ = out_bs * M * N * sizeof(OT);
|
|
|
|
|
|
|
|
|
|
// Translate transA and transB
|
|
|
|
|
if (strides_x.empty())
|
|
|
|
@ -226,7 +226,7 @@ class MatMulFactory {
|
|
|
|
|
: memory::dims{N * K, 1, K};
|
|
|
|
|
memory::dims out_strides = memory::dims{M * N, N, 1};
|
|
|
|
|
|
|
|
|
|
CorrectStridesWhenFloatOutputFused(ctx, N, b, &out_strides);
|
|
|
|
|
CorrectStridesWhenFloatOutputFused(ctx, N, out_bs, &out_strides);
|
|
|
|
|
|
|
|
|
|
return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides};
|
|
|
|
|
}
|
|
|
|
|