|
|
|
@ -5764,7 +5764,12 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|
|
|
|
Default: -0.5. It must be a float number or a scalar tensor with float16 or float32 data type.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, represents the updated `var`.
|
|
|
|
|
There are three outputs for Ascend environment.
|
|
|
|
|
- **var** (Tensor) - represents the updated `var`.
|
|
|
|
|
- **accum** (Tensor) - represents the updated `accum`.
|
|
|
|
|
- **linear** (Tensor) - represents the updated `linear`.
|
|
|
|
|
There is only one output for GPU environment.
|
|
|
|
|
- **var** (Tensor) - This value is alwalys zero and the input parameters has been updated in-place.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU``
|
|
|
|
@ -5773,8 +5778,8 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|
|
|
|
>>> import mindspore
|
|
|
|
|
>>> import mindspore.nn as nn
|
|
|
|
|
>>> import numpy as np
|
|
|
|
|
>>> from mindspore import Parameter
|
|
|
|
|
>>> from mindspore import Tensor
|
|
|
|
|
>>> from mindspore import Parameter, Tensor
|
|
|
|
|
>>> import mindspore.context as context
|
|
|
|
|
>>> from mindspore.ops import operations as ops
|
|
|
|
|
>>> class ApplyFtrlNet(nn.Cell):
|
|
|
|
|
... def __init__(self):
|
|
|
|
@ -5797,7 +5802,9 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|
|
|
|
>>> net = ApplyFtrlNet()
|
|
|
|
|
>>> input_x = Tensor(np.random.randint(-4, 4, (2, 2)), mindspore.float32)
|
|
|
|
|
>>> output = net(input_x)
|
|
|
|
|
>>> print(output)
|
|
|
|
|
>>> is_tbe = context.get_context("device_target") == "Ascend"
|
|
|
|
|
>>> if is_tbe:
|
|
|
|
|
... print(output)
|
|
|
|
|
(Tensor(shape=[2, 2], dtype=Float32, value=
|
|
|
|
|
[[ 4.61418092e-01, 5.30964255e-01],
|
|
|
|
|
[ 2.68715084e-01, 3.82065028e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
|
|
|
@ -5805,6 +5812,16 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|
|
|
|
[ 1.43758726e+00, 9.89177322e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
|
|
|
|
[[-1.86994812e+03, -1.64906018e+03],
|
|
|
|
|
[-3.22187836e+02, -1.20163989e+03]]))
|
|
|
|
|
>>> else:
|
|
|
|
|
... print(net.var.asnumpy())
|
|
|
|
|
[[0.4614181 0.5309642 ]
|
|
|
|
|
[0.2687151 0.38206503]]
|
|
|
|
|
... print(net.accum.asnumpy())
|
|
|
|
|
[[16.423655 9.645894 ]
|
|
|
|
|
[ 1.4375873 9.891773 ]]
|
|
|
|
|
... print(net.linear.asnumpy())
|
|
|
|
|
[[-1869.9479 -1649.0599]
|
|
|
|
|
[ -322.1879 -1201.6399]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|