pynative-support-list-input

pull/1765/head
lvliang 5 years ago
parent 95ef02af7d
commit e046e6dd52

@ -319,6 +319,27 @@ void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tenso
input_tensors->push_back(tensor_ptr); input_tensors->push_back(tensor_ptr);
} }
void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
MS_EXCEPTION_IF_NULL(op_prim);
MS_EXCEPTION_IF_NULL(input_tensors);
MS_EXCEPTION_IF_NULL(tensor_mask);
if (!py::isinstance<py::tuple>(input_object)) {
MS_LOG(EXCEPTION) << "The input should be a tuple!";
}
auto tuple_inputs = py::cast<py::tuple>(input_object);
if (tuple_inputs.size() == 0) {
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
}
if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
} else {
ConvertValueTupleToTensor(input_object, input_tensors);
*tensor_mask = kValueNodeTensorMask;
}
}
void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) { std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
MS_EXCEPTION_IF_NULL(op_prim); MS_EXCEPTION_IF_NULL(op_prim);
@ -333,20 +354,20 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr
} else if (py::isinstance<py::int_>(input_object)) { } else if (py::isinstance<py::int_>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32); tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32);
*tensor_mask = kValueNodeTensorMask; *tensor_mask = kValueNodeTensorMask;
} else if (py::isinstance<py::list>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::list>(input_object), nullptr);
} else if (py::isinstance<py::array>(input_object)) { } else if (py::isinstance<py::array>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr); tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr);
} else if (py::isinstance<py::none>(input_object)) { } else if (py::isinstance<py::list>(input_object)) {
auto list_inputs = py::cast<py::list>(input_object);
py::tuple tuple_inputs(list_inputs.size());
for (size_t i = 0; i < tuple_inputs.size(); ++i) {
tuple_inputs[i] = list_inputs[i];
}
ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
return; return;
} else if (py::isinstance<py::tuple>(input_object)) { } else if (py::isinstance<py::tuple>(input_object)) {
auto tuple_inputs = py::cast<py::tuple>(input_object); ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) { return;
PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors); } else if (py::isinstance<py::none>(input_object)) {
} else {
ConvertValueTupleToTensor(input_object, input_tensors);
*tensor_mask = kValueNodeTensorMask;
}
return; return;
} else { } else {
MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";

Loading…
Cancel
Save