modify Gelu、FastGelu to GeLU and FastGeLU

pull/12046/head
jinyaohui 4 years ago
parent 408159e301
commit 30a27b2adb

File diff suppressed because one or more lines are too long

@ -22,7 +22,7 @@ HALF = 0.5
def expand_gelu(expand_info): def expand_gelu(expand_info):
"""Gelu expander""" """GeLU expander"""
# cal formula are: # cal formula are:
# gelu(x) is 0.5 * x * (1.0 + tanh(y)) # gelu(x) is 0.5 * x * (1.0 + tanh(y))
# y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) # y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)

@ -23,7 +23,7 @@ HALF = 0.5
def expand_gelugrad(expand_info): def expand_gelugrad(expand_info):
"""GeluGrad expander""" """GeLUGrad expander"""
# cal formula are: # cal formula are:
# gelu_grad(dy, x) is dy * y' # gelu_grad(dy, x) is dy * y'
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right # y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right

@ -128,7 +128,7 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = FLOOR; operate_type_ = FLOOR;
} else if (kernel_name == prim::kPrimReciprocal->name()) { } else if (kernel_name == prim::kPrimReciprocal->name()) {
operate_type_ = RECIPROCAL; operate_type_ = RECIPROCAL;
} else if (kernel_name == prim::kPrimGelu->name()) { } else if (kernel_name == prim::kPrimGeLU->name()) {
operate_type_ = GELU; operate_type_ = GELU;
} else if (kernel_name == prim::kPrimAsin->name()) { } else if (kernel_name == prim::kPrimAsin->name()) {
operate_type_ = ASIN; operate_type_ = ASIN;

@ -66,7 +66,7 @@ MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutput
ArithmeticSelfCPUKernel); ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel); ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_CPU_KERNEL(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel); ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticSelfCPUKernel); ArithmeticSelfCPUKernel);

