update the operator registration for incompatible upgrade, test=develop (#29720)

revert-31562-mean
石晓伟 4 years ago committed by GitHub
parent 10edfb6f21
commit 8bd2879ef7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -62,6 +62,37 @@ OpVersionDesc&& OpVersionDesc::BugfixWithBehaviorChanged(
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::DeleteAttr(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kDeleteAttr>(OpAttrInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::ModifyInput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kModifyInput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::ModifyOutput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kModifyOutput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::DeleteInput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kDeleteInput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::DeleteOutput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kDeleteOutput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersion& OpVersionRegistrar::Register(const std::string& op_type) {
PADDLE_ENFORCE_EQ(
op_version_map_.find(op_type), op_version_map_.end(),

@ -20,6 +20,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include <boost/none.hpp>
#include <boost/variant.hpp>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_version_proto.h"
@ -30,16 +31,17 @@ namespace framework {
namespace compatible {
using OpAttrVariantT =
boost::variant<bool, /* AttrType::BOOL */
float, /* AttrType::FLOAT */
int32_t, /* AttrType::INT */
int64_t, /* AttrType::LONG*/
std::string, /* AttrType::STRING */
std::vector<bool>, /* AttrType::BOOLS */
std::vector<float>, /* AttrType::FLOATS */
std::vector<int32_t>, /* AttrType::INTS */
std::vector<int64_t>, /* AttrType::LONGS */
std::vector<std::string> /* AttrType::STRINGS */
boost::variant<bool, /* AttrType::BOOL */
float, /* AttrType::FLOAT */
int32_t, /* AttrType::INT */
int64_t, /* AttrType::LONG*/
std::string, /* AttrType::STRING */
std::vector<bool>, /* AttrType::BOOLS */
std::vector<float>, /* AttrType::FLOATS */
std::vector<int32_t>, /* AttrType::INTS */
std::vector<int64_t>, /* AttrType::LONGS */
std::vector<std::string>, /* AttrType::STRINGS */
boost::none_t /* None */
>;
struct OpUpdateInfo {
@ -48,7 +50,7 @@ struct OpUpdateInfo {
struct OpAttrInfo : OpUpdateInfo {
OpAttrInfo(const std::string& name, const std::string& remark,
const OpAttrVariantT& default_value)
const OpAttrVariantT& default_value = boost::none)
: name_{name}, default_value_{default_value}, remark_{remark} {}
const std::string& name() const { return name_; }
@ -83,11 +85,18 @@ struct OpBugfixInfo : OpUpdateInfo {
enum class OpUpdateType {
kInvalid = 0,
/* Compatibility upgrade */
kModifyAttr,
kNewAttr,
kNewInput,
kNewOutput,
kBugfixWithBehaviorChanged,
/* Incompatible upgrade, only for existing registration. */
kDeleteAttr = 100,
kModifyInput,
kModifyOutput,
kDeleteInput,
kDeleteOutput,
};
class OpUpdateBase {
@ -111,6 +120,7 @@ class OpUpdate : public OpUpdateBase {
class OpVersionDesc {
public:
/* Compatibility upgrade */
OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark,
const OpAttrVariantT& default_value);
OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark,
@ -118,10 +128,23 @@ class OpVersionDesc {
OpVersionDesc&& NewInput(const std::string& name, const std::string& remark);
OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark);
OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark);
/* Incompatible upgrade, only for existing registration. */
OpVersionDesc&& DeleteAttr(const std::string& name,
const std::string& remark);
OpVersionDesc&& ModifyInput(const std::string& name,
const std::string& remark);
OpVersionDesc&& ModifyOutput(const std::string& name,
const std::string& remark);
OpVersionDesc&& DeleteInput(const std::string& name,
const std::string& remark);
OpVersionDesc&& DeleteOutput(const std::string& name,
const std::string& remark);
public:
const std::vector<std::unique_ptr<OpUpdateBase>>& infos() const {
return infos_;
}
OpVersionDesc() = default;
OpVersionDesc(OpVersionDesc&&) = default;
OpVersionDesc& operator=(OpVersionDesc&&) = default;

@ -53,6 +53,19 @@ TEST(test_operator_version, test_operator_version) {
framework::compatible::OpVersionDesc()
.NewInput("X2", "The second input.")
.NewOutput("Y2", "The second output."));
REGISTER_OP_VERSION(op_name_0__)
.AddCheckpoint(
R"ROC(
Incompatible upgrade of attribute [height], input [X2] and output [Y2]
)ROC",
framework::compatible::OpVersionDesc()
.DeleteAttr("height",
"Parameters deleted due to interface alignment.")
.ModifyInput("X2", "Modify input due to interface alignment.")
.ModifyOutput("Y2", "Modify output due to interface alignment.")
.DeleteInput("X2", "Delete input due to interface alignment.")
.DeleteOutput("Y2", "Delete output due to interface alignment."));
}
TEST(test_pass_op_version_checker, test_pass_op_version_checker) {

Loading…
Cancel
Save