|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/op_compatible_info.h"
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
#include "paddle/fluid/string/string_helper.h"
|
|
|
|
@ -72,7 +73,7 @@ void OpCompatibleMap::InitOpCompatibleMap() {
|
|
|
|
|
op_compatible_map_["layer_norm"] = {"1.6.0", OpCompatibleType::bug_fix};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CompatibleInfo OpCompatibleMap::GetOpCompatibleInfo(std::string op_name) {
|
|
|
|
|
CompatibleInfo OpCompatibleMap::GetOpCompatibleInfo(std::string op_name) const {
|
|
|
|
|
auto it = op_compatible_map_.find(op_name);
|
|
|
|
|
if (it != op_compatible_map_.end()) {
|
|
|
|
|
return it->second;
|
|
|
|
@ -82,7 +83,7 @@ CompatibleInfo OpCompatibleMap::GetOpCompatibleInfo(std::string op_name) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpCompatibleType OpCompatibleMap::IsRequireMiniVersion(
|
|
|
|
|
std::string op_name, std::string str_current_version) {
|
|
|
|
|
std::string op_name, std::string str_current_version) const {
|
|
|
|
|
auto it = op_compatible_map_.find(op_name);
|
|
|
|
|
if (it != op_compatible_map_.end()) {
|
|
|
|
|
if (CompareVersion(str_current_version, it->second.required_version_)) {
|
|
|
|
@ -100,5 +101,40 @@ OpCompatibleType OpCompatibleMap::IsRequireMiniVersion(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool OpCompatibleMap::ConvertToProto(proto::OpCompatibleMap* desc) const {
|
|
|
|
|
desc->Clear();
|
|
|
|
|
desc->set_default_required_version(default_required_version_);
|
|
|
|
|
for (auto pair : op_compatible_map_) {
|
|
|
|
|
const CompatibleInfo& info = pair.second;
|
|
|
|
|
auto* pair_desc = desc->add_pair();
|
|
|
|
|
pair_desc->set_op_name(pair.first);
|
|
|
|
|
auto* info_desc = pair_desc->mutable_compatible_info();
|
|
|
|
|
info_desc->set_version(info.required_version_);
|
|
|
|
|
info_desc->set_type(
|
|
|
|
|
static_cast<proto::CompatibleInfo_Type>(info.compatible_type_));
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool OpCompatibleMap::ReadFromProto(const proto::OpCompatibleMap& desc) {
|
|
|
|
|
std::string version = desc.default_required_version();
|
|
|
|
|
if (version.empty()) {
|
|
|
|
|
LOG(INFO) << "The default operator required version is missing."
|
|
|
|
|
" Please update the model version.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
op_compatible_map_.clear();
|
|
|
|
|
default_required_version_ = desc.default_required_version();
|
|
|
|
|
for (int i = 0; i < desc.pair_size(); ++i) {
|
|
|
|
|
const auto& pair_desc = desc.pair(i);
|
|
|
|
|
auto info_desc = pair_desc.compatible_info();
|
|
|
|
|
CompatibleInfo info(info_desc.version(),
|
|
|
|
|
static_cast<OpCompatibleType>(info_desc.type()));
|
|
|
|
|
std::pair<std::string, CompatibleInfo> pair(pair_desc.op_name(), info);
|
|
|
|
|
op_compatible_map_.insert(pair);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|