syn code 505

pull/928/head
changzherui 5 years ago
commit ffcbcc8dcb

@ -1 +1 @@
Subproject commit 43f5d24337bf785251eefae2d810c7d5684194d6
Subproject commit 63cb729373ae8b1b14bc14176c14dac6d18d0e4d

@ -320,6 +320,224 @@ class Validator:
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@staticmethod
def equal(arg_name, arg_value, cond_str, cond):
"""Judging valid value."""
if not cond:
raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
"""This method is only used for check int values, since when compare float values,
we need consider float error."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
@staticmethod
def check_integer(arg_name, arg_value, value, rel):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_shape_length(arg_name, arg_value, value, rel):
"""Shape length judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
return arg_value
@staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""This method is only used for check int values,
since when compare float values, we need consider float error."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isinstance of classes"""
if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value
@staticmethod
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""Is it necessary to consider error when comparing float values."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_subclass(arg_name, type_, template_type, with_type_of=True):
"""Check whether some type is subclass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
@staticmethod
def check_args_tensor(args):
"""Check whether args are all tensor."""
if not isinstance(args, dict):
raise TypeError("The args should be a dict.")
for arg, value in args.items():
ParamValidator.check_subclass(arg, value, mstype.tensor)
@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isinstance of bool"""
if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
return arg_value
@staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
@staticmethod
def check_typename(arg_name, arg_type, valid_types):
"""Does it contain the _name_ attribute."""
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
if len(valid_types) == 1:
raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
@staticmethod
def check_string(arg_name, arg_value, valid_values):
"""String type judgment."""
if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value
if len(valid_values) == 1:
raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
f' but got {arg_value}.')
raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
f' but got {arg_value}.')
@staticmethod
def check_type_same(args, valid_values):
"""Determine whether the types are the same."""
name = list(args.keys())[0]
value = list(args.values())[0]
if isinstance(value, type(mstype.tensor)):
value = value.element_type()
for arg_name, arg_value in args.items():
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
if arg_value not in valid_values:
raise TypeError(f'The `{arg_name}` should be in {valid_values},'
f' but `{arg_name}` is {arg_value}.')
if arg_value != value:
raise TypeError(f'`{arg_name}` should be same as `{name}`,'
f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
@staticmethod
def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
"""Determine whether the types of two variables are the same."""
if arg1_type != arg2_type:
raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
@staticmethod
def check_value_on_integer(arg_name, arg_value, value, rel):
"""Judging integer type."""
rel_fn = Rel.get_fns(rel)
type_match = isinstance(arg_value, int)
if type_match and (not rel_fn(arg_value, value)):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_param_equal(param1_name, param1_value, param2_name, param2_value):
"""Judging the equality of parameters."""
if param1_value != param2_value:
raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
f" but got `{param1_name}` = {param1_value},"
f" `{param2_name}` = {param2_value}.")
@staticmethod
def check_const_input(arg_name, arg_value):
"""Check valid value."""
if arg_value is None:
raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
@staticmethod
def check_float_positive(arg_name, arg_value):
"""Float type judgment."""
if isinstance(arg_value, float):
if arg_value > 0:
return arg_value
raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
raise TypeError(f"`{arg_name}` must be float!")
@staticmethod
def check_pad_value_by_mode(op_name, pad_mode, padding):
"""Validate value of padding according to pad_mode"""
if pad_mode != 'pad' and padding != 0:
raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
return padding
@staticmethod
def check_empty_shape_input(arg_name, arg_value):
"""Check zeros value."""
if 0 in arg_value:
raise ValueError(f"Input `{arg_name}` cannot be empty.")
@staticmethod
def check_scalar_shape_input(arg_name, arg_value):
"""Check scalar shape input."""
if arg_value != []:
raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
def check_int(input_param):
"""Int type judgment."""
if isinstance(input_param, int) and not isinstance(input_param, bool):

@ -201,6 +201,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
if (AnfAlgo::GetCNodeName(kernel) == "ApplyMomentum") {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0);
AnfAlgo::SetOutputAddr(device_address, 0, kernel.get());
AnfAlgo::SetOutputAddr(device_address, 1, kernel.get());
return;
}

@ -27,7 +27,6 @@ namespace kernel {
constexpr auto kInitDataSetQueue = "InitDataSetQueue";
constexpr auto kInitData = "InitData";
constexpr auto kGetNext = "GetNext";
constexpr auto kDropoutGenMask = "DropoutGenMask";
constexpr auto kPrint = "Print";
constexpr auto kOutputTypes = "output_types";

@ -55,7 +55,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm,
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(BatchNorm,
@ -69,7 +68,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm,
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
FusedBatchNormGpuKernel, half)
} // namespace kernel

