adapte Second order optimization ops

pull/1413/head
jjfeing 5 years ago
parent c8f69f5db2
commit c312b47ae1

@ -19,5 +19,6 @@ from .aicpu import *
if "Windows" not in platform.system():
from .akg.gpu import *
from .tbe import *
from ._custom_op import *
__all__ = []

@ -0,0 +1,16 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""custom ops"""

@ -23,6 +23,7 @@ from mindspore._checkparam import Validator as validator
# path of built-in op info register.
BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl"
BUILT_IN_CUSTOM_OPS_REGISTER_PATH = "mindspore/ops/_op_impl/_custom_op"
def op_info_register(op_info):
@ -47,7 +48,10 @@ def op_info_register(op_info):
op_lib = Oplib()
file_path = os.path.realpath(inspect.getfile(func))
# keep the path custom ops implementation.
imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path
if BUILT_IN_CUSTOM_OPS_REGISTER_PATH in file_path:
imply_path = file_path
else:
imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path
if not op_lib.reg_op(op_info_real, imply_path):
raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info_real))

Loading…
Cancel
Save