Change TensorAdd to Add, from r1.1 to master

pull/11915/head
l00591931 4 years ago
parent 36f3c4740d
commit 9ec100d069

File diff suppressed because one or more lines are too long

@ -36,24 +36,24 @@ def expand_biasadd(expand_info):
'ExpandDims', [input_y], attrs={'axis': 1})
input_y_expand = graph_builder.emit(
'ExpandDims', [input_y_expand], attrs={'axis': 2})
result = graph_builder.emit('TensorAdd', [input_x, input_y_expand])
result = graph_builder.emit('Add', [input_x, input_y_expand])
elif input_x.data_format == "DefaultFormat":
if len(input_x.shape) == 2:
result = graph_builder.emit('TensorAdd', [input_x, input_y])
result = graph_builder.emit('Add', [input_x, input_y])
elif len(input_x.shape) == 3:
input_y_expand = graph_builder.emit(
'ExpandDims', [input_y], attrs={'axis': 1})
result = graph_builder.emit(
'TensorAdd', [input_x, input_y_expand])
'Add', [input_x, input_y_expand])
else:
input_y_expand = graph_builder.emit(
'ExpandDims', [input_y], attrs={'axis': 1})
input_y_expand = graph_builder.emit(
'ExpandDims', [input_y_expand], attrs={'axis': 2})
result = graph_builder.emit(
'TensorAdd', [input_x, input_y_expand])
'Add', [input_x, input_y_expand])
else:
result = graph_builder.emit('TensorAdd', [input_x, input_y])
result = graph_builder.emit('Add', [input_x, input_y])
# set graph output.
graph_scope.set_output(result)

@ -49,13 +49,13 @@ def expand_fusedadam(expand_info):
# compute result
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad])
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
grad_square = graph_builder.emit('Mul', [gradient, gradient])
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps])
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
update_with_lr = graph_builder.emit('Mul', [lr, update])
next_para = graph_builder.emit('Sub', [param, update_with_lr])

@ -52,16 +52,16 @@ def expand_fusedadamweightdecay(expand_info):
# compute result
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad])
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
grad_square = graph_builder.emit('Mul', [gradient, gradient])
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps])
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param])
update = graph_builder.emit('TensorAdd', [update, param_with_weight_decay])
update = graph_builder.emit('Add', [update, param_with_weight_decay])
update_with_lr = graph_builder.emit('Mul', [lr, update])
next_para = graph_builder.emit('Sub', [param, update_with_lr])

@ -42,7 +42,7 @@ def expand_gelu(expand_info):
pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format'])
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1])
tanh_res = graph_builder.emit('Add', [input_x, mul_1])
const_csvalue_sqrt_two_div_pi = graph_builder.value(
tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format'])
y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi])
@ -51,7 +51,7 @@ def expand_gelu(expand_info):
tanh_y = graph_builder.emit('Tanh', [y])
const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format'])
const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format'])
tanh_y_add_one = graph_builder.emit('TensorAdd', [tanh_y, const_one])
tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one])
mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
result = graph_builder.emit('Mul', [const_half, mul_x])

@ -55,18 +55,18 @@ def expand_gelugrad(expand_info):
# cal mul_right
mul_double = graph_builder.emit('Mul', [input_x, input_x])
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
mul_add_one = graph_builder.emit('TensorAdd', [const_one, mul_double_mul_tri])
mul_add_one = graph_builder.emit('Add', [const_one, mul_double_mul_tri])
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
# cal tanh_para
mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
mul_add_x = graph_builder.emit('TensorAdd', [input_x, mul_triple_mul_csvalue])
mul_add_x = graph_builder.emit('Add', [input_x, mul_triple_mul_csvalue])
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])
# cal 0.5 * (1.0 + tanh(tahn_para))
tanh_res = graph_builder.emit('Tanh', [tanh_para])
tanh_res_add_one = graph_builder.emit('TensorAdd', [const_one, tanh_res])
tanh_res_add_one = graph_builder.emit('Add', [const_one, tanh_res])
half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one])
# cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right
@ -77,7 +77,7 @@ def expand_gelugrad(expand_info):
mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right])
# cal result
result_tmp = graph_builder.emit('TensorAdd', [half_mul_tanh_res_add_one, mul_final])
result_tmp = graph_builder.emit('Add', [half_mul_tanh_res_add_one, mul_final])
result = graph_builder.emit('Mul', [input_dy, result_tmp])
# set graph output.

