!10885 add nll_loss operation
From: @jiangzg001 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghuipull/10885/MERGE
commit
274e0aa750
@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "transform/graph_ir/op_declare/math_ops_declare.h"
|
||||
|
||||
namespace mindspore::transform {
|
||||
// NLLLoss
|
||||
INPUT_MAP(NLLLoss) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(target)}, {3, INPUT_DESC(weight)}};
|
||||
ATTR_MAP(NLLLoss) = {{"reduction", ATTR_DESC(reduction, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(NLLLoss) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(total_weight)}};
|
||||
REG_ADPT_DESC(NLLLoss, kNameNLLLoss, ADPT_DESC(NLLLoss))
|
||||
|
||||
// NLLLossGrad
|
||||
INPUT_MAP(NLLLossGrad) = {{1, INPUT_DESC(x)},
|
||||
{2, INPUT_DESC(y_grad)},
|
||||
{3, INPUT_DESC(target)},
|
||||
{4, INPUT_DESC(weight)},
|
||||
{5, INPUT_DESC(total_weight)}};
|
||||
ATTR_MAP(NLLLossGrad) = {{"reduction", ATTR_DESC(reduction, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(NLLLossGrad) = {{0, OUTPUT_DESC(x_grad)}};
|
||||
REG_ADPT_DESC(NLLLossGrad, kNameNLLLossGrad, ADPT_DESC(NLLLossGrad))
|
||||
} // namespace mindspore::transform
|
@ -0,0 +1,31 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_
|
||||
#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "transform/graph_ir/op_declare/op_declare_macro.h"
|
||||
#include "ops/math_ops.h"
|
||||
|
||||
namespace mindspore::transform {
|
||||
DECLARE_OP_ADAPTER(NLLLoss)
|
||||
DECLARE_OP_USE_OUTPUT(NLLLoss)
|
||||
DECLARE_OP_ADAPTER(NLLLossGrad)
|
||||
DECLARE_OP_USE_OUTPUT(NLLLossGrad)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_
|
@ -0,0 +1,40 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""NLLLoss op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
nll_loss_op_info = TBERegOp("NLLLoss") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("nll_loss.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("nll_loss") \
|
||||
.partial_flag(True) \
|
||||
.attr("reduction", "optional", "str", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "target", False, "required", "all") \
|
||||
.input(2, "weight", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.output(1, "total_weight", False, "optional", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(nll_loss_op_info)
|
||||
def _nll_loss_tbe():
|
||||
"""NLLLoss TBE register"""
|
||||
return
|
@ -0,0 +1,41 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""NLLLossGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
nll_loss_grad_op_info = TBERegOp("NLLLossGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("nll_loss_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("nll_loss_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("reduction", "optional", "str", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "y_grad", False, "required", "all") \
|
||||
.input(2, "target", False, "required", "all") \
|
||||
.input(3, "weight", False, "required", "all") \
|
||||
.input(4, "total_weight", False, "required", "all") \
|
||||
.output(0, "x_grad", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(nll_loss_grad_op_info)
|
||||
def _nll_loss_grad_tbe():
|
||||
"""NLLLossGrad TBE register"""
|
||||
return
|
Loading…
Reference in new issue