|
|
|
@ -17,7 +17,7 @@
|
|
|
|
|
import inspect
|
|
|
|
|
import copy
|
|
|
|
|
from mindspore.common.api import _wrap_func
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore import context, log as logger
|
|
|
|
|
from .._c_expression import Primitive_, real_run_op, prim_type
|
|
|
|
|
from .._checkparam import Validator
|
|
|
|
|
from . import signature as sig
|
|
|
|
@ -141,6 +141,10 @@ class Primitive(Primitive_):
|
|
|
|
|
Args:
|
|
|
|
|
strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
|
|
|
|
|
"""
|
|
|
|
|
if context.get_auto_parallel_context("parallel_mode") not in [context.ParallelMode.AUTO_PARALLEL,
|
|
|
|
|
context.ParallelMode.SEMI_AUTO_PARALLEL]:
|
|
|
|
|
logger.warning("Shard strategy is not valid in ", context.get_auto_parallel_context("parallel_mode"),
|
|
|
|
|
" mode. Please use semi auto or auto parallel mode.")
|
|
|
|
|
self.add_prim_attr("strategy", strategy)
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|