@ -68,13 +68,13 @@ def expand_layernorm(expand_info):
# Calculate normalize
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format)
normalize_add = graph_builder.emit('TensorAdd', [variance, epsilon_v])
normalize_add = graph_builder.emit('Add', [variance, epsilon_v])
normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add])
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
# Calculate scale and translate
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
res = graph_builder.emit('TensorAdd', [scale_mul, input_beta])
res = graph_builder.emit('Add', [scale_mul, input_beta])
# set graph output.
graph_scope.set_output(res, mean, variance)

@ -66,7 +66,7 @@ def expand_layernormgrad(expand_info):
mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size), x.data_format)
# cal dg db
var_eps = graph_builder.emit('TensorAdd', [variance, eps])
var_eps = graph_builder.emit('Add', [variance, eps])
sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps])
rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps])
x_sub_mean = graph_builder.emit('Sub', [x, mean])
@ -100,10 +100,10 @@ def expand_layernormgrad(expand_info):
neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2])
sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3])
mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3])
add_tmp = graph_builder.emit('TensorAdd', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
add_tmp = graph_builder.emit('Add', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3])
dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof])
dx_tmp = graph_builder.emit('TensorAdd', [dx_1, dx_2])
dx = graph_builder.emit('TensorAdd', [dx_tmp, dx_3])
dx_tmp = graph_builder.emit('Add', [dx_1, dx_2])
dx = graph_builder.emit('Add', [dx_tmp, dx_3])
# set graph output.
graph_scope.set_output(dx, dg, db)

@ -131,7 +131,7 @@ class PrimLib:
]
primtives = {
'TensorAdd': Prim(ELEMWISE),
'Add': Prim(ELEMWISE),
'Abs': Prim(ELEMWISE),
'Neg': Prim(ELEMWISE),
'Mul': Prim(ELEMWISE),

@ -238,7 +238,7 @@ void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out,
void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == prim::kPrimTensorAdd->name()) {
if (kernel_name == prim::kPrimAdd->name()) {
operate_type_ = ADD;
} else if (kernel_name == prim::kPrimSub->name()) {
operate_type_ = SUB;

@ -37,8 +37,7 @@ class TensorAddCPUKernel : public MKLCPUKernel {
};
MS_REG_CPU_KERNEL(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TensorAddCPUKernel);
} // namespace kernel
} // namespace mindspore

@ -51,8 +51,7 @@ MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
FloorDiv,
@ -103,8 +102,7 @@ MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
Add, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
FloorDiv,
@ -133,7 +131,7 @@ MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
Add, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
@ -171,7 +169,7 @@ MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
Add, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),

@ -145,7 +145,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
static std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = {
{"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER},
{"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
{"TensorAdd", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
{"Add", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
{"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN},
};

@ -1063,7 +1063,7 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i
std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) {
static std::map<std::string, std::string> buffer_fussion_op_map = {
{parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}, {parallel::TENSOR_ADD, parallel::ADD}};
{parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}};
string result = origin_type;
auto iter = buffer_fussion_op_map.find(origin_type);
if (iter != buffer_fussion_op_map.end()) {

@ -99,7 +99,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) {
auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) {
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimAdd)) {
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
}
}

@ -28,7 +28,7 @@ const BaseRef AdamApplyOneFusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
}
@ -41,7 +41,7 @@ const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
}
@ -54,7 +54,7 @@ const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
}
@ -67,7 +67,7 @@ const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
}
@ -80,7 +80,7 @@ const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
}
@ -94,7 +94,7 @@ const BaseRef AdamApplyOneAssignFusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
@ -114,7 +114,7 @@ const BaseRef AdamApplyOneAssignCond1Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
@ -134,7 +134,7 @@ const BaseRef AdamApplyOneAssignCond2Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
@ -154,7 +154,7 @@ const BaseRef AdamApplyOneAssignCond3Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, sqrt0, add2_y_})});
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
@ -174,7 +174,7 @@ const BaseRef AdamApplyOneAssignCond4Fusion::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimAdd, add2_y_, sqrt0})});
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});

