[OpDevOptimize] Add common infershape functions (#26096)
* add unchaged infershape function * add broadcast infershape function * fix bug * rename infershape functions * add UnaryOpUnchangedInferShapeCheckAxis * add error message * add test for common infer shape functions * dont update existed ops * dont update op_desc.h * add more test * add error check, refine error messagerevert-24895-update_cub
parent
2d95280e1f
commit
ffe52b4452
@ -0,0 +1,166 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/common_infer_shape_functions.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
// This file almostly contains all the infershape functions that are used in
|
||||
// operators.
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace details {
|
||||
inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
|
||||
const framework::DDim &y_dims,
|
||||
int *x_dims_array, int *y_dims_array,
|
||||
int *out_dims_array, const int max_dim,
|
||||
const int axis) {
|
||||
PADDLE_ENFORCE_GE(
|
||||
axis, 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"Axis should be great than or equal to 0, but received axis is %d.",
|
||||
axis));
|
||||
PADDLE_ENFORCE_LT(axis, max_dim,
|
||||
platform::errors::InvalidArgument(
|
||||
"Axis should be less than %d, but received axis is %d.",
|
||||
max_dim, axis));
|
||||
if (x_dims.size() > y_dims.size()) {
|
||||
std::fill(y_dims_array, y_dims_array + axis, 1);
|
||||
if (axis + y_dims.size() < max_dim) {
|
||||
std::fill(y_dims_array + axis + y_dims.size(), y_dims_array + max_dim, 1);
|
||||
}
|
||||
std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array);
|
||||
std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array + axis);
|
||||
} else {
|
||||
std::fill(x_dims_array, x_dims_array + axis, 1);
|
||||
if (axis + x_dims.size() < max_dim) {
|
||||
std::fill(x_dims_array + axis + x_dims.size(), x_dims_array + max_dim, 1);
|
||||
}
|
||||
std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array + axis);
|
||||
std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array);
|
||||
}
|
||||
|
||||
for (int i = 0; i < max_dim; i++) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 ||
|
||||
y_dims_array[i] <= 1,
|
||||
true, platform::errors::InvalidArgument(
|
||||
"Broadcast dimension mismatch. Operands could "
|
||||
"not be broadcast together with the shape of X = [%s] and "
|
||||
"the shape of Y = [%s]. Received [%d] in X is not equal to "
|
||||
"[%d] in Y at i:%d.",
|
||||
x_dims, y_dims, x_dims_array[i], y_dims_array[i], i));
|
||||
if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) ||
|
||||
(x_dims_array[i] == 1 && y_dims_array[i] == 1)) {
|
||||
out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]);
|
||||
} else {
|
||||
out_dims_array[i] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace details
|
||||
|
||||
// shape input(0) -> output(0) without change.
|
||||
void UnaryOpUnchangedInferShape(framework::InferShapeContext *ctx) {
|
||||
auto x_name = ctx->GetInputNameByIdx(0);
|
||||
auto out_name = ctx->GetOutputNameByIdx(0);
|
||||
ctx->ShareDim(x_name, /*->*/ out_name);
|
||||
ctx->ShareLoD(x_name, /*->*/ out_name);
|
||||
}
|
||||
|
||||
// shape input(0) -> output(0) without change, check if axis in range [-Rank(x),
|
||||
// Rank(x)-1]
|
||||
void UnaryOpUnchangedInferShapeCheckAxis(framework::InferShapeContext *ctx) {
|
||||
auto x_name = ctx->GetInputNameByIdx(0);
|
||||
auto out_name = ctx->GetOutputNameByIdx(0);
|
||||
auto x_dim = ctx->GetInputDim(x_name);
|
||||
auto x_rank = x_dim.size();
|
||||
auto axis = ctx->Attrs().Get<int>("axis");
|
||||
PADDLE_ENFORCE_GE(
|
||||
axis, -x_rank,
|
||||
platform::errors::InvalidArgument(
|
||||
"Attr(axis) value should be in range [-R, R-1], "
|
||||
"R is the rank of Input(X). But received axis: %d, R: %d.",
|
||||
axis, x_rank));
|
||||
PADDLE_ENFORCE_LT(
|
||||
axis, x_rank,
|
||||
platform::errors::InvalidArgument(
|
||||
"Attr(axis) value should be in range [-R, R-1], "
|
||||
"R is the rank of Input(X). But received axis: %d, R: %d.",
|
||||
axis, x_rank));
|
||||
ctx->ShareDim(x_name, /*->*/ out_name);
|
||||
ctx->ShareLoD(x_name, /*->*/ out_name);
|
||||
}
|
||||
|
||||
// broadcast input(0) and input(1) -> output(0)
|
||||
void BinaryOpBroadcastInferShape(framework::InferShapeContext *ctx) {
|
||||
auto x_name = ctx->GetInputNameByIdx(0);
|
||||
auto y_name = ctx->GetInputNameByIdx(1);
|
||||
auto out_name = ctx->GetOutputNameByIdx(0);
|
||||
auto x_dims = ctx->GetInputDim(x_name);
|
||||
auto y_dims = ctx->GetInputDim(y_name);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->GetInputsVarType(y_name).front(),
|
||||
framework::proto::VarType::LOD_TENSOR,
|
||||
platform::errors::InvalidArgument(
|
||||
"The var type of input %s should be LoDTensor, but got %s.",
|
||||
ctx->Inputs(y_name).front(), ctx->GetInputsVarType(y_name).front()));
|
||||
|
||||
if (ctx->GetInputsVarType(x_name).front() ==
|
||||
framework::proto::VarType::SELECTED_ROWS) {
|
||||
PADDLE_ENFORCE_EQ(y_dims.size(), 1u,
|
||||
platform::errors::InvalidArgument(
|
||||
"For binary broadcastable operator, if X is "
|
||||
"Sparse(VarType.SELECTED_ROWS"
|
||||
"), Y must be scalar, and the size of Y should be 1. "
|
||||
"But reveived the size of Y = %s.",
|
||||
y_dims.size()));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
y_dims[0], 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"For binary broadcastable operator, if X is "
|
||||
"Sparse(VarType.SELECTED_ROWS"
|
||||
"), Y must be scalar, the first dimension of Y should be 1. "
|
||||
"But reveived the first dimension of Y = %s.",
|
||||
y_dims[0]));
|
||||
} else if (ctx->GetInputsVarType(x_name).front() !=
|
||||
framework::proto::VarType::LOD_TENSOR) {
|
||||
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||
"For binary broadcastable operator, the var type of input X should "
|
||||
"be LOD_TENSOR, but got %s",
|
||||
ctx->GetInputsVarType(x_name).front()));
|
||||
}
|
||||
|
||||
if (x_dims == y_dims) {
|
||||
ctx->ShareDim(x_name, /*->*/ out_name);
|
||||
ctx->ShareLoD(x_name, /*->*/ out_name);
|
||||
} else {
|
||||
int max_dim = std::max(x_dims.size(), y_dims.size());
|
||||
int axis = ctx->Attrs().Get<int>("axis");
|
||||
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
|
||||
std::vector<int> x_dims_array(max_dim);
|
||||
std::vector<int> y_dims_array(max_dim);
|
||||
std::vector<int> out_dims_array(max_dim);
|
||||
details::GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
|
||||
y_dims_array.data(), out_dims_array.data(),
|
||||
max_dim, axis);
|
||||
ctx->SetOutputDim(out_name, framework::make_ddim(out_dims_array));
|
||||
ctx->ShareLoD(x_name, /*->*/ out_name);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,35 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
|
||||
// This file almostly contains all the infershape functions that are used in
|
||||
// operators.
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// shape input(0) -> output(0) without change.
|
||||
void UnaryOpUnchangedInferShape(framework::InferShapeContext* ctx);
|
||||
// shape input(0) -> output(0) without change, check if axis in range [-Rank(x),
|
||||
// Rank(x)-1]
|
||||
void UnaryOpUnchangedInferShapeCheckAxis(framework::InferShapeContext* ctx);
|
||||
// broadcast input(0) and input(1) -> output(0)
|
||||
void BinaryOpBroadcastInferShape(framework::InferShapeContext* ctx);
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,145 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
#include "paddle/fluid/imperative/infer_shape_context.h"
|
||||
#include "paddle/fluid/imperative/layer.h"
|
||||
#include "paddle/fluid/operators/common_infer_shape_functions.h"
|
||||
|
||||
USE_OP(relu);
|
||||
USE_OP(elementwise_add);
|
||||
USE_OP(softmax);
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace details {
|
||||
|
||||
class DygraphInferShapeTest {
|
||||
public:
|
||||
void AddInput(const std::string& name, const framework::DDim& dim) {
|
||||
std::shared_ptr<imperative::VarBase> vin(
|
||||
new imperative::VarBase(false, name));
|
||||
vin->MutableVar()->GetMutable<framework::LoDTensor>()->Resize(dim);
|
||||
ins_[name] = {vin};
|
||||
}
|
||||
void AddOutput(const std::string& name, const framework::DDim& expected_dim) {
|
||||
std::shared_ptr<imperative::VarBase> vout(
|
||||
new imperative::VarBase(false, name));
|
||||
vout->MutableVar()
|
||||
->GetMutable<framework::LoDTensor>(); // InitializeVariable
|
||||
outs_[name] = {vout};
|
||||
expected_dims_[name] = expected_dim;
|
||||
}
|
||||
void AddAttrs(const framework::AttributeMap& attrs) { attrs_ = attrs; }
|
||||
void SetOpType(const std::string& op_type) { op_type_ = op_type; }
|
||||
void Run(std::function<void(framework::InferShapeContext* ctx)> infer_shape) {
|
||||
imperative::DygraphInferShapeContext<imperative::VarBase> ctx(
|
||||
&ins_, &outs_, &attrs_, op_type_);
|
||||
infer_shape(&ctx);
|
||||
for (const auto& pair : expected_dims_) {
|
||||
auto out = outs_[pair.first][0];
|
||||
ASSERT_EQ(pair.second,
|
||||
out->MutableVar()->GetMutable<framework::LoDTensor>()->dims());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
imperative::NameVarBaseMap ins_;
|
||||
imperative::NameVarBaseMap outs_;
|
||||
framework::AttributeMap attrs_;
|
||||
std::string op_type_;
|
||||
std::map<std::string, framework::DDim> expected_dims_;
|
||||
};
|
||||
} // namespace details
|
||||
|
||||
TEST(test_UnaryOpUnchangedInferShape, test_shape) {
|
||||
details::DygraphInferShapeTest test;
|
||||
test.AddInput("X", {2, 10});
|
||||
test.AddOutput("Out", {2, 10});
|
||||
test.SetOpType("relu");
|
||||
test.Run(UnaryOpUnchangedInferShape);
|
||||
}
|
||||
|
||||
TEST(test_BinaryOpBroadcastInferShape, test_same_shape) {
|
||||
details::DygraphInferShapeTest test;
|
||||
test.AddInput("X", {2, 3, 4, 5});
|
||||
test.AddInput("Y", {2, 3, 4, 5});
|
||||
test.AddOutput("Out", {2, 3, 4, 5});
|
||||
test.SetOpType("elementwise_add");
|
||||
test.Run(BinaryOpBroadcastInferShape);
|
||||
}
|
||||
|
||||
TEST(test_BinaryOpBroadcastInferShape, test_broadcast1) {
|
||||
details::DygraphInferShapeTest test;
|
||||
test.AddInput("X", {2, 3, 4, 5});
|
||||
test.AddInput("Y", {4, 5});
|
||||
test.AddOutput("Out", {2, 3, 4, 5});
|
||||
test.AddAttrs({
|
||||
{"axis", -1},
|
||||
});
|
||||
test.SetOpType("elementwise_add");
|
||||
test.Run(BinaryOpBroadcastInferShape);
|
||||
}
|
||||
|
||||
TEST(test_BinaryOpBroadcastInferShape, test_broadcast2) {
|
||||
details::DygraphInferShapeTest test;
|
||||
test.AddInput("X", {2, 10, 5, 1});
|
||||
test.AddInput("Y", {10, 1, 1});
|
||||
test.AddOutput("Out", {2, 10, 5, 1});
|
||||
test.AddAttrs({
|
||||
{"axis", -1},
|
||||
});
|
||||
test.SetOpType("elementwise_add");
|
||||
test.Run(BinaryOpBroadcastInferShape);
|
||||
}
|
||||
|
||||
TEST(test_BinaryOpBroadcastInferShape, test_broadcast3) {
|
||||
details::DygraphInferShapeTest test;
|
||||
test.AddInput("X", {10, 1, 1});
|
||||
test.AddInput("Y", {2, 10, 5, 5});
|
||||
test.AddOutput("Out", {2, 10, 5, 5});
|
||||
test.AddAttrs({
|
||||
{"axis", -1},
|
||||
});
|
||||
test.SetOpType("elementwise_add");
|
||||
test.Run(BinaryOpBroadcastInferShape);
|
||||
}
|
||||
|
||||
TEST(test_UnaryOpUnchangedInferShapeCheckAxis, test_shape) {
|
||||
details::DygraphInferShapeTest test;
|
||||
test.AddInput("X", {2, 10});
|
||||
test.AddOutput("Out", {2, 10});
|
||||
test.AddAttrs({
|
||||
{"axis", -1},
|
||||
});
|
||||
test.SetOpType("softmax");
|
||||
test.Run(UnaryOpUnchangedInferShapeCheckAxis);
|
||||
}
|
||||
|
||||
TEST(test_UnaryOpUnchangedInferShapeCheckAxis, test_axis_exception) {
|
||||
details::DygraphInferShapeTest test;
|
||||
test.AddInput("X", {2, 10});
|
||||
test.AddOutput("Out", {2, 10});
|
||||
test.AddAttrs({
|
||||
{"axis", 2},
|
||||
});
|
||||
test.SetOpType("softmax");
|
||||
ASSERT_ANY_THROW(test.Run(UnaryOpUnchangedInferShapeCheckAxis));
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue