|
|
|
@ -69,7 +69,8 @@ class OpInfo {
|
|
|
|
|
|
|
|
|
|
const OpCreator& Creator() const {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(creator_,
|
|
|
|
|
"Operator's Creator has not been registered");
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Operator's Creator has not been registered."));
|
|
|
|
|
return creator_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -79,11 +80,12 @@ class OpInfo {
|
|
|
|
|
std::string type = proto_ ? proto_->type() : "unknown";
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
grad_op_maker_,
|
|
|
|
|
"Operator %s's GradOpMaker has not been "
|
|
|
|
|
"registered.\nPlease check whether %s_op has "
|
|
|
|
|
"grad_op.\nIf not, please set stop_gradient to True "
|
|
|
|
|
"for its input and output variables using var.stop_gradient=True.",
|
|
|
|
|
type.c_str(), type.c_str());
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Operator %s's GradOpMaker has not been "
|
|
|
|
|
"registered.\nPlease check whether (%s) operator has "
|
|
|
|
|
"gradient operator.\nIf not, please set stop_gradient to be True "
|
|
|
|
|
"for its input and output variables using var.stop_gradient=True.",
|
|
|
|
|
type.c_str(), type.c_str()));
|
|
|
|
|
return grad_op_maker_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -100,11 +102,12 @@ class OpInfo {
|
|
|
|
|
std::string type = proto_ ? proto_->type() : "unknown";
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
dygraph_grad_op_maker_,
|
|
|
|
|
"Operator %s's DygraphGradOpMaker has not been "
|
|
|
|
|
"registered.\nPlease check whether %s_op has "
|
|
|
|
|
"grad_op.\nIf not, please set stop_gradient to True "
|
|
|
|
|
"for its input and output variables using var.stop_gradient=True.",
|
|
|
|
|
type.c_str(), type.c_str());
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Operator %s's DygraphGradOpMaker has not been "
|
|
|
|
|
"registered.\nPlease check whether (%s) operator has "
|
|
|
|
|
"gradient operator.\nIf not, please set stop_gradient to be True "
|
|
|
|
|
"for its input and output variables using var.stop_gradient=True.",
|
|
|
|
|
type.c_str(), type.c_str()));
|
|
|
|
|
return dygraph_grad_op_maker_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -130,14 +133,17 @@ class OpInfoMap {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Insert(const std::string& type, const OpInfo& info) {
|
|
|
|
|
PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type);
|
|
|
|
|
PADDLE_ENFORCE_NE(Has(type), true,
|
|
|
|
|
platform::errors::AlreadyExists(
|
|
|
|
|
"Operator (%s) has been registered.", type));
|
|
|
|
|
map_.insert({type, info});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const OpInfo& Get(const std::string& type) const {
|
|
|
|
|
auto op_info_ptr = GetNullable(type);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not been registered",
|
|
|
|
|
type);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
op_info_ptr,
|
|
|
|
|
platform::errors::NotFound("Operator (%s) is not registered.", type));
|
|
|
|
|
return *op_info_ptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|