|
|
@ -23,6 +23,7 @@
|
|
|
|
#include <mutex>
|
|
|
|
#include <mutex>
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
|
#include <utility>
|
|
|
|
#include <utility>
|
|
|
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
|
|
|
|
|
|
|
#include "frontend/operator/cc_implementations.h"
|
|
|
|
#include "frontend/operator/cc_implementations.h"
|
|
|
|
#include "frontend/operator/ops.h"
|
|
|
|
#include "frontend/operator/ops.h"
|
|
|
@ -62,6 +63,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|
|
|
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
|
|
|
|
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
|
|
|
|
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
|
|
|
|
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
|
|
|
|
{prim::kPrimPack, {InferImplPack, true}},
|
|
|
|
{prim::kPrimPack, {InferImplPack, true}},
|
|
|
|
|
|
|
|
{prim::kPrimUnique, {InferImplUnique, true}},
|
|
|
|
|
|
|
|
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
|
|
|
|
// Structure
|
|
|
|
// Structure
|
|
|
|
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
|
|
|
|
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
|
|
|
|
{prim::kPrimMakeList, {InferImplMakeList, true}},
|
|
|
|
{prim::kPrimMakeList, {InferImplMakeList, true}},
|
|
|
@ -389,6 +392,14 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
|
|
|
if (abs_base->isa<AbstractTensor>()) {
|
|
|
|
if (abs_base->isa<AbstractTensor>()) {
|
|
|
|
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
|
|
|
|
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
|
|
|
|
dic["shape"] = arg_tensor->shape()->shape();
|
|
|
|
dic["shape"] = arg_tensor->shape()->shape();
|
|
|
|
|
|
|
|
if (MsContext::GetInstance()->execution_mode() == kGraphMode) {
|
|
|
|
|
|
|
|
const auto &min_shape = arg_tensor->shape()->min_shape();
|
|
|
|
|
|
|
|
const auto &max_shape = arg_tensor->shape()->max_shape();
|
|
|
|
|
|
|
|
if (!min_shape.empty() && !max_shape.empty()) {
|
|
|
|
|
|
|
|
dic["min_shape"] = min_shape;
|
|
|
|
|
|
|
|
dic["max_shape"] = max_shape;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
dic["dtype"] = arg_tensor->BuildType();
|
|
|
|
dic["dtype"] = arg_tensor->BuildType();
|
|
|
|
dic["value"] = BuildValue(arg_tensor->BuildValue());
|
|
|
|
dic["value"] = BuildValue(arg_tensor->BuildValue());
|
|
|
|
} else if (abs_base->isa<AbstractIndexedSlices>()) {
|
|
|
|
} else if (abs_base->isa<AbstractIndexedSlices>()) {
|
|
|
@ -503,7 +514,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
|
|
|
|
if (output["value"].is_none()) {
|
|
|
|
if (output["value"].is_none()) {
|
|
|
|
auto out_shape = output["shape"];
|
|
|
|
auto out_shape = output["shape"];
|
|
|
|
auto out_dtype = output["dtype"];
|
|
|
|
auto out_dtype = output["dtype"];
|
|
|
|
return PyListDtype2AbstractTensor(out_shape, out_dtype);
|
|
|
|
py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none();
|
|
|
|
|
|
|
|
py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Convert pyobject to Value, then to AbstractValue
|
|
|
|
// Convert pyobject to Value, then to AbstractValue
|
|
|
|
ValuePtr converted_ret = nullptr;
|
|
|
|
ValuePtr converted_ret = nullptr;
|
|
|
|