|
|
@ -319,6 +319,27 @@ void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tenso
|
|
|
|
input_tensors->push_back(tensor_ptr);
|
|
|
|
input_tensors->push_back(tensor_ptr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
|
|
|
|
|
|
|
|
std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_prim);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_tensors);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_mask);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!py::isinstance<py::tuple>(input_object)) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "The input should be a tuple!";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto tuple_inputs = py::cast<py::tuple>(input_object);
|
|
|
|
|
|
|
|
if (tuple_inputs.size() == 0) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
|
|
|
|
|
|
|
|
PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
ConvertValueTupleToTensor(input_object, input_tensors);
|
|
|
|
|
|
|
|
*tensor_mask = kValueNodeTensorMask;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
|
|
|
|
void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
|
|
|
|
std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
|
|
|
|
std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
|
|
|
|
MS_EXCEPTION_IF_NULL(op_prim);
|
|
|
|
MS_EXCEPTION_IF_NULL(op_prim);
|
|
|
@ -333,20 +354,20 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr
|
|
|
|
} else if (py::isinstance<py::int_>(input_object)) {
|
|
|
|
} else if (py::isinstance<py::int_>(input_object)) {
|
|
|
|
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32);
|
|
|
|
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32);
|
|
|
|
*tensor_mask = kValueNodeTensorMask;
|
|
|
|
*tensor_mask = kValueNodeTensorMask;
|
|
|
|
} else if (py::isinstance<py::list>(input_object)) {
|
|
|
|
|
|
|
|
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::list>(input_object), nullptr);
|
|
|
|
|
|
|
|
} else if (py::isinstance<py::array>(input_object)) {
|
|
|
|
} else if (py::isinstance<py::array>(input_object)) {
|
|
|
|
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr);
|
|
|
|
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr);
|
|
|
|
} else if (py::isinstance<py::none>(input_object)) {
|
|
|
|
} else if (py::isinstance<py::list>(input_object)) {
|
|
|
|
|
|
|
|
auto list_inputs = py::cast<py::list>(input_object);
|
|
|
|
|
|
|
|
py::tuple tuple_inputs(list_inputs.size());
|
|
|
|
|
|
|
|
for (size_t i = 0; i < tuple_inputs.size(); ++i) {
|
|
|
|
|
|
|
|
tuple_inputs[i] = list_inputs[i];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
} else if (py::isinstance<py::tuple>(input_object)) {
|
|
|
|
} else if (py::isinstance<py::tuple>(input_object)) {
|
|
|
|
auto tuple_inputs = py::cast<py::tuple>(input_object);
|
|
|
|
ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
|
|
|
|
if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
|
|
|
|
return;
|
|
|
|
PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
|
|
|
|
} else if (py::isinstance<py::none>(input_object)) {
|
|
|
|
} else {
|
|
|
|
|
|
|
|
ConvertValueTupleToTensor(input_object, input_tensors);
|
|
|
|
|
|
|
|
*tensor_mask = kValueNodeTensorMask;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
|
|
|
|
MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
|
|
|
|