|
|
|
@ -22,25 +22,58 @@
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace kernel {
|
|
|
|
|
using Tensor = mindspore::tensor::Tensor;
|
|
|
|
|
using TensorPtr = mindspore::tensor::TensorPtr;
|
|
|
|
|
using AbstractTensor = mindspore::abstract::AbstractTensor;
|
|
|
|
|
using AbstractTensorPtr = mindspore::abstract::AbstractTensorPtr;
|
|
|
|
|
using CheckSupportFun = bool (*)(const CNodePtr &cnode);
|
|
|
|
|
|
|
|
|
|
constexpr char kAttrStrides[] = "strides";
|
|
|
|
|
constexpr char kAttrShrinkAxisMask[] = "shrink_axis_mask";
|
|
|
|
|
|
|
|
|
|
static bool CheckStridedSlice(const CNodePtr &cnode) {
|
|
|
|
|
// check stride[-1] != 1 TODO
|
|
|
|
|
// check stride[-1] != 1
|
|
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrStrides, cnode)) {
|
|
|
|
|
auto strides = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrStrides);
|
|
|
|
|
if (!strides.empty() && strides[strides.size() - 1] == 1) {
|
|
|
|
|
return true;
|
|
|
|
|
if (!strides.empty() && strides[strides.size() - 1] != 1) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// check reduction on the last dimension
|
|
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, cnode)) {
|
|
|
|
|
auto shrink_axis_mask = AnfAlgo::GetNodeAttr<int>(cnode, kAttrShrinkAxisMask);
|
|
|
|
|
AnfNodePtr input = cnode->input(1);
|
|
|
|
|
int input_dims = 0;
|
|
|
|
|
if (input->isa<ValueNode>()) {
|
|
|
|
|
ValuePtr input_value = input->cast<ValueNodePtr>()->value();
|
|
|
|
|
if (!input_value->isa<Tensor>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got "
|
|
|
|
|
<< input_value->ToString();
|
|
|
|
|
}
|
|
|
|
|
input_dims = SizeToInt(input_value->cast<TensorPtr>()->shape().size());
|
|
|
|
|
} else if (input->isa<CNode>() || input->isa<Parameter>()) {
|
|
|
|
|
AbstractBasePtr input_abstract = input->abstract();
|
|
|
|
|
if (!input_abstract->isa<AbstractTensor>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input value should be a tensor, but got "
|
|
|
|
|
<< input_abstract->ToString();
|
|
|
|
|
}
|
|
|
|
|
input_dims = SizeToInt(input_abstract->cast<AbstractTensorPtr>()->shape()->shape().size());
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "For 'StrideSlice', the first input node should be a 'ValueNode' or a 'CNode', but got "
|
|
|
|
|
<< input->ToString();
|
|
|
|
|
}
|
|
|
|
|
int base_number = 2;
|
|
|
|
|
if (shrink_axis_mask >= std::pow<int, int>(base_number, input_dims - 1)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// last tensor TODO
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool TbePropertyChecker::CheckTbeProperties(const mindspore::CNodePtr &cnode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
static std::map<std::string, CheckSupportFun> tbe_property_checker = {{parallel::KStridedSlice, CheckStridedSlice}};
|
|
|
|
|
static std::map<std::string, CheckSupportFun> tbe_property_checker = {{kStridedSliceOpName, CheckStridedSlice},
|
|
|
|
|
{kStridedSliceGradOpName, CheckStridedSlice}};
|
|
|
|
|
auto cnode_type = AnfAlgo::GetCNodeName(cnode);
|
|
|
|
|
auto find_iter = tbe_property_checker.find(cnode_type);
|
|
|
|
|
if (find_iter != tbe_property_checker.end()) {
|
|
|
|
|