fix a bug in op_version_registry, test=develop, test=op_version (#29994)

revert-31562-mean
石晓伟 5 years ago committed by GitHub
parent 3e0c492910
commit 53bb126510
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,29 +18,6 @@ namespace paddle {
namespace framework {
namespace compatible {
namespace {
template <OpUpdateType type__, typename InfoType>
OpUpdate<InfoType, type__>* new_update(InfoType&& info) {
return new OpUpdate<InfoType, type__>(info);
}
}
OpVersionDesc&& OpVersionDesc::ModifyAttr(const std::string& name,
const std::string& remark,
const OpAttrVariantT& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kModifyAttr>(
OpAttrInfo(name, remark, default_value)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::NewAttr(const std::string& name,
const std::string& remark,
const OpAttrVariantT& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kNewAttr>(
OpAttrInfo(name, remark, default_value)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::NewInput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(

@ -118,13 +118,44 @@ class OpUpdate : public OpUpdateBase {
OpUpdateType type_;
};
template <OpUpdateType type__, typename InfoType>
OpUpdate<InfoType, type__>* new_update(InfoType&& info) {
return new OpUpdate<InfoType, type__>(info);
}
template <typename T>
OpAttrVariantT op_attr_wrapper(const T& val) {
return OpAttrVariantT{val};
}
template <int N>
OpAttrVariantT op_attr_wrapper(const char (&val)[N]) {
PADDLE_ENFORCE_EQ(
val[N - 1], 0,
platform::errors::InvalidArgument(
"The argument of operator register %c is illegal.", val[N - 1]));
return OpAttrVariantT{std::string{val}};
}
class OpVersionDesc {
public:
/* Compatibility upgrade */
template <typename T>
OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark,
const OpAttrVariantT& default_value);
const T& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kModifyAttr>(
OpAttrInfo(name, remark, op_attr_wrapper(default_value))));
return std::move(*this);
}
template <typename T>
OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark,
const OpAttrVariantT& default_value);
const T& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kNewAttr>(
OpAttrInfo(name, remark, op_attr_wrapper(default_value))));
return std::move(*this);
}
OpVersionDesc&& NewInput(const std::string& name, const std::string& remark);
OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark);
OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark);

@ -661,7 +661,7 @@ REGISTER_OP_VERSION(conv_transpose)
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
{}));
std::vector<int>{}));
REGISTER_OP_VERSION(conv2d_transpose)
.AddCheckpoint(
@ -672,7 +672,7 @@ REGISTER_OP_VERSION(conv2d_transpose)
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
{}));
std::vector<int>{}));
REGISTER_OP_VERSION(conv3d_transpose)
.AddCheckpoint(
@ -683,7 +683,7 @@ REGISTER_OP_VERSION(conv3d_transpose)
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
{}));
std::vector<int>{}));
REGISTER_OP_VERSION(depthwise_conv2d_transpose)
.AddCheckpoint(
@ -694,4 +694,4 @@ REGISTER_OP_VERSION(depthwise_conv2d_transpose)
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
{}));
std::vector<int>{}));

@ -489,4 +489,4 @@ REGISTER_OP_VERSION(fusion_gru)
"Scale_weights",
"The added attribute 'Scale_weights' is not yet "
"registered.",
{1.0f}));
std::vector<float>{1.0f}));

@ -184,7 +184,7 @@ REGISTER_OP_VERSION(unique)
.NewAttr("axis",
"The axis to apply unique. If None, the input will be "
"flattened.",
{})
std::vector<int>{})
.NewAttr("is_sorted",
"If True, the unique elements of X are in ascending order."
"Otherwise, the unique elements are not sorted.",

Loading…
Cancel
Save