@ -157,9 +157,6 @@ class FusedBatchNormGpuKernel : public GpuKernel {
output_size_list_.push_back(para_size); // running variance
output_size_list_.push_back(para_size); // save mean
output_size_list_.push_back(para_size); // save variance
if (!is_train_) {
output_size_list_.push_back(para_size); // reserve
}
return;
}

@ -30,6 +30,9 @@ namespace mindspore {
namespace kernel {
namespace tbe {
static std::map<string, string> tbe_func_adapter_map = {
{"softmax", "softmax_v2"},
{"log_softmax", "log_softmax_v2"},
{"apply_momentum", "apply_momentum_d"},
{"re_lu6", "relu6"},
{"re_lu6_grad", "relu6_grad"},
{"re_lu", "relu"},

@ -344,8 +344,23 @@ bool IsNopNode(const AnfNodePtr &node) {
return true;
}
bool IsAllNopNode(session::KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
for (auto &cnode : execution_order) {
MS_EXCEPTION_IF_NULL(cnode);
if (!IsNopNode(cnode)) {
return false;
}
}
return true;
}
void HideNopNode(session::KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph);
if (IsAllNopNode(graph) == true) {
return;
}
auto execution_order = graph->execution_order();
MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
std::vector<CNodePtr> new_nodes;
@ -361,6 +376,9 @@ void HideNopNode(session::KernelGraph *const graph) {
void RemoveNopNode(session::KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph);
if (IsAllNopNode(graph) == true) {
return;
}
bool changed = true;
while (changed) {
changed = false;

@ -177,6 +177,7 @@ const char kNameAbsGrad[] = "AbsGrad";
const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy";
const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad";
const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad";
const char kNameSparseApplyFtrlD[] = "SparseApplyFtrlD";
const char kNameAcosh[] = "Acosh";
const char kNameAcoshGrad[] = "AcoshGrad";
const char kNameFloorMod[] = "FloorMod";
@ -206,7 +207,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameMaxPool), ADPT_DESC(MaxPool)},
{string(kNameAvgPool), ADPT_DESC(AvgPool)},
{string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)},
{string(kNameTopK), ADPT_DESC(TopKV2)},
{string(kNameTopK), ADPT_DESC(TopK)},
{string(kNamePack), ADPT_DESC(Pack)},
{string(kNameUnpack), ADPT_DESC(Unpack)},
{string(kNameSplitD), ADPT_DESC(SplitD)},
@ -240,15 +241,15 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameSquare), ADPT_DESC(Square)},
{prim::kPrimTanh->name(), ADPT_DESC(Tanh)},
{prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad)},
{string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborD)},
{string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborGrad)},
{string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborV2D)},
{string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborV2Grad)},
{string(kNameApplyAdam), ADPT_DESC(ApplyAdam)},
{string(kNameReLU6), ADPT_DESC(Relu6)},
{string(kNameReLU6Grad), ADPT_DESC(Relu6Grad)},
{string(kNameElu), ADPT_DESC(Elu)},
{string(kNameEluGrad), ADPT_DESC(EluGrad)},
{string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearGrad)},
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearD)},
{string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearV2Grad)},
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
{string(kNameZerosLike), ADPT_DESC(ZerosLike)},
{string(kNameOnesLike), ADPT_DESC(OnesLike)},
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
@ -329,7 +330,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{prim::kPrimMinimum->name(), ADPT_DESC(Minimum)},
{prim::kPrimSelect->name(), ADPT_DESC(Select)},
{string(kNameLessEqual), ADPT_DESC(LessEqual)},
{prim::kPrimLogSoftmax->name(), ADPT_DESC(LogSoftmax)},
{prim::kPrimLogSoftmax->name(), ADPT_DESC(LogSoftmaxV2)},
{string(kNameTruncatedNormal), ADPT_DESC(TruncatedNormal)},
{string(kNameStridedSliceGrad), ADPT_DESC(StridedSliceGrad)},
{prim::kPrimGelu->name(), ADPT_DESC(Gelu)},
@ -363,7 +364,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{prim::kPrimMatMul->name(), ADPT_DESC(MatMul)},
{string(kNameConst), ADPT_DESC(Constant, Const)},
{string(kNameSoftmax), ADPT_DESC(Softmax)},
{string(kNameSoftmax), ADPT_DESC(SoftmaxV2)},
{string(kNameSoftmaxGrad), ADPT_DESC(SoftmaxGrad)},
{string(kNameParam), ADPT_DESC(Data)},
{string(kNameROIAlign), ADPT_DESC(ROIAlign)},
@ -373,6 +374,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBinaryCrossEntropy), ADPT_DESC(BinaryCrossEntropy)},
{string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)},
{string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)},
{string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)},
{string(kNameAcosh), ADPT_DESC(Acosh)},
{string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)},
{string(kNameFloorMod), ADPT_DESC(FloorMod)},
@ -390,6 +392,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}};
#ifdef ENABLE_GE
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD);
#endif
return adpt_map;
}
@ -1127,8 +1130,8 @@ void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr
if (desc == nullptr) {
MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null.";
} else {
(void)std::static_pointer_cast<Data>(op)->update_input_desc_data(*desc);
(void)std::static_pointer_cast<Data>(op)->update_output_desc_out(*desc);
(void)std::static_pointer_cast<Data>(op)->update_input_desc_x(*desc);
(void)std::static_pointer_cast<Data>(op)->update_output_desc_y(*desc);
}
}

