|
|
|
@ -76,12 +76,13 @@ inline void get_mid_dims(const framework::DDim &x_dims,
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < y_dims.size(); ++i) {
|
|
|
|
|
if (x_dims[i + axis] != y_dims[i]) {
|
|
|
|
|
PADDLE_ENFORCE(y_dims[i] == 1 || x_dims[i + axis] == 1,
|
|
|
|
|
"ShapeError: broadcast dimension mismatch. Operands "
|
|
|
|
|
"could not be broadcast together with the shape of "
|
|
|
|
|
"X = [%s] and the shape of Y = [%s]. Received [%d] "
|
|
|
|
|
"in X is not equal to [%d] in Y",
|
|
|
|
|
x_dims, y_dims, x_dims[i + axis], y_dims[i]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims[i] == 1 || x_dims[i + axis] == 1, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Broadcast dimension mismatch. Operands "
|
|
|
|
|
"could not be broadcast together with the shape of "
|
|
|
|
|
"X = [%s] and the shape of Y = [%s]. Received [%d] "
|
|
|
|
|
"in X is not equal to [%d] in Y.",
|
|
|
|
|
x_dims, y_dims, x_dims[i + axis], y_dims[i]));
|
|
|
|
|
*is_run_common_broadcast = 1;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -119,8 +120,15 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
|
|
|
|
|
int *x_dims_array, int *y_dims_array,
|
|
|
|
|
int *out_dims_array, const int max_dim,
|
|
|
|
|
const int axis) {
|
|
|
|
|
PADDLE_ENFORCE_GE(axis, 0, "Axis should be in range [0, %d)", axis);
|
|
|
|
|
PADDLE_ENFORCE_LT(axis, max_dim, "Axis should be in range [0, %d)", axis);
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
axis, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Axis should be great than or equal to 0, but received axis is %d.",
|
|
|
|
|
axis));
|
|
|
|
|
PADDLE_ENFORCE_LT(axis, max_dim,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Axis should be less than %d, but received axis is %d.",
|
|
|
|
|
max_dim, axis));
|
|
|
|
|
if (x_dims.size() > y_dims.size()) {
|
|
|
|
|
std::fill(y_dims_array, y_dims_array + axis, 1);
|
|
|
|
|
if (axis + y_dims.size() < max_dim) {
|
|
|
|
@ -138,13 +146,15 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < max_dim; i++) {
|
|
|
|
|
PADDLE_ENFORCE(x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 ||
|
|
|
|
|
y_dims_array[i] <= 1,
|
|
|
|
|
"ShapeError: broadcast dimension mismatch. Operands could "
|
|
|
|
|
"not be broadcast together with the shape of X = [%s] and "
|
|
|
|
|
"the shape of Y = [%s]. Received [%d] in X is not equal to "
|
|
|
|
|
"[%d] in Y at i:%d",
|
|
|
|
|
x_dims, y_dims, x_dims_array[i], y_dims_array[i], i);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 ||
|
|
|
|
|
y_dims_array[i] <= 1,
|
|
|
|
|
true, platform::errors::InvalidArgument(
|
|
|
|
|
"Broadcast dimension mismatch. Operands could "
|
|
|
|
|
"not be broadcast together with the shape of X = [%s] and "
|
|
|
|
|
"the shape of Y = [%s]. Received [%d] in X is not equal to "
|
|
|
|
|
"[%d] in Y at i:%d.",
|
|
|
|
|
x_dims, y_dims, x_dims_array[i], y_dims_array[i], i));
|
|
|
|
|
if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) ||
|
|
|
|
|
(x_dims_array[i] == 1 && y_dims_array[i] == 1)) {
|
|
|
|
|
out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]);
|
|
|
|
@ -1690,8 +1700,15 @@ void ElemwiseGradComputeWithBroadcast(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
|
|
|
|
|
PADDLE_ENFORCE_GE(axis, 0, "Axis should be in range [0, %d)", axis);
|
|
|
|
|
PADDLE_ENFORCE_LT(axis, max_dim, "Axis should be in range [0, %d)", axis);
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
axis, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Axis should be great than or equal to 0, but received axis is %d.",
|
|
|
|
|
axis));
|
|
|
|
|
PADDLE_ENFORCE_LT(axis, max_dim,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Axis should be less than %d, but received axis is %d.",
|
|
|
|
|
max_dim, axis));
|
|
|
|
|
|
|
|
|
|
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
|
|
|
|
|
if (is_xsize_larger) {
|
|
|
|
@ -1758,8 +1775,15 @@ void CommonElementwiseBroadcastForward(
|
|
|
|
|
int axis, const bool is_xsize_larger = true) {
|
|
|
|
|
int max_dim = std::max(x_dims.size(), y_dims.size());
|
|
|
|
|
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
|
|
|
|
|
PADDLE_ENFORCE_GE(axis, 0, "Axis should be in range [0, %d)", axis);
|
|
|
|
|
PADDLE_ENFORCE_LT(axis, max_dim, "Axis should be in range [0, %d)", axis);
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
axis, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Axis should be great than or equal to 0, but received axis is %d.",
|
|
|
|
|
axis));
|
|
|
|
|
PADDLE_ENFORCE_LT(axis, max_dim,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Axis should be less than %d, but received axis is %d.",
|
|
|
|
|
max_dim, axis));
|
|
|
|
|
std::vector<int> x_dims_array(max_dim);
|
|
|
|
|
std::vector<int> y_dims_array(max_dim);
|
|
|
|
|
std::vector<int> out_dims_array(max_dim);
|
|
|
|
@ -1848,8 +1872,15 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
|
|
|
|
|
PADDLE_ENFORCE_GE(axis, 0, "Axis should be in range [0, %d)", axis);
|
|
|
|
|
PADDLE_ENFORCE_LT(axis, max_dim, "Axis should be in range [0, %d)", axis);
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
axis, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Axis should be great than or equal to 0, but received axis is %d.",
|
|
|
|
|
axis));
|
|
|
|
|
PADDLE_ENFORCE_LT(axis, max_dim,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Axis should be less than %d, but received axis is %d.",
|
|
|
|
|
max_dim, axis));
|
|
|
|
|
|
|
|
|
|
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
|
|
|
|
|
if (is_xsize_larger) {
|
|
|
|
@ -2723,7 +2754,9 @@ void FusedElemwiseAndActGradComputeEx(
|
|
|
|
|
const framework::DDim &x_dim = x->dims();
|
|
|
|
|
const framework::DDim &y_dim = y->dims();
|
|
|
|
|
if (UseIntermediateOut) {
|
|
|
|
|
PADDLE_ENFORCE(intermediate_out, "intermediate_out should not be nullptr");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
intermediate_out,
|
|
|
|
|
platform::errors::InvalidArgument("Intermediate out is null pointer."));
|
|
|
|
|
}
|
|
|
|
|
if (x_dim == y_dim) {
|
|
|
|
|
FusedElemwiseAndActGradComputeNoBroadcast<
|
|
|
|
@ -2768,9 +2801,11 @@ void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
|
|
|
|
|
framework::Tensor *out,
|
|
|
|
|
framework::Tensor *intermediate_out) {
|
|
|
|
|
if (KeepIntermediateOut) {
|
|
|
|
|
PADDLE_ENFORCE(intermediate_out,
|
|
|
|
|
"The save_intermediate_out is opened, "
|
|
|
|
|
"intermediate_out should not be nullptr.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
intermediate_out,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The save_intermediate_out is opened, intermediate "
|
|
|
|
|
"out is null pointer."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const framework::DDim &x_dim = x.dims();
|
|
|
|
|