|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
/**
|
|
|
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
|
|
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
|
|
|
|
*
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
@ -34,35 +34,6 @@ AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
|
|
|
|
|
return abs_base;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
|
|
|
const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
// Inputs: two tensors.
|
|
|
|
|
const std::string op_name = primitive->name();
|
|
|
|
|
CheckArgsSize(op_name, args_spec_list, 2);
|
|
|
|
|
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
|
|
|
|
AbstractTensorPtr input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
|
|
|
|
|
|
|
|
|
ShapePtr x_shp = input_x->shape();
|
|
|
|
|
auto x_shp_value = x_shp->shape();
|
|
|
|
|
ShapePtr y_shp = input_y->shape();
|
|
|
|
|
auto y_shp_value = y_shp->shape();
|
|
|
|
|
// Should be matrix which shape size is 2.
|
|
|
|
|
if (x_shp_value.size() != 2 || y_shp_value.size() != 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are "
|
|
|
|
|
<< x_shp_value.size() << ", " << y_shp_value.size() << " ";
|
|
|
|
|
}
|
|
|
|
|
if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto x_element = input_x->element();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(x_element);
|
|
|
|
|
(void)x_element->Join(input_y->element());
|
|
|
|
|
auto param = {x_shp_value[0], y_shp_value[1]};
|
|
|
|
|
|
|
|
|
|
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(param));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &prim,
|
|
|
|
|
const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
// Inputs: condition, true branch, false branch
|
|
|
|
|