adapt Softplus\SoftplusGrad\ApplyFtrlD for VM

pull/1062/head
liuxiao 5 years ago
parent 7068e708de
commit 4335e9bc83

@ -33,6 +33,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"softmax", "softmax_v2"},
{"log_softmax", "log_softmax_v2"},
{"apply_momentum", "apply_momentum_d"},
{"apply_ftrl", "apply_ftrl_d"},
{"re_lu6", "relu6"},
{"re_lu6_grad", "relu6_grad"},
{"re_lu", "relu"},

@ -384,7 +384,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},
{string(kNameSign), ADPT_DESC(Sign)},
{string(kNameRound), ADPT_DESC(Round)},
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)},
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrlD)},
{string(kNameDiag), ADPT_DESC(Diag)},
{string(kNameDiagPart), ADPT_DESC(DiagPart)},
{string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)},

@ -1176,11 +1176,11 @@ ATTR_MAP(Round) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}};
// ApplyFtrl
INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)},
{4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)},
{7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}};
ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}};
INPUT_MAP(ApplyFtrlD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)},
{4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)},
{7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}};
ATTR_MAP(ApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}};
// Diag
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};

@ -446,8 +446,8 @@ DECLARE_OP_ADAPTER(LarsV2Update)
DECLARE_OP_USE_OUTPUT(LarsV2Update)
DECLARE_OP_ADAPTER(Round)
DECLARE_OP_USE_OUTPUT(Round)
DECLARE_OP_ADAPTER(ApplyFtrl)
DECLARE_OP_USE_OUTPUT(ApplyFtrl)
DECLARE_OP_ADAPTER(ApplyFtrlD)
DECLARE_OP_USE_OUTPUT(ApplyFtrlD)
DECLARE_OP_ADAPTER(SparseApplyFtrlD)
DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD)
DECLARE_OP_ADAPTER(Diag)

@ -326,6 +326,18 @@ def get_bprop_log_softmax(self):
return bprop
@bprop_getters.register(P.Softplus)
def get_bprop_softplus(self):
"""Grad definition for `Softplus` operation."""
softplus_grad = G.SoftplusGrad()
def bprop(x, out, dout):
dx = softplus_grad(dout, x)
return (dx,)
return bprop
@bprop_getters.register(P.Tanh)
def get_bprop_tanh(self):
"""Grad definition for `Tanh` operation."""

@ -100,6 +100,8 @@ from .round import _round_tbe
from .tanh import _tanh_tbe
from .tanh_grad import _tanh_grad_tbe
from .softmax import _softmax_tbe
from .softplus import _softplus_tbe
from .softplus_grad import _softplus_grad_tbe
from .square import _square_tbe
from .sqrt import _sqrt_tbe
from .transpose_d import _transpose_d_tbe

@ -32,30 +32,32 @@ apply_ftrl_op_info = TBERegOp("ApplyFtrl") \
.input(6, "l2", False, "required", "all") \
.input(7, "lr_power", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.output(2, "linear", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_5HD) \
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_FracZ) \
DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_C1HWNCoC0) \
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_5HD) \
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_FracZ) \
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_C1HWNCoC0) \
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Softplus op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
softplus_op_info = TBERegOp("Softplus") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("softplus.so") \
.compute_cost(10) \
.kernel_name("softplus") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(softplus_op_info)
def _softplus_tbe():
"""Softplus TBE register"""
return

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SoftplusGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
softplus_grad_op_info = TBERegOp("SoftplusGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("softplus_grad.so") \
.compute_cost(10) \
.kernel_name("softplus_grad") \
.partial_flag(True) \
.op_pattern("broadcast") \
.input(0, "gradients", False, "required", "all") \
.input(1, "features", False, "required", "all") \
.output(0, "backprops", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(softplus_grad_op_info)
def _softplus_grad_tbe():
"""SoftplusGrad TBE register"""
return

@ -62,7 +62,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid,
SigmoidCrossEntropyWithLogits,
SmoothL1Loss, Softmax,
SmoothL1Loss, Softmax, Softplus,
SoftmaxCrossEntropyWithLogits, ROIAlign,
SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl,

@ -974,6 +974,23 @@ class StridedSliceGrad(PrimitiveWithInfer):
'value': None}
class SoftplusGrad(PrimitiveWithInfer):
"""Computes gradient for the Log Softmax activation."""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['dout', 'x'], outputs=['output'])
def infer_shape(self, dout_shape, x_shape):
validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, dout_dtype, x_dtype):
args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype}
validator.check_tensor_type_same(args, mstype.float_type, self.name)
return x_dtype
class TanhGrad(PrimitiveWithInfer):
"""Computes gradient of hyperbolic tangent of input element-wise."""

@ -183,6 +183,41 @@ class LogSoftmax(PrimitiveWithInfer):
return logits
class Softplus(PrimitiveWithInfer):
r"""
Softplus activation function.
Softplus is a smooth approximation to the ReLU function.
The function is shown as follows:
.. math::
\text{output} = \log(1 + \exp(\text{input_x})),
Inputs:
- **input_x** (Tensor) - The input tensor whose data type should be float.
Outputs:
Tensor, with the same type and shape as the `input_x`.
Examples:
>>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
>>> softplus = P.Softplus()
>>> softplus(input_x)
[1.3132615, 2.126928, 3.0485873, 4.01815, 5.0067153]
"""
@prim_attr_register
def __init__(self):
"""init Softplus"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, input_x):
return input_x
def infer_dtype(self, input_x):
validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name)
return input_x
class ReLU(PrimitiveWithInfer):
r"""
Computes ReLU(Rectified Linear Unit) of input tensor element-wise.
@ -2701,11 +2736,14 @@ class ApplyFtrl(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
outputs=['output'])
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
self.is_tbe = context.get_context("device_target") == "Ascend"
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
lr_power_shape):
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
if self.is_tbe:
return var_shape, var_shape, var_shape
return var_shape
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type):
@ -2717,6 +2755,8 @@ class ApplyFtrl(PrimitiveWithInfer):
validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name)
if self.is_tbe:
return var_type, var_type, var_type
return var_type

@ -185,6 +185,22 @@ class ScatterMax(nn.Cell):
out = self.scatter_max(self.ref, indices, updates)
return out
class ApplyFtrlNet(nn.Cell):
def __init__(self):
super(ApplyFtrlNet, self).__init__()
self.apply_ftrl = P.ApplyFtrl()
self.lr = 0.001
self.l1 = 0.0
self.l2 = 0.0
self.lr_power = -0.5
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear")
def construct(self, grad):
out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, self.lr_power)
return out
test_case_math_ops = [
('Neg', {
@ -602,6 +618,14 @@ test_case_nn_ops = [
'block': G.ReluGrad(),
'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]],
'skip': ['backward']}),
('Softplus', {
'block': P.Softplus(),
'desc_inputs': [[1, 3, 4, 4]],
'desc_bprop': [[1, 3, 4, 4]]}),
('SoftplusGrad', {
'block': G.SoftplusGrad(),
'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]],
'skip': ['backward']}),
('Elu', {
'block': P.Elu(),
'desc_inputs': [[2, 3, 4]],
@ -869,9 +893,8 @@ test_case_nn_ops = [
'desc_inputs': [[3, 2]],
'desc_bprop': [[3, 2]]}),
('ApplyFtrl', {
'block': P.ApplyFtrl(),
'desc_const': [0.001, 0.0, 0.0, -0.5],
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
'block': ApplyFtrlNet(),
'desc_inputs': [[3, 3]],
'desc_bprop': [3, 3],
'skip': ['backward']}),
('ApplyRMSProp', {

Loading…
Cancel
Save