|
|
|
@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/op_desc.h"
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include "paddle/framework/block_desc.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -185,5 +188,38 @@ void OpDescBind::Sync() {
|
|
|
|
|
need_update_ = false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using InferShapeFuncMap =
|
|
|
|
|
std::unordered_map<std::string /*op_type*/,
|
|
|
|
|
std::function<void(InferShapeContext *)>>;
|
|
|
|
|
|
|
|
|
|
static InferShapeFuncMap &InferShapeFuncs() {
|
|
|
|
|
static InferShapeFuncMap *g_map = nullptr;
|
|
|
|
|
if (g_map == nullptr) {
|
|
|
|
|
g_map = new InferShapeFuncMap();
|
|
|
|
|
auto &info_map = OpInfoMap::Instance();
|
|
|
|
|
// all registered kernels
|
|
|
|
|
for (auto &pair : OperatorWithKernel::AllOpKernels()) {
|
|
|
|
|
auto &info = info_map.Get(pair.first);
|
|
|
|
|
// use empty type here to avoid runtime checks.
|
|
|
|
|
auto op =
|
|
|
|
|
static_cast<OperatorWithKernel *>(info.Creator()("", {}, {}, {}));
|
|
|
|
|
g_map->insert(
|
|
|
|
|
{pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return *g_map;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpDescBind::InferShape(const BlockDescBind &block) const {
|
|
|
|
|
auto &funcs = InferShapeFuncs();
|
|
|
|
|
auto it = funcs.find(this->Type());
|
|
|
|
|
if (it == funcs.end()) {
|
|
|
|
|
PADDLE_THROW("Operator %s has not been registered", this->Type());
|
|
|
|
|
}
|
|
|
|
|
CompileTimeInferShapeContext ctx(*this, block);
|
|
|
|
|
it->second(&ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|