@ -147,7 +147,7 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = TANHGRAD; operate_type_ = TANHGRAD;
} else if (kernel_name == "SqrtGrad") { } else if (kernel_name == "SqrtGrad") {
operate_type_ = SQRTGRAD; operate_type_ = SQRTGRAD;
} else if (kernel_name == "GeluGrad") { } else if (kernel_name == "GeLUGrad") {
operate_type_ = GELUGRAD; operate_type_ = GELUGRAD;
} else if (kernel_name == "AsinGrad") { } else if (kernel_name == "AsinGrad") {
operate_type_ = ASINGRAD; operate_type_ = ASINGRAD;

@ -88,7 +88,7 @@ MS_REG_CPU_KERNEL(
TanhGrad, TanhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel); EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(GeluGrad, MS_REG_CPU_KERNEL(GeLUGrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)

@ -18,14 +18,14 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(GeluGrad, MS_REG_GPU_KERNEL_ONE(GeLUGrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
GeLUGpuGradKernel, float) GeLUGpuGradKernel, float)
MS_REG_GPU_KERNEL_ONE(GeluGrad, MS_REG_GPU_KERNEL_ONE(GeLUGrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)

@ -18,9 +18,9 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
GeluGpuKernel, float) GeluGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
GeluGpuKernel, half) GeluGpuKernel, half)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -701,7 +701,7 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
std::unordered_set<PrimitivePtr> GetExpandOps() { std::unordered_set<PrimitivePtr> GetExpandOps() {
std::unordered_set<PrimitivePtr> expand_ops = { std::unordered_set<PrimitivePtr> expand_ops = {
prim::kPrimSquare, prim::kPrimSquare,
prim::kPrimGeluGrad, prim::kPrimGeLUGrad,
#if ENABLE_D #if ENABLE_D
prim::kPrimTile, prim::kPrimTile,
prim::kPrimSqrtGrad, prim::kPrimSqrtGrad,
@ -709,7 +709,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
#elif ENABLE_GPU #elif ENABLE_GPU
prim::kPrimBiasAdd, prim::kPrimBiasAdd,
prim::kPrimBiasAddGrad, prim::kPrimBiasAddGrad,
prim::kPrimGelu, prim::kPrimGeLU,
prim::kPrimFusedAdam, prim::kPrimFusedAdam,
prim::kPrimFusedAdamWeightDecay, prim::kPrimFusedAdamWeightDecay,
prim::kPrimReduceMean, prim::kPrimReduceMean,

@ -77,7 +77,7 @@ class RegisterAction {
// operator register // operator register
REGISTER(MatMulInfo); REGISTER(MatMulInfo);
REGISTER(GeluInfo); REGISTER(GeLUInfo);
REGISTER(VirtualDatasetInfo); REGISTER(VirtualDatasetInfo);
REGISTER(BatchParallelInfo); REGISTER(BatchParallelInfo);
REGISTER(TanhInfo); REGISTER(TanhInfo);

@ -82,12 +82,12 @@ class ActivationOther : public Activation {
Status GetAttrs() override; Status GetAttrs() override;
}; };
class GeluInfo : public ActivationOther { class GeLUInfo : public ActivationOther {
public: public:
GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, GeLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs) const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<GeLUCost>()) {} : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<GeLUCost>()) {}
~GeluInfo() override = default; ~GeLUInfo() override = default;
}; };
class TanhInfo : public ActivationOther { class TanhInfo : public ActivationOther {

@ -187,7 +187,7 @@ constexpr char CONCAT[] = "Concat";
constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits"; constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits";
constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits"; constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits";
constexpr char MATMUL[] = "MatMul"; constexpr char MATMUL[] = "MatMul";
constexpr char GELU[] = "Gelu"; constexpr char GELU[] = "GeLU";
constexpr char TANH[] = "Tanh"; constexpr char TANH[] = "Tanh";
constexpr char RECEIVE[] = "Receive"; constexpr char RECEIVE[] = "Receive";
constexpr char SEND[] = "Send"; constexpr char SEND[] = "Send";

@ -459,7 +459,6 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
op_prim->EndRecordAddAttr(); op_prim->EndRecordAddAttr();
} }
void ConvertAttrToUnifyMindIR(const OpExecInfoPtr &op_run_info) { void ConvertAttrToUnifyMindIR(const OpExecInfoPtr &op_run_info) {
MS_EXCEPTION_IF_NULL(op_run_info); MS_EXCEPTION_IF_NULL(op_run_info);
PrimitivePtr op_prim = op_run_info->py_primitive; PrimitivePtr op_prim = op_run_info->py_primitive;
@ -479,7 +478,6 @@ void ConvertAttrToUnifyMindIR(const OpExecInfoPtr &op_run_info) {
} }
} }
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) { BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
if (utils::isa<VectorRef>(base_ref)) { if (utils::isa<VectorRef>(base_ref)) {
auto ref_list = utils::cast<VectorRef>(base_ref); auto ref_list = utils::cast<VectorRef>(base_ref);

@ -101,27 +101,27 @@ ATTR_MAP(TanhGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(TanhGrad) = {{0, OUTPUT_DESC(z)}}; OUTPUT_MAP(TanhGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(TanhGrad, prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad)) REG_ADPT_DESC(TanhGrad, prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad))
// Gelu // GeLU
INPUT_MAP(Gelu) = {{1, INPUT_DESC(x)}}; INPUT_MAP(Gelu) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Gelu) = EMPTY_ATTR_MAP; ATTR_MAP(Gelu) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Gelu) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Gelu) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Gelu, prim::kPrimGelu->name(), ADPT_DESC(Gelu)) REG_ADPT_DESC(Gelu, prim::kPrimGeLU->name(), ADPT_DESC(Gelu))
// GeluGrad // GeLUGrad
INPUT_MAP(GeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(y)}}; INPUT_MAP(GeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(y)}};
ATTR_MAP(GeluGrad) = EMPTY_ATTR_MAP; ATTR_MAP(GeluGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(GeluGrad) = {{0, OUTPUT_DESC(z)}}; OUTPUT_MAP(GeluGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(GeluGrad, prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)) REG_ADPT_DESC(GeluGrad, prim::kPrimGeLUGrad->name(), ADPT_DESC(GeluGrad))
// FastGelu // FastGeLU
INPUT_MAP(FastGelu) = {{1, INPUT_DESC(x)}}; INPUT_MAP(FastGelu) = {{1, INPUT_DESC(x)}};
ATTR_MAP(FastGelu) = EMPTY_ATTR_MAP; ATTR_MAP(FastGelu) = EMPTY_ATTR_MAP;
OUTPUT_MAP(FastGelu) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(FastGelu) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(FastGelu, prim::kPrimFastGelu->name(), ADPT_DESC(FastGelu)) REG_ADPT_DESC(FastGelu, prim::kPrimFastGeLU->name(), ADPT_DESC(FastGelu))
// FastGeluGrad // FastGeLUGrad
INPUT_MAP(FastGeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}}; INPUT_MAP(FastGeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}};
ATTR_MAP(FastGeluGrad) = EMPTY_ATTR_MAP; ATTR_MAP(FastGeluGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(FastGeluGrad) = {{0, OUTPUT_DESC(z)}}; OUTPUT_MAP(FastGeluGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(FastGeluGrad, prim::kPrimFastGeluGrad->name(), ADPT_DESC(FastGeluGrad)) REG_ADPT_DESC(FastGeluGrad, prim::kPrimFastGeLUGrad->name(), ADPT_DESC(FastGeluGrad))
} // namespace mindspore::transform } // namespace mindspore::transform

@ -65,13 +65,13 @@ AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplGeLU(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplGeLUGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFastGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplFastGeLU(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFastGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplFastGeLUGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);

@ -43,6 +43,10 @@ constexpr auto kScalarUsub = "ScalarUsub";
constexpr auto kStack = "Stack"; constexpr auto kStack = "Stack";
constexpr auto kUnstack = "Unstack"; constexpr auto kUnstack = "Unstack";
constexpr auto kTupleGetItem = "TupleGetItem"; constexpr auto kTupleGetItem = "TupleGetItem";
constexpr auto kGeLU = "GeLU";
constexpr auto kGeLUGrad = "GeLUGrad";
constexpr auto kFastGeLU = "FastGeLU";
constexpr auto kFastGeLUGrad = "FastGeLUGrad";
// Here list all primitives used in backend or some special primitives used by core. // Here list all primitives used in backend or some special primitives used by core.
// Arithmetic // Arithmetic
@ -257,11 +261,10 @@ inline const PrimitivePtr kPrimDropout = std::make_shared<Primitive>("Dropout");
inline const PrimitivePtr kPrimUniformReal = std::make_shared<Primitive>("UniformReal"); inline const PrimitivePtr kPrimUniformReal = std::make_shared<Primitive>("UniformReal");
inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared<Primitive>("CudnnUniformReal"); inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared<Primitive>("CudnnUniformReal");
inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
inline const PrimitivePtr kPrimGeLU = std::make_shared<Primitive>("Gelu"); inline const PrimitivePtr kPrimGeLU = std::make_shared<Primitive>(kGeLU);
inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); inline const PrimitivePtr kPrimGeLUGrad = std::make_shared<Primitive>(kGeLUGrad);
inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); inline const PrimitivePtr kPrimFastGeLU = std::make_shared<Primitive>(kFastGeLU);
inline const PrimitivePtr kPrimFastGelu = std::make_shared<Primitive>("FastGelu"); inline const PrimitivePtr kPrimFastGeLUGrad = std::make_shared<Primitive>(kFastGeLUGrad);
inline const PrimitivePtr kPrimFastGeluGrad = std::make_shared<Primitive>("FastGeluGrad");
inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu"); inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu");
inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6"); inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6");

