|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/op_desc.h"
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <mutex>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include "paddle/framework/block_desc.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
@ -229,26 +230,26 @@ void OpDescBind::Flush() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using InferShapeFuncMap =
|
|
|
|
|
std::unordered_map<std::string /*op_type*/,
|
|
|
|
|
std::function<void(InferShapeContext *)>>;
|
|
|
|
|
|
|
|
|
|
static InferShapeFuncMap &InferShapeFuncs() {
|
|
|
|
|
static InferShapeFuncMap *g_map = nullptr;
|
|
|
|
|
if (g_map == nullptr) {
|
|
|
|
|
g_map = new InferShapeFuncMap();
|
|
|
|
|
auto &info_map = OpInfoMap::Instance();
|
|
|
|
|
// all registered kernels
|
|
|
|
|
for (auto &pair : OperatorWithKernel::AllOpKernels()) {
|
|
|
|
|
auto &info = info_map.Get(pair.first);
|
|
|
|
|
// use empty type here to avoid runtime checks.
|
|
|
|
|
static std::once_flag init_infer_shape_funcs;
|
|
|
|
|
|
|
|
|
|
static void InitInferShapeFuncs() {
|
|
|
|
|
std::call_once(init_infer_shape_funcs, [] {
|
|
|
|
|
auto &map = OpInfoMap::Instance();
|
|
|
|
|
auto &info_map = *map.mutable_map();
|
|
|
|
|
|
|
|
|
|
for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
|
|
|
|
|
auto op_type = kern_pair.first;
|
|
|
|
|
auto &op_info = info_map.at(op_type);
|
|
|
|
|
auto op =
|
|
|
|
|
static_cast<OperatorWithKernel *>(info.Creator()("", {}, {}, {}));
|
|
|
|
|
g_map->insert(
|
|
|
|
|
{pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }});
|
|
|
|
|
static_cast<OperatorWithKernel *>(op_info.Creator()("", {}, {}, {}));
|
|
|
|
|
if (op_info.infer_shape_) { // infer_shape has been registered.
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
op_info.infer_shape_ = [op](InferShapeContext *ctx) {
|
|
|
|
|
op->InferShape(ctx);
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return *g_map;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpDescBind::CheckAttrs() {
|
|
|
|
@ -265,13 +266,12 @@ void OpDescBind::CheckAttrs() {
|
|
|
|
|
|
|
|
|
|
void OpDescBind::InferShape(const BlockDescBind &block) const {
|
|
|
|
|
VLOG(3) << "CompileTime infer shape on " << Type();
|
|
|
|
|
auto &funcs = InferShapeFuncs();
|
|
|
|
|
auto it = funcs.find(this->Type());
|
|
|
|
|
if (it == funcs.end()) {
|
|
|
|
|
PADDLE_THROW("Operator %s has not been registered", this->Type());
|
|
|
|
|
}
|
|
|
|
|
InitInferShapeFuncs();
|
|
|
|
|
auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
|
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(infer_shape),
|
|
|
|
|
"%s's infer_shape has not been registered", this->Type());
|
|
|
|
|
CompileTimeInferShapeContext ctx(*this, block);
|
|
|
|
|
it->second(&ctx);
|
|
|
|
|
infer_shape(&ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpDescBind::InferVarType(BlockDescBind *block) const {
|
|
|
|
|