@ -736,7 +736,7 @@ class OpAdapter : public BaseOpAdapter {
return static_cast<int64_t>(GetValue<int>(value));
}
// specialization for int to Vector
// specialization for int or tuple broadcast to Vector
static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &name,
const AnyTraits<std::vector<int64_t>> anyTraitsInt) {
return ConvertAnyUtil(value, name, anyTraitsInt);

@ -35,15 +35,21 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor
std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name,
const AnyTraits<std::vector<int64_t>>) {
int64_t data = GetValue<int>(value);
MS_EXCEPTION_IF_NULL(value);
std::vector<int64_t> list;
int size = 2; // 2 int in list
if (name == "pad") {
size = 4; // 4 int in list
list = TransformUtil::ConvertIntToList(data, size);
if (!value->isa<ValueSequeue>()) {
MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name();
}
auto vec = value->cast<ValueSequeuePtr>();
list.resize(vec->value().size() + 2);
list[0] = 1;
list[1] = 1;
(void)std::transform(vec->value().begin(), vec->value().end(), list.begin() + 2,
[](const ValuePtr &val) { return static_cast<int64_t>(GetValue<int>(val)); });
} else {
int64_t data = GetValue<int>(value);
int size = 2; // 2 int in list
list = TransformUtil::ConvertIntToList(data, size);
}

File diff suppressed because it is too large Load Diff

@ -114,20 +114,22 @@ DECLARE_OP_ADAPTER(Reshape)
DECLARE_OP_USE_OUTPUT(Reshape)
DECLARE_OP_ADAPTER(Iou)
DECLARE_OP_USE_OUTPUT(Iou)
DECLARE_OP_ADAPTER(ResizeNearestNeighborD)
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborD)
DECLARE_OP_ADAPTER(ResizeNearestNeighborGrad)
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborGrad)
DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D)
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2D)
DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad)
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad)
DECLARE_OP_ADAPTER(ApplyAdam)
DECLARE_OP_USE_OUTPUT(ApplyAdam)
DECLARE_OP_ADAPTER(ApplyAdamD)
DECLARE_OP_USE_OUTPUT(ApplyAdamD)
DECLARE_OP_ADAPTER(Relu6)
DECLARE_OP_USE_OUTPUT(Relu6)
DECLARE_OP_ADAPTER(Relu6Grad)
DECLARE_OP_USE_OUTPUT(Relu6Grad)
DECLARE_OP_ADAPTER(ResizeBilinearD)
DECLARE_OP_USE_OUTPUT(ResizeBilinearD)
DECLARE_OP_ADAPTER(ResizeBilinearGrad)
DECLARE_OP_USE_OUTPUT(ResizeBilinearGrad)
DECLARE_OP_ADAPTER(ResizeBilinearV2D)
DECLARE_OP_USE_OUTPUT(ResizeBilinearV2D)
DECLARE_OP_ADAPTER(ResizeBilinearV2Grad)
DECLARE_OP_USE_OUTPUT(ResizeBilinearV2Grad)
DECLARE_OP_ADAPTER(ZerosLike)
DECLARE_OP_USE_OUTPUT(ZerosLike)
DECLARE_OP_ADAPTER(OnesLike)
@ -213,8 +215,8 @@ DECLARE_OP_USE_OUTPUT(Merge)
DECLARE_OP_ADAPTER(Switch)
DECLARE_OP_USE_OUTPUT(Switch)
DECLARE_OP_ADAPTER(TopKV2)
DECLARE_OP_USE_OUTPUT(TopKV2)
DECLARE_OP_ADAPTER(TopK)
DECLARE_OP_USE_OUTPUT(TopK)
DECLARE_OP_ADAPTER(RealDiv)
DECLARE_OP_USE_OUTPUT(RealDiv)
@ -264,8 +266,8 @@ DECLARE_OP_ADAPTER(Select)
DECLARE_OP_USE_OUTPUT(Select)
DECLARE_OP_ADAPTER(LessEqual)
DECLARE_OP_USE_OUTPUT(LessEqual)
DECLARE_OP_ADAPTER(LogSoftmax)
DECLARE_OP_USE_OUTPUT(LogSoftmax)
DECLARE_OP_ADAPTER(LogSoftmaxV2)
DECLARE_OP_USE_OUTPUT(LogSoftmaxV2)
DECLARE_OP_ADAPTER(TruncatedNormal)
DECLARE_OP_USE_OUTPUT(TruncatedNormal)
DECLARE_OP_ADAPTER(StridedSliceGrad)
@ -400,8 +402,8 @@ DECLARE_OP_ADAPTER(Sigmoid)
DECLARE_OP_USE_OUTPUT(Sigmoid)
DECLARE_OP_ADAPTER(SigmoidGrad)
DECLARE_OP_USE_OUTPUT(SigmoidGrad)
DECLARE_OP_ADAPTER(Softmax)
DECLARE_OP_USE_OUTPUT(Softmax)
DECLARE_OP_ADAPTER(SoftmaxV2)
DECLARE_OP_USE_OUTPUT(SoftmaxV2)
DECLARE_OP_ADAPTER(SoftmaxGrad)
DECLARE_OP_USE_OUTPUT(SoftmaxGrad)
DECLARE_OP_ADAPTER(Greater)
@ -444,6 +446,8 @@ DECLARE_OP_ADAPTER(Round)
DECLARE_OP_USE_OUTPUT(Round)
DECLARE_OP_ADAPTER(ApplyFtrl)
DECLARE_OP_USE_OUTPUT(ApplyFtrl)
DECLARE_OP_ADAPTER(SparseApplyFtrlD)
DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD)
DECLARE_OP_ADAPTER(Diag)
DECLARE_OP_USE_OUTPUT(Diag)
DECLARE_OP_ADAPTER(DiagPart)