@ -38,8 +38,8 @@ class AdamApplyOneFusion : public PatternProcessPass {
mul_x_input_vars_.push_back(std::make_shared<Var>());
}
add2_y_ = std::make_shared<Var>();
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name()));
}

@ -59,10 +59,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond1::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, input4_, add3});
VectorRef sub0({prim::kPrimSub, input3_, mul5});
return sub0;
@ -79,10 +79,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond2::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({prim::kPrimSub, input3_, mul5});
return sub0;
@ -99,10 +99,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond3::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({prim::kPrimSub, input3_, mul5});
return sub0;
@ -119,10 +119,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond4::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({prim::kPrimSub, input3_, mul5});
return sub0;
@ -139,10 +139,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({prim::kPrimSub, input3_, mul5});
return sub0;
@ -159,10 +159,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond1::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, input4_, add3});
VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
@ -184,10 +184,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond2::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
@ -209,10 +209,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond3::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
@ -234,10 +234,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond4::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
@ -259,10 +259,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond5::DefinePattern() const {
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});

@ -38,8 +38,8 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass {
mul3_x_ = std::make_shared<Var>();
mul4_x_ = std::make_shared<Var>();
add2_y_ = std::make_shared<Var>();
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name()));
}
~AdamApplyOneWithDecayRule() override = default;

@ -130,11 +130,11 @@ const BaseRef LambNextMVRuleCond1::DefinePattern() const {
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1});
auto add2 = VectorRef({prim::kPrimAdd, add2_y_, real_div1});
auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
return VectorRef({prim::kPrimAdd, mul4, real_div2});
}
BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const {
@ -147,7 +147,7 @@ BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt1});
VectorRef add4 = VectorRef({prim::kPrimAdd, add2_y_, sqrt1});
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
return real_div4;
}
@ -166,11 +166,11 @@ const BaseRef LambNextMVRuleCond2::DefinePattern() const {
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1});
auto add2 = VectorRef({prim::kPrimAdd, add2_y_, real_div1});
auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
return VectorRef({prim::kPrimAdd, mul4, real_div2});
}
BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const {
@ -183,7 +183,7 @@ BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
return real_div4;
}
@ -202,11 +202,11 @@ const BaseRef LambNextMVRuleCond3::DefinePattern() const {
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_});
auto add2 = VectorRef({prim::kPrimAdd, real_div1, add2_y_});
auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
return VectorRef({prim::kPrimAdd, mul4, real_div2});
}
BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const {
@ -219,7 +219,7 @@ BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
return real_div4;
}
@ -238,11 +238,11 @@ const BaseRef LambNextMVRuleCond4::DefinePattern() const {
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_});
auto add2 = VectorRef({prim::kPrimAdd, real_div1, add2_y_});
auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0});
return VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
return VectorRef({prim::kPrimAdd, real_div2, mul4});
}
BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
@ -255,7 +255,7 @@ BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, add2_y_});
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
return real_div4;
}

@ -49,8 +49,8 @@ class LambNextMVRule : public MultipleOutputPatternProcessPass {
real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
real_div2_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
}
~LambNextMVRule() override = default;
const BaseRef DefinePattern() const override = 0;

