|
|
|
@ -580,77 +580,73 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
|
|
|
|
|
|
|
|
|
|
Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend,
|
|
|
|
|
std::shared_ptr<Tensor> append) {
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported");
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported");
|
|
|
|
|
|
|
|
|
|
axis = Tensor::HandleNeg(axis, input[0]->shape().Rank());
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported");
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<Tensor> out;
|
|
|
|
|
TensorShape t = TensorShape::CreateScalar();
|
|
|
|
|
|
|
|
|
|
DataType first_dtype = input[0]->type();
|
|
|
|
|
|
|
|
|
|
TensorRow tensor_list;
|
|
|
|
|
|
|
|
|
|
if (prepend != nullptr) {
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(first_dtype == prepend->type(), "Tensor types do not match");
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported");
|
|
|
|
|
RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0]));
|
|
|
|
|
} else {
|
|
|
|
|
out = input[0];
|
|
|
|
|
tensor_list.emplace_back(prepend);
|
|
|
|
|
}
|
|
|
|
|
for (dsize_t i = 1; i < input.size(); i++) {
|
|
|
|
|
std::shared_ptr<Tensor> out_t;
|
|
|
|
|
|
|
|
|
|
for (dsize_t i = 0; i < input.size(); i++) {
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(first_dtype == input[i]->type(), "Tensor types do not match");
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported");
|
|
|
|
|
RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i]));
|
|
|
|
|
out = out_t;
|
|
|
|
|
tensor_list.emplace_back(input[i]);
|
|
|
|
|
}
|
|
|
|
|
std::shared_ptr<Tensor> out_t;
|
|
|
|
|
|
|
|
|
|
if (append != nullptr) {
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(first_dtype == append->type(), "Tensor types do not match");
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported");
|
|
|
|
|
RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append));
|
|
|
|
|
} else {
|
|
|
|
|
out_t = out;
|
|
|
|
|
tensor_list.emplace_back(append);
|
|
|
|
|
}
|
|
|
|
|
output->push_back(out_t);
|
|
|
|
|
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int8_t axis,
|
|
|
|
|
std::shared_ptr<Tensor> append) {
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match");
|
|
|
|
|
|
|
|
|
|
TensorShape t({});
|
|
|
|
|
|
|
|
|
|
for (dsize_t i = 0; i < input->shape().Rank(); i++) {
|
|
|
|
|
// create final shape
|
|
|
|
|
for (dsize_t i = 0; i < tensor_list[0]->shape().Rank(); i++) {
|
|
|
|
|
if (i != axis) {
|
|
|
|
|
t = t.AppendDim(input->shape()[i]);
|
|
|
|
|
t = t.AppendDim(tensor_list[0]->shape()[i]);
|
|
|
|
|
} else {
|
|
|
|
|
dsize_t new_shape = input->shape()[i] + append->shape()[i];
|
|
|
|
|
|
|
|
|
|
dsize_t new_shape = 0;
|
|
|
|
|
for (dsize_t j = 0; j < tensor_list.size(); j++) {
|
|
|
|
|
new_shape = tensor_list[j]->shape()[i] + new_shape;
|
|
|
|
|
}
|
|
|
|
|
t = t.AppendDim(new_shape);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<Tensor> out;
|
|
|
|
|
|
|
|
|
|
if (input->type().IsNumeric()) {
|
|
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(t, input->type(), &out));
|
|
|
|
|
if (input[0]->type().IsNumeric()) {
|
|
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(t, tensor_list[0]->type(), &out));
|
|
|
|
|
std::vector<dsize_t> index(axis + 1, 0);
|
|
|
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(out->Concatenate({0}, input));
|
|
|
|
|
RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append));
|
|
|
|
|
*output = out;
|
|
|
|
|
int n = index.size() - 1;
|
|
|
|
|
for (dsize_t i = 0; i < tensor_list.size(); i++) {
|
|
|
|
|
RETURN_IF_NOT_OK(out->InsertTensor({index}, tensor_list[i], true));
|
|
|
|
|
index[n] = index[n] + tensor_list[i]->shape()[axis];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<std::string> strings;
|
|
|
|
|
|
|
|
|
|
auto itr = input->begin<std::string_view>();
|
|
|
|
|
for (; itr != input->end<std::string_view>(); itr++) {
|
|
|
|
|
strings.emplace_back(*itr);
|
|
|
|
|
}
|
|
|
|
|
itr = append->begin<std::string_view>();
|
|
|
|
|
for (; itr != append->end<std::string_view>(); itr++) {
|
|
|
|
|
strings.emplace_back(*itr);
|
|
|
|
|
for (dsize_t i = 0; i < tensor_list.size(); i++) {
|
|
|
|
|
auto itr = tensor_list[i]->begin<std::string_view>();
|
|
|
|
|
for (; itr != tensor_list[i]->end<std::string_view>(); itr++) {
|
|
|
|
|
strings.emplace_back(*itr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateFromVector(strings, t, &out));
|
|
|
|
|
|
|
|
|
|
*output = out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
output->push_back(out);
|
|
|
|
|
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace dataset
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|