!8648 PyNative Performance Optimization

From: @jojobugfree
Reviewed-by: 
Signed-off-by:
pull/8648/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 3f75f13556

@ -24,6 +24,7 @@
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace opt {
@ -41,6 +42,11 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format");
MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return nullptr;
}
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}

@ -248,7 +248,8 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
}
}
if (need_sync) {
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) ||
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
tensor->set_device_address(device_address);
}
MS_EXCEPTION_IF_NULL(device_address);

@ -91,6 +91,7 @@ class LARS(Optimizer):
self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr")
self.decay_flags = optimizer.decay_flags
self.reciprocal_scale = optimizer.reciprocal_scale
self.need_scale = optimizer.need_scale
self.hyper_map = C.HyperMap()
self.lars = P.LARSUpdate(epsilon, coefficient, use_clip)
self.cast = P.Cast()
@ -136,7 +137,7 @@ class LARS(Optimizer):
else:
lr = self.learning_rate
if self.reciprocal_scale != 1.0:
if self.need_scale:
gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients)
if self.is_group:

@ -138,12 +138,14 @@ class Optimizer(Cell):
if self.is_group:
self.parameters = ParameterTuple(self.group_params)
self.weight_decay = tuple(self.group_weight_decay)
self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay)
decay_filter = lambda x: x > 0
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
self.exec_weight_decay = any(self.decay_flags)
else:
self.parameters = ParameterTuple(parameters)
self.weight_decay = weight_decay * loss_scale
self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32)
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.exec_weight_decay = self.weight_decay > 0
@ -154,7 +156,8 @@ class Optimizer(Cell):
break
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
self.reciprocal_scale = 1.0 / loss_scale
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
self.need_scale = loss_scale != 1.0
self.param_length = len(self.parameters)
self.map_ = C.Map()
if context.get_auto_parallel_context("enable_parallel_optimizer"):
@ -222,10 +225,10 @@ class Optimizer(Cell):
if self.exec_weight_decay:
params = self.parameters
if self.is_group:
gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
gradients = self.map_(F.partial(_apply_decay), self.weight_decay_tensor_tuple, self.decay_flags,
params, gradients)
else:
gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
gradients = self.map_(F.partial(_apply_decay, self.weight_decay_tensor), self.decay_flags,
params, gradients)
return gradients
@ -245,7 +248,7 @@ class Optimizer(Cell):
tuple[Tensor], The gradients after loss scale.
"""
if self.reciprocal_scale != 1.0:
if self.need_scale:
gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients)
return gradients
@ -529,11 +532,12 @@ class Optimizer(Cell):
op_add = P.AddN()
op_gather = P.GatherV2()
op_mul = P.Mul()
_apply_decay = C.MultitypeFuncGraph("apply_decay")
@_apply_decay.register("Number", "Bool", "Tensor", "RowTensor")
@_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor")
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
@ -544,11 +548,11 @@ def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
return gradient
@_apply_decay.register("Number", "Bool", "Tensor", "Tensor")
@_apply_decay.register("Tensor", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((weight * weight_decay, gradient))
return op_add((op_mul(weight, weight_decay), gradient))
return gradient
@ -560,14 +564,16 @@ def tensor_grad_scale(scale, grad):
"""Get grad with scale."""
if scale == 1.0:
return grad
return grad * scale
return op_mul(grad, scale)
@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale_with_tensor(scale, grad):
"""Get grad with scale."""
return op_mul(grad, scale)
@_grad_scale.register("Number", "RowTensor")
@_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_with_sparse(scale, grad):
"""Get grad with scale."""
if scale == 1.0:
return grad
return RowTensor(grad.indices, grad.values * scale, grad.dense_shape)

Loading…
Cancel
Save