@ -124,10 +124,10 @@ BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1});
VectorRef add2 = VectorRef({prim::kPrimAdd, constant_add2_y_, real_div1});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
return add3;
}
@ -141,14 +141,14 @@ const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const {
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]});
VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
return add5;
}
@ -165,10 +165,10 @@ BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1});
VectorRef add2 = VectorRef({prim::kPrimAdd, constant_add2_y_, real_div1});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
return add3;
}
@ -182,14 +182,14 @@ const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const {
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1});
VectorRef add4 = VectorRef({prim::kPrimAdd, constant_add2_y_, sqrt1});
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
return add5;
}
@ -206,10 +206,10 @@ BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_});
VectorRef add2 = VectorRef({prim::kPrimAdd, real_div1, constant_add2_y_});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
VectorRef add3 = VectorRef({prim::kPrimAdd, mul4, real_div2});
return add3;
}
@ -223,14 +223,14 @@ const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const {
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]});
VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
VectorRef add5 = VectorRef({prim::kPrimAdd, mul4, real_div4});
return add5;
}
@ -248,10 +248,10 @@ BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const {
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_});
VectorRef add2 = VectorRef({prim::kPrimAdd, real_div1, constant_add2_y_});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
VectorRef add3 = VectorRef({prim::kPrimAdd, real_div2, mul4});
return add3;
}
@ -265,14 +265,14 @@ const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const {
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
VectorRef add4 = VectorRef({prim::kPrimAdd, sqrt1, constant_add2_y_});
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4});
VectorRef add5 = VectorRef({prim::kPrimAdd, real_div4, mul4});
return add5;
}
} // namespace opt

@ -38,8 +38,8 @@ class LambNextMVWithDecayRule : public MultipleOutputPatternProcessPass {
mul4_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
}
~LambNextMVWithDecayRule() override = default;

@ -66,7 +66,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
return false;
}
auto add5 = node->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(add5) != prim::kPrimTensorAdd->name() || add5->inputs().size() != kAddInputNum) {
if (AnfAlgo::GetCNodeName(add5) != prim::kPrimAdd->name() || add5->inputs().size() != kAddInputNum) {
return false;
}
auto real_div4_anf = add5->input(1);
@ -82,7 +82,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
return false;
}
auto add4 = add4_anf->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(add4) != prim::kPrimTensorAdd->name() || add4->inputs().size() != kAddInputNum) {
if (AnfAlgo::GetCNodeName(add4) != prim::kPrimAdd->name() || add4->inputs().size() != kAddInputNum) {
return false;
}
auto sqrt1_anf = add4->input(1);
@ -140,17 +140,17 @@ const BaseRef LambNextMVWithDecayV1Rule::DefinePattern() const {
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
VectorRef mul3({prim::kPrimMul, mul3_sub1_, input0_});
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
VectorRef add1({prim::kPrimTensorAdd, mul2, mul3});
VectorRef add1({prim::kPrimAdd, mul2, mul3});
VectorRef real_div1({prim_real_div, add1, input2_});
VectorRef add2({prim::kPrimTensorAdd, real_div1, add2_y_});
VectorRef add2({prim::kPrimAdd, real_div1, add2_y_});
VectorRef mul0({prim::kPrimMul, mul0_x_, input4_});
VectorRef mul1({prim::kPrimMul, mul1_sub_, input3_});
VectorRef sqrt0({prim_rsqrt, add2});
VectorRef add0({prim::kPrimTensorAdd, mul0, mul1});
VectorRef add0({prim::kPrimAdd, mul0, mul1});
VectorRef real_div0({prim_real_div, add0, input5_});
VectorRef real_div2({prim::kPrimMul, real_div0, sqrt0});
VectorRef mul4({prim::kPrimMul, mul4_x_, input6_});
VectorRef add3({prim::kPrimTensorAdd, real_div2, mul4});
VectorRef add3({prim::kPrimAdd, real_div2, mul4});
return add3;
}

@ -54,7 +54,7 @@ const BaseRef LambNextRightRule::DefinePattern() const {
VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})});
VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3});
return VectorRef(
{prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_});
{prim::kPrimAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_});
}
const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save