enhance the op_version_registry, test=develop (#28347)

* enhance the op_version_registry, test=develop

* add unittests, test=develop

* enhance the op_version_registry, test=develop

* fix bugs, test=develop

* revert pybind_boost_headers.h, test=develop

* fix a attribute bug, test=develop
TCChenlong-patch-1
石晓伟 5 years ago committed by GitHub
parent c1c3e21726
commit 21a63f6f90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,9 +23,9 @@ function(pass_library TARGET DEST)
cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if(pass_library_DIR)
cc_library(${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS})
cc_library(${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry ${pass_library_DEPS})
else()
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS})
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry ${pass_library_DEPS})
endif()
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically.

@ -13,3 +13,75 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace compatible {
namespace {
template <OpUpdateType type__, typename InfoType>
OpUpdate<InfoType, type__>* new_update(InfoType&& info) {
return new OpUpdate<InfoType, type__>(info);
}
}
OpVersionDesc&& OpVersionDesc::ModifyAttr(const std::string& name,
const std::string& remark,
const OpAttrVariantT& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kModifyAttr>(
OpAttrInfo(name, remark, default_value)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::NewAttr(const std::string& name,
const std::string& remark,
const OpAttrVariantT& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kNewAttr>(
OpAttrInfo(name, remark, default_value)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::NewInput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kNewInput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::NewOutput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kNewOutput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::BugfixWithBehaviorChanged(
const std::string& remark) {
infos_.emplace_back(new_update<OpUpdateType::kBugfixWithBehaviorChanged>(
OpBugfixInfo(remark)));
return std::move(*this);
}
OpVersion& OpVersionRegistrar::Register(const std::string& op_type) {
PADDLE_ENFORCE_EQ(
op_version_map_.find(op_type), op_version_map_.end(),
platform::errors::AlreadyExists(
"'%s' is registered in operator version more than once.", op_type));
op_version_map_.insert(
std::pair<std::string, OpVersion>{op_type, OpVersion()});
return op_version_map_[op_type];
}
uint32_t OpVersionRegistrar::version_id(const std::string& op_type) const {
PADDLE_ENFORCE_NE(
op_version_map_.count(op_type), 0,
platform::errors::InvalidArgument(
"The version of operator type %s has not been registered.", op_type));
return op_version_map_.find(op_type)->second.version_id();
}
// Provide a fake registration item for pybind testing.
#include "paddle/fluid/framework/op_version_registry.inl"
} // namespace compatible
} // namespace framework
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -0,0 +1,42 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
REGISTER_OP_VERSION(for_pybind_test__)
.AddCheckpoint("Note 0", framework::compatible::OpVersionDesc()
.BugfixWithBehaviorChanged(
"BugfixWithBehaviorChanged Remark"))
.AddCheckpoint("Note 1", framework::compatible::OpVersionDesc()
.ModifyAttr("BOOL", "bool", true)
.ModifyAttr("FLOAT", "float", 1.23f)
.ModifyAttr("INT", "int32", -1)
.ModifyAttr("STRING", "std::string",
std::string{"hello"}))
.AddCheckpoint("Note 2",
framework::compatible::OpVersionDesc()
.ModifyAttr("BOOLS", "std::vector<bool>",
std::vector<bool>{true, false})
.ModifyAttr("FLOATS", "std::vector<float>",
std::vector<float>{2.56f, 1.28f})
.ModifyAttr("INTS", "std::vector<int32>",
std::vector<int32_t>{10, 100})
.NewAttr("LONGS", "std::vector<int64>",
std::vector<int64_t>{10000001, -10000001}))
.AddCheckpoint("Note 3", framework::compatible::OpVersionDesc()
.NewAttr("STRINGS", "std::vector<std::string>",
std::vector<std::string>{"str1", "str2"})
.ModifyAttr("LONG", "int64", static_cast<int64_t>(10000001))
.NewInput("NewInput", "NewInput_")
.NewOutput("NewOutput", "NewOutput_")
.BugfixWithBehaviorChanged(
"BugfixWithBehaviorChanged_"));

@ -21,7 +21,7 @@ namespace framework {
namespace compatible {
TEST(test_operator_version, test_operator_version) {
REGISTER_OP_VERSION(test__)
REGISTER_OP_VERSION(op_name__)
.AddCheckpoint(
R"ROC(Fix the bug of reshape op, support the case of axis < 0)ROC",
framework::compatible::OpVersionDesc().BugfixWithBehaviorChanged(
@ -56,6 +56,7 @@ TEST(test_operator_version, test_operator_version) {
}
TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
const std::string fake_op_name{"op_name__"};
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"no_bind_pass"));
@ -90,7 +91,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
REGISTER_PASS_CAPABILITY(test_pass4)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("test__", 5)
.GE(fake_op_name, 5)
.EQ("fc", 0));
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass4"));
@ -98,7 +99,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
REGISTER_PASS_CAPABILITY(test_pass5)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("test__", 4)
.GE(fake_op_name, 4)
.EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass5"));
@ -106,7 +107,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
REGISTER_PASS_CAPABILITY(test_pass6)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("test__", 4)
.EQ(fake_op_name, 4)
.EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass6"));
@ -114,7 +115,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
REGISTER_PASS_CAPABILITY(test_pass7)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.NE("test__", 4)
.NE(fake_op_name, 4)
.EQ("fc", 0));
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass7"));

@ -104,6 +104,7 @@ endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} tensor_formatter)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} op_version_registry)
# FIXME(typhoonzero): operator deps may not needed.
# op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)