@ -31,6 +31,7 @@ from ... import context
env_force_bprop_seq = os.getenv("ENV_FORCE_BPROP_SEQ") env_force_bprop_seq = os.getenv("ENV_FORCE_BPROP_SEQ")
@bprop_getters.register(P.BiasAdd) @bprop_getters.register(P.BiasAdd)
def get_bprop_bias_add(self): def get_bprop_bias_add(self):
"""Grad definition for `BiasAdd` operation.""" """Grad definition for `BiasAdd` operation."""
@ -681,10 +682,10 @@ def get_bprop_tanh_grad(self):
return bprop return bprop
@bprop_getters.register(P.Gelu) @bprop_getters.register(P.GeLU)
def get_bprop_gelu(self): def get_bprop_gelu(self):
"""Grad definition for `Gelu` operation.""" """Grad definition for `GeLU` operation."""
input_grad = G.GeluGrad() input_grad = G.GeLUGrad()
def bprop(x, out, dout): def bprop(x, out, dout):
dx = input_grad(dout, x, out) dx = input_grad(dout, x, out)
@ -693,10 +694,34 @@ def get_bprop_gelu(self):
return bprop return bprop
@bprop_getters.register(P.FastGelu) @bprop_getters.register(P.Gelu)
def get_bprop_gelu_2(self):
"""Grad definition for `GeLU` operation."""
input_grad = G.GeLUGrad()
def bprop(x, out, dout):
dx = input_grad(dout, x, out)
return (dx,)
return bprop
@bprop_getters.register(P.FastGeLU)
def get_bprop_fast_gelu(self): def get_bprop_fast_gelu(self):
"""Grad definition for `FastGelu` operation.""" """Grad definition for `FastGeLU` operation."""
input_grad = G.FastGeluGrad() input_grad = G.FastGeLUGrad()
def bprop(x, out, dout):
dx = input_grad(dout, x)
return (dx,)
return bprop
@bprop_getters.register(P.FastGelu)
def get_bprop_fast_gelu_2(self):
"""Grad definition for `FastGeLU` operation."""
input_grad = G.FastGeLUGrad()
def bprop(x, out, dout): def bprop(x, out, dout):
dx = input_grad(dout, x) dx = input_grad(dout, x)
@ -713,6 +738,7 @@ def get_bprop_fused_batch_norm(self):
if self.target == "CPU": if self.target == "CPU":
input_grad = G.FusedBatchNormGradCPU(self.epsilon, self.momentum) input_grad = G.FusedBatchNormGradCPU(self.epsilon, self.momentum)
target_cpu = True target_cpu = True
def bprop(x, scale, b, mean, variance, out, dout): def bprop(x, scale, b, mean, variance, out, dout):
saved_mean = out[3] saved_mean = out[3]
saved_variance = out[4] saved_variance = out[4]
@ -897,6 +923,7 @@ def _range_op(start, limit, delta, dtype):
output_tensor = Tensor(list(range(start, limit, delta)), dtype) output_tensor = Tensor(list(range(start, limit, delta)), dtype)
return output_tensor return output_tensor
@constexpr @constexpr
def _get_1d_shape(in_shape): def _get_1d_shape(in_shape):
"""helper function for Grad TopK""" """helper function for Grad TopK"""
@ -905,6 +932,7 @@ def _get_1d_shape(in_shape):
out_shape *= i out_shape *= i
return (out_shape,) return (out_shape,)
@bprop_getters.register(P.TopK) @bprop_getters.register(P.TopK)
def get_bprop_top_kv2(self): def get_bprop_top_kv2(self):
"""Grad definition for `TopK` operation.""" """Grad definition for `TopK` operation."""
@ -915,7 +943,6 @@ def get_bprop_top_kv2(self):
dtype = P.DType() dtype = P.DType()
def bprop(input_x, k, out, dout): def bprop(input_x, k, out, dout):
in_shape = shape_op(input_x) in_shape = shape_op(input_x)
in_lastdim = in_shape[-1] in_lastdim = in_shape[-1]
@ -976,6 +1003,7 @@ def get_bprop_rnnt_loss(self):
def bprop(acts, labels, act_lens, label_lens, out, dout): def bprop(acts, labels, act_lens, label_lens, out, dout):
grad = out[1] grad = out[1]
return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens) return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
return bprop return bprop
@ -1064,6 +1092,7 @@ def get_bprop_dynamic_rnn(self):
dh_prev = expand_dims(dh_prev, 0) dh_prev = expand_dims(dh_prev, 0)
dc_prev = expand_dims(dc_prev, 0) dc_prev = expand_dims(dc_prev, 0)
return dx, dw, db, (0), dh_prev, dc_prev return dx, dw, db, (0), dh_prev, dc_prev
return bprop return bprop
@ -1082,6 +1111,7 @@ def get_bprop_dynamic_gru_v2(self):
out_h, dy, dout_h[-1], update, out_h, dy, dout_h[-1], update,
reset, new, hidden_new, None, None) reset, new, hidden_new, None, None)
return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev
return bprop return bprop
@ -1181,6 +1211,7 @@ def get_bprop_binary_cross_entropy(self):
return bprop return bprop
@bprop_getters.register(P.KLDivLoss) @bprop_getters.register(P.KLDivLoss)
def get_bprop_kl_div_loss(self): def get_bprop_kl_div_loss(self):
"""Grad definition for `KLDivLoss` operation.""" """Grad definition for `KLDivLoss` operation."""
@ -1239,6 +1270,7 @@ def get_bprop_basic_lstm_cell(self):
dxt, dht = basic_lstm_cell_input_grad(dgate, w) dxt, dht = basic_lstm_cell_input_grad(dgate, w)
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
return dxt, dht, dct_1, dw, db return dxt, dht, dct_1, dw, db
return bprop return bprop

