From 96229b7358b18da7be54e2d162ead04ffee67f5e Mon Sep 17 00:00:00 2001 From: chujinjin Date: Thu, 4 Feb 2021 15:33:19 +0800 Subject: [PATCH] support swtichlayer for pynative --- mindspore/common/tensor.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 5ac9e05626..225b5c6d2d 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -161,6 +161,20 @@ class Tensor(Tensor_): return bool(data[0]) raise ValueError("The truth value of an array with several elements is ambiguous.") + def __index__(self): + data = self.asnumpy() + if not (data.dtype == "int8" + or data.dtype == "int16" + or data.dtype == "int32" + or data.dtype == "int64" + or data.dtype == "bool"): + raise ValueError("Only integer tensors of a single element can be converted to an index.") + if data.shape == (): + return int(data) + if data.shape == (1,): + return int(data[0]) + raise ValueError("Only integer tensors of a single element can be converted to an index.") + def __pos__(self): return self