@ -1,7 +1,7 @@
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune
feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator)
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry)
if (WITH_NCCL)
set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper)

@ -13,26 +13,136 @@
// limitations under the License.
#include "paddle/fluid/pybind/compatible.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h"
namespace py = pybind11;
using paddle::framework::compatible::PassVersionCheckerRegistrar;
using paddle::framework::compatible::OpAttrVariantT;
using paddle::framework::compatible::OpUpdateInfo;
using paddle::framework::compatible::OpAttrInfo;
using paddle::framework::compatible::OpInputOutputInfo;
using paddle::framework::compatible::OpBugfixInfo;
using paddle::framework::compatible::OpUpdateType;
using paddle::framework::compatible::OpUpdateBase;
using paddle::framework::compatible::OpVersionDesc;
using paddle::framework::compatible::OpCheckpoint;
using paddle::framework::compatible::OpVersion;
namespace paddle {
namespace pybind {
void BindCompatible(py::module* m) {
namespace {
using paddle::framework::compatible::PassVersionCheckerRegistrar;
void BindPassVersionChecker(py::module *m) {
py::class_<PassVersionCheckerRegistrar>(*m, "PassVersionChecker")
.def_static("IsCompatible", [](const std::string& name) -> bool {
.def_static("IsCompatible", [](const std::string &name) -> bool {
auto instance = PassVersionCheckerRegistrar::GetInstance();
return instance.IsPassCompatible(name);
});
}
void BindPassCompatible(py::module *m) { BindPassVersionChecker(m); }
void BindOpUpdateInfo(py::module *m) {
py::class_<OpUpdateInfo>(*m, "OpUpdateInfo").def(py::init<>());
}
void BindOpAttrInfo(py::module *m) {
py::class_<OpAttrInfo, OpUpdateInfo>(*m, "OpAttrInfo")
.def(py::init<const std::string &, const std::string &,
const OpAttrVariantT &>())
.def(py::init<const OpAttrInfo &>())
.def("name", &OpAttrInfo::name)
.def("default_value", &OpAttrInfo::default_value)
.def("remark", &OpAttrInfo::remark);
}
void BindOpInputOutputInfo(py::module *m) {
py::class_<OpInputOutputInfo, OpUpdateInfo>(*m, "OpInputOutputInfo")
.def(py::init<const std::string &, const std::string &>())
.def(py::init<const OpInputOutputInfo &>())
.def("name", &OpInputOutputInfo::name)
.def("remark", &OpInputOutputInfo::remark);
}
void BindOpBugfixInfo(py::module *m) {
py::class_<OpBugfixInfo, OpUpdateInfo>(*m, "OpBugfixInfo")
.def(py::init<const std::string &>())
.def(py::init<const OpBugfixInfo &>())
.def("remark", &OpBugfixInfo::remark);
}
void BindOpCompatible(py::module *m) {
BindOpUpdateInfo(m);
BindOpAttrInfo(m);
BindOpInputOutputInfo(m);
BindOpBugfixInfo(m);
}
void BindOpUpdateType(py::module *m) {
py::enum_<OpUpdateType>(*m, "OpUpdateType")
.value("kInvalid", OpUpdateType::kInvalid)
.value("kModifyAttr", OpUpdateType::kModifyAttr)
.value("kNewAttr", OpUpdateType::kNewAttr)
.value("kNewInput", OpUpdateType::kNewInput)
.value("kNewOutput", OpUpdateType::kNewOutput)
.value("kBugfixWithBehaviorChanged",
OpUpdateType::kBugfixWithBehaviorChanged);
}
void BindOpUpdateBase(py::module *m) {
py::class_<OpUpdateBase>(*m, "OpUpdateBase")
.def("info", [](const OpUpdateBase &obj) { return obj.info(); },
py::return_value_policy::reference)
.def("type", &OpUpdateBase::type);
}
void BindOpVersionDesc(py::module *m) {
py::class_<OpVersionDesc>(*m, "OpVersionDesc")
// Pybind11 does not yet support the transfer of `const
// std::vector<std::unique_ptr<T>>&` type objects.
.def("infos", [](const OpVersionDesc &obj) {
auto pylist = py::list();
for (const auto &ptr : obj.infos()) {
auto pyobj = py::cast(*ptr, py::return_value_policy::reference);
pylist.append(pyobj);
}
return pylist;
});
}
void BindOpCheckpoint(py::module *m) {
py::class_<OpCheckpoint>(*m, "OpCheckpoint")
.def("note", &OpCheckpoint::note, py::return_value_policy::reference)
.def("version_desc", &OpCheckpoint::version_desc,
py::return_value_policy::reference);
}
void BindOpVersion(py::module *m) {
py::class_<OpVersion>(*m, "OpVersion")
.def("version_id", &OpVersion::version_id,
py::return_value_policy::reference)
.def("checkpoints", &OpVersion::checkpoints,
py::return_value_policy::reference);
// At least pybind v2.3.0 is required because of bug #1603 of pybind11.
m->def("get_op_version_map", &framework::compatible::get_op_version_map,
py::return_value_policy::reference);
}
} // namespace
void BindCompatible(py::module *m) {
BindPassCompatible(m);
BindOpCompatible(m);
BindOpUpdateType(m);
BindOpUpdateBase(m);
BindOpVersionDesc(m);
BindOpCheckpoint(m);
BindOpVersion(m);
}
} // namespace pybind
} // namespace paddle

@ -0,0 +1,83 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.utils as utils
import paddle.fluid as fluid
class OpLastCheckpointCheckerTest(unittest.TestCase):
def __init__(self, methodName='runTest'):
super(OpLastCheckpointCheckerTest, self).__init__(methodName)
self.checker = utils.OpLastCheckpointChecker()
self.fake_op = 'for_pybind_test__'
def test_op_attr_info(self):
update_type = fluid.core.OpUpdateType.kNewAttr
info_list = self.checker.filter_updates(self.fake_op, update_type,
'STRINGS')
self.assertTrue(info_list)
self.assertEqual(info_list[0].name(), 'STRINGS')
self.assertEqual(info_list[0].default_value(), ['str1', 'str2'])
self.assertEqual(info_list[0].remark(), 'std::vector<std::string>')
def test_op_input_output_info(self):
update_type = fluid.core.OpUpdateType.kNewInput
info_list = self.checker.filter_updates(self.fake_op, update_type,
'NewInput')
self.assertTrue(info_list)
self.assertEqual(info_list[0].name(), 'NewInput')
self.assertEqual(info_list[0].remark(), 'NewInput_')
def test_op_bug_fix_info(self):
update_type = fluid.core.OpUpdateType.kBugfixWithBehaviorChanged
info_list = self.checker.filter_updates(self.fake_op, update_type)
self.assertTrue(info_list)
self.assertEqual(info_list[0].remark(), 'BugfixWithBehaviorChanged_')
class OpVersionTest(unittest.TestCase):
def __init__(self, methodName='runTest'):
super(OpVersionTest, self).__init__(methodName)
self.vmap = fluid.core.get_op_version_map()
self.fake_op = 'for_pybind_test__'
def test_checkpoints(self):
version_id = self.vmap[self.fake_op].version_id()
checkpoints = self.vmap[self.fake_op].checkpoints()
self.assertEqual(version_id, 4)
self.assertEqual(len(checkpoints), 4)
self.assertEqual(checkpoints[2].note(), 'Note 2')
desc_1 = checkpoints[1].version_desc().infos()
self.assertEqual(desc_1[0].info().default_value(), True)
self.assertAlmostEqual(desc_1[1].info().default_value(), 1.23, 2)
self.assertEqual(desc_1[2].info().default_value(), -1)
self.assertEqual(desc_1[3].info().default_value(), 'hello')
desc_2 = checkpoints[2].version_desc().infos()
self.assertEqual(desc_2[0].info().default_value(), [True, False])
true_l = [2.56, 1.28]
self.assertEqual(len(true_l), len(desc_2[1].info().default_value()))
for i in range(len(true_l)):
self.assertAlmostEqual(desc_2[1].info().default_value()[i],
true_l[i], 2)
self.assertEqual(desc_2[2].info().default_value(), [10, 100])
self.assertEqual(desc_2[3].info().default_value(),
[10000001, -10000001])
if __name__ == '__main__':
unittest.main()

@ -17,6 +17,7 @@ from .profiler import Profiler
from .profiler import get_profiler
from .deprecated import deprecated
from .lazy_import import try_import
from .op_version import OpLastCheckpointChecker
from .install_check import run_check
from ..fluid.framework import unique_name
from ..fluid.framework import load_op_library

@ -0,0 +1,70 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..fluid import core
__all__ = ['OpLastCheckpointChecker']
def Singleton(cls):
_instance = {}
def _singleton(*args, **kargs):
if cls not in _instance:
_instance[cls] = cls(*args, **kargs)
return _instance[cls]
return _singleton
class OpUpdateInfoHelper(object):
def __init__(self, info):
self._info = info
def verify_key_value(self, name=''):
result = False
key_funcs = {
core.OpAttrInfo: 'name',
core.OpInputOutputInfo: 'name',
}
if name == '':
result = True
elif type(self._info) in key_funcs:
if getattr(self._info, key_funcs[type(self._info)])() == name:
result = True
return result
@Singleton
class OpLastCheckpointChecker(object):
def __init__(self):
self.raw_version_map = core.get_op_version_map()
self.checkpoints_map = {}
self._construct_map()
def _construct_map(self):
for op_name in self.raw_version_map:
last_checkpoint = self.raw_version_map[op_name].checkpoints()[-1]
infos = last_checkpoint.version_desc().infos()
self.checkpoints_map[op_name] = infos
def filter_updates(self, op_name, type=core.OpUpdateType.kInvalid, key=''):
updates = []
if op_name in self.checkpoints_map:
for update in self.checkpoints_map[op_name]:
if (update.type() == type) or (
type == core.OpUpdateType.kInvalid):
if OpUpdateInfoHelper(update.info()).verify_key_value(key):
updates.append(update.info())
return updates
Loading…
Cancel
Save