|
|
|
@ -1,21 +1,21 @@
|
|
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import paddle.fluid.framework as framework
|
|
|
|
|
from paddle.fluid.optimizer import Optimizer
|
|
|
|
|
import paddle.fluid.core as core
|
|
|
|
|
import numpy as np
|
|
|
|
|
from paddle.distributed import fleet
|
|
|
|
|
|
|
|
|
|
registerd_op = {
|
|
|
|
|
"elementwise_add": "AddParser",
|
|
|
|
@ -555,7 +555,8 @@ class AllReduceParser(AscendParserBase):
|
|
|
|
|
def _apply(self):
|
|
|
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
|
|
|
reduction = self.reduction
|
|
|
|
|
group = "hccl_world_group" #self.op.attr("group")
|
|
|
|
|
ring_id = self.op.attr("ring_id")
|
|
|
|
|
group = "hcom_group_" + str(ring_id)
|
|
|
|
|
fusion = None #self.op.attr("fusion")
|
|
|
|
|
fusion_id = None #self.op.attr("fusion_id")
|
|
|
|
|
|
|
|
|
@ -658,12 +659,13 @@ class ReceiveParser(AscendParserBase):
|
|
|
|
|
"shape", shape).set_attr_int32("dtype", dtype)
|
|
|
|
|
return [receive], [[0]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScaleParser(AscendParserBase):
|
|
|
|
|
def __init__(self, graph, var2geop):
|
|
|
|
|
super(ScaleParser, self).__init__(graph, var2geop)
|
|
|
|
|
self.parser_name = "scale"
|
|
|
|
|
|
|
|
|
|
def _apply(self):
|
|
|
|
|
def _apply(self):
|
|
|
|
|
x = self._get_ge_input(self.op.input_arg_names[0])
|
|
|
|
|
scale = self.op.attr("scale") #self.get_ge_input(self.op.input_arg_names[1])
|
|
|
|
|
bias = self.op.attr("bias")
|
|
|
|
@ -672,9 +674,9 @@ class ScaleParser(AscendParserBase):
|
|
|
|
|
scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", bias)
|
|
|
|
|
else:
|
|
|
|
|
x_add_bias = core.GEOperatorFactory.create_operator("adds" + self._accumulated_op_id(), "Adds").set_input("x", x).set_attr_float("value", bias) #set_input("x2", bias)
|
|
|
|
|
scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x_add_bias).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", 0.0)
|
|
|
|
|
scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x_add_bias).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", 0.0)
|
|
|
|
|
#tensor_zeros = core.GEOperatorFactory.create_operator("zeroslike" + self.getid(), "ZerosLike").set_input("x", x)
|
|
|
|
|
#bias_ = self.create_ge_tensor([1], 5, bias)
|
|
|
|
|
#bias_ = self.create_ge_tensor([1], 5, bias)
|
|
|
|
|
#const_bias = core.GEOperatorFactory.create_operator("const" + self.getid(), "Const").set_attr_tensor("value", tensor_bias)
|
|
|
|
|
return [scale_value],[[0]]
|
|
|
|
|
|
|
|
|
@ -695,5 +697,7 @@ class ReshapeParser(AscendParserBase):
|
|
|
|
|
tensor = self._create_ge_tensor([len(shape)], 2, shape)
|
|
|
|
|
const_shape = core.GEOperatorFactory.create_operator("shape" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor)
|
|
|
|
|
reshape = core.GEOperatorFactory.create_operator("reshape" + self._accumulated_op_id(), "Reshape").set_input("x", data_x1_shape).set_input("shape", const_shape).set_attr_int32("axis", axis)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return [reshape, reshape], [[0],[1]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|