@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FastGelu op""" """FastGeLU op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fast_gelu_op_info = TBERegOp("FastGelu") \ fast_gelu_op_info = TBERegOp("FastGeLU") \
.fusion_type("ELEMWISE") \ .fusion_type("ELEMWISE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("fast_gelu.so") \ .binfile_name("fast_gelu.so") \
@ -33,5 +33,5 @@ fast_gelu_op_info = TBERegOp("FastGelu") \
@op_info_register(fast_gelu_op_info) @op_info_register(fast_gelu_op_info)
def _fast_gelu_tbe(): def _fast_gelu_tbe():
"""FastGelu TBE register""" """FastGeLU TBE register"""
return return

@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FastGeluGrad op""" """FastGeLUGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fast_gelu_grad_op_info = TBERegOp("FastGeluGrad") \ fast_gelu_grad_op_info = TBERegOp("FastGeLUGrad") \
.fusion_type("ELEMWISE") \ .fusion_type("ELEMWISE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("fast_gelu_grad.so") \ .binfile_name("fast_gelu_grad.so") \
@ -37,5 +37,5 @@ fast_gelu_grad_op_info = TBERegOp("FastGeluGrad") \
@op_info_register(fast_gelu_grad_op_info) @op_info_register(fast_gelu_grad_op_info)
def _fast_gelu_grad_tbe(): def _fast_gelu_grad_tbe():
"""FastGeluGrad TBE register""" """FastGeLUGrad TBE register"""
return return

@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Gelu op""" """GeLU op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
gelu_op_info = TBERegOp("Gelu") \ gelu_op_info = TBERegOp("GeLU") \
.fusion_type("ELEMWISE") \ .fusion_type("ELEMWISE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("gelu.so") \ .binfile_name("gelu.so") \
@ -33,5 +33,5 @@ gelu_op_info = TBERegOp("Gelu") \
@op_info_register(gelu_op_info) @op_info_register(gelu_op_info)
def _gelu_tbe(): def _gelu_tbe():
"""Gelu TBE register""" """GeLU TBE register"""
return return

@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""GeluGrad op""" """GeLUGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
gelu_grad_op_info = TBERegOp("GeluGrad") \ gelu_grad_op_info = TBERegOp("GeLUGrad") \
.fusion_type("ELEMWISE") \ .fusion_type("ELEMWISE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("gelu_grad.so") \ .binfile_name("gelu_grad.so") \
@ -38,5 +38,5 @@ gelu_grad_op_info = TBERegOp("GeluGrad") \
@op_info_register(gelu_grad_op_info) @op_info_register(gelu_grad_op_info)
def _gelu_grad_tbe(): def _gelu_grad_tbe():
"""GeluGrad TBE register""" """GeLUGrad TBE register"""
return return

@ -43,7 +43,8 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm
from .control_ops import ControlDepend, GeSwitch, Merge from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
BitwiseAnd, BitwiseOr,
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub, BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceAny, ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceAny,
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil, Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
@ -65,7 +66,8 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
DepthwiseConv2dNative, DepthwiseConv2dNative,
DropoutDoMask, Dropout, Dropout3d, DropoutGenMask, Flatten, DropoutDoMask, Dropout, Dropout3d, DropoutGenMask, Flatten,
FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate, FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
Gelu, FastGelu, Elu, GeLU, Gelu, FastGeLU, FastGelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
LogSoftmax, LogSoftmax,
MaxPool, DataFormatDimMap, MaxPool, DataFormatDimMap,
@ -93,7 +95,8 @@ from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle, CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient, Cholesky, CholeskyTrsm, DetTriangle,
ProdForceSeA) ProdForceSeA)
from .sparse_ops import SparseToDense from .sparse_ops import SparseToDense
from ._embedding_cache_ops import (CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter, from ._embedding_cache_ops import (CacheSwapHashmap, SearchCacheIdx, CacheSwapTable, UpdateCache, MapCacheIdx,
SubAndFilter,
MapUniform, DynamicAssign, PadAndShift) MapUniform, DynamicAssign, PadAndShift)
__all__ = [ __all__ = [
@ -174,7 +177,9 @@ __all__ = [
'Unstack', 'Unstack',
'Tile', 'Tile',
'BiasAdd', 'BiasAdd',
'GeLU',
'Gelu', 'Gelu',
'FastGeLU',
'FastGelu', 'FastGelu',
'Minimum', 'Minimum',
'Maximum', 'Maximum',

@ -790,12 +790,12 @@ class BNTrainingUpdateGrad(PrimitiveWithInfer):
return (batch_mean, batch_variance) return (batch_mean, batch_variance)
class GeluGrad(PrimitiveWithInfer): class GeLUGrad(PrimitiveWithInfer):
"""Gradients of Gelu operation.""" """Gradients of GeLU operation."""
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""Initialize GeluGrad""" """Initialize GeLUGrad"""
def infer_shape(self, y_backprop_shape, x_shape, y_shape): def infer_shape(self, y_backprop_shape, x_shape, y_shape):
return x_shape return x_shape
@ -808,12 +808,12 @@ class GeluGrad(PrimitiveWithInfer):
return x_dtype return x_dtype
class FastGeluGrad(PrimitiveWithInfer): class FastGeLUGrad(PrimitiveWithInfer):
"""Gradients of FastGelu operation.""" """Gradients of FastGeLU operation."""
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init FastGeluGrad""" """init FastGeLUGrad"""
def infer_shape(self, y_backprop_shape, x_shape): def infer_shape(self, y_backprop_shape, x_shape):
return x_shape return x_shape

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save