@ -31,6 +31,7 @@
namespace mindspore {
const char kShapeSeperator[] = ",";
const char kShapeScalar[] = "[0]";
static std::map<std::string, TypeId> print_type_map = {
{"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8},
{"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16},
@ -81,6 +82,73 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co
return true;
}
template <typename T>
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type) {
const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr);
std::ostringstream buf_scalar;
buf_scalar << "Tensor shape :1 " << tensor_type;
buf_scalar << "\nval:";
buf_scalar << *data_ptr;
std::cout << buf_scalar.str() << std::endl;
}
void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type) {
const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr);
std::ostringstream buf_scalar;
buf_scalar << "Tensor shape :1 " << tensor_type;
buf_scalar << "\nval:";
if (*data_ptr == true) {
buf_scalar << "True";
} else {
buf_scalar << "False";
}
std::cout << buf_scalar.str() << std::endl;
}
void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type) {
auto type_iter = print_type_map.find(tensor_type);
auto type_id = type_iter->second;
if (type_id == TypeId::kNumberTypeBool) {
PrintScalarToBoolString(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeInt8) {
PrintScalarToString<int8_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeUInt8) {
PrintScalarToString<uint8_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeInt16) {
PrintScalarToString<int16_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeUInt16) {
PrintScalarToString<uint16_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeInt32) {
PrintScalarToString<int32_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeUInt32) {
PrintScalarToString<uint32_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeInt64) {
PrintScalarToString<int64_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeUInt64) {
PrintScalarToString<uint64_t>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeFloat16) {
PrintScalarToString<float16>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeFloat32) {
PrintScalarToString<float>(str_data_ptr, tensor_type);
} else if (type_id == TypeId::kNumberTypeFloat64) {
PrintScalarToString<double>(str_data_ptr, tensor_type);
} else {
MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << ".";
}
} // namespace mindspore
bool judgeLengthValid(const size_t str_len, const string &tensor_type) {
auto type_iter = type_size_map.find(tensor_type);
if (type_iter == type_size_map.end()) {
MS_LOG(EXCEPTION) << "type of scalar to print is not support.";
}
if (str_len != type_iter->second) {
return false;
}
return true;
}
#ifndef NO_DLIB
bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
// Acquire Python GIL
@ -92,14 +160,22 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
ret_end_sequence = true;
break;
}
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_);
MS_EXCEPTION_IF_NULL(str_data_ptr);
if (item.tensorShape_ == kShapeScalar) {
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) {
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
}
convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_);
continue;
}
std::vector<int> tensor_shape;
size_t totaldims = 1;
if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) {
MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_;
continue;
}
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_);
MS_EXCEPTION_IF_NULL(str_data_ptr);
if (item.tensorType_ == "string") {
std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_);

@ -377,12 +377,10 @@ def get_bprop_batch_norm(self):
if is_training:
saved_reserve_1 = out[3]
saved_reserve_2 = out[4]
saved_reserve_3 = out[5]
else:
saved_reserve_1 = mean
saved_reserve_2 = variance
saved_reserve_3 = variance
out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2, saved_reserve_3)
out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2)
dx = out[0]
dscale = out[1]
dbias = out[2]

@ -18,3 +18,8 @@ from .dropout_genmask import _dropout_genmask_aicpu
from .get_next import _get_next_aicpu
from .print_tensor import _print_aicpu
from .topk import _top_k_aicpu
from .is_finite import _is_finite_aicpu
from .reshape import _reshape_aicpu
from .flatten import _flatten_aicpu
from .squeeze import _squeeze_aicpu
from .expand_dims import _expand_dims_aicpu

@ -0,0 +1,52 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ExpandDims op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
expand_dims_op_info = AiCPURegOp("ExpandDims") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.F64_NCHW) \
.get_op_info()
@op_info_register(expand_dims_op_info)
def _expand_dims_aicpu():
"""ExpandDims AiCPU register"""
return

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Flatten op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
flatten_op_info = AiCPURegOp("Flatten") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.get_op_info()
@op_info_register(flatten_op_info)
def _flatten_aicpu():
"""Flatten AiCPU register"""
return

@ -0,0 +1,52 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""IsFinite op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
is_finite_op_info = AiCPURegOp("IsFinite") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.BOOL_NCHW) \
.get_op_info()
@op_info_register(is_finite_op_info)
def _is_finite_aicpu():
"""IsFinite AiCPU register"""
return

@ -0,0 +1,52 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reshape op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
reshape_op_info = AiCPURegOp("Reshape") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.F64_NCHW) \
.get_op_info()
@op_info_register(reshape_op_info)
def _reshape_aicpu():
"""Rpeshape AiCPU register"""
return

@ -0,0 +1,52 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Squeeze op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
squeeze_op_info = AiCPURegOp("Squeeze") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.F64_NCHW) \
.get_op_info()
@op_info_register(squeeze_op_info)
def _squeeze_aicpu():
"""Squeeze AiCPU register"""
return

@ -61,9 +61,6 @@ from .reduce_mean_d import _reduce_mean_d_tbe
from .scatter_nd import _scatter_nd_tbe
from .scatter_nd_d import _scatter_nd_d_tbe
from .reduce_mean import _reduce_mean_tbe
from .reshape import _reshape_tbe
from .expand_dims import _expand_dims_tbe
from .squeeze import _squeeze_tbe
from .tile import _tile_tbe
from .atomic_addr_clean import _atomic_addr_clean_tbe
from .gather_v2 import _gather_v2_tbe

@ -30,22 +30,23 @@ apply_momentum_op_info = TBERegOp("ApplyMomentum") \
.input(3, "grad", False, "required", "all") \
.input(4, "momentum", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD,
DataType.F16_Default, DataType.F16_5HD) \
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0,
DataType.F16_Default, DataType.F16_C1HWNCoC0) \
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
DataType.F16_Default, DataType.F16_FracZ) \
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
DataType.F32_Default, DataType.F32_5HD) \
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
DataType.F32_Default, DataType.F32_C1HWNCoC0) \
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
DataType.F32_Default, DataType.F32_FracZ) \
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
.get_op_info()

@ -36,19 +36,18 @@ batch_norm_op_info = TBERegOp("BatchNorm") \
.output(2, "batch_variance", False, "required", "all") \
.output(3, "reserve_space_1", False, "optional", "all") \
.output(4, "reserve_space_2", False, "optional", "all") \
.output(5, "reserve_space_3", False, "optional", "all") \
.dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@ -599,4 +599,13 @@ class DataType:
F32_NCHW = ("float32", "NCHW")
F32_NHWC = ("float32", "NHWC")
F32_HWCN = ("float32", "HWCN")
F64_None = ("float64", "")
F64_Default = ("float64", "DefaultFormat")
F64_5HD = ("float64", "NC1HWC0")
F64_FracZ = ("float64", "FracZ")
F64_FracNZ = ("float64", "FRACTAL_NZ")
F64_C1HWNCoC0 = ("float64", "C1HWNCoC0")
F64_NCHW = ("float64", "NCHW")
F64_NHWC = ("float64", "NHWC")
F64_HWCN = ("float64", "HWCN")

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save