parent
ea87b6c443
commit
87aa9c8f7a
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/kernels/data/pad_end_op.h"
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/data/data_utils.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status PadEndOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
Status s = PadEnd(input, output, output_shape_.AsVector(), pad_val_);
|
||||
return s;
|
||||
}
|
||||
|
||||
Status PadEndOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
|
||||
outputs.clear();
|
||||
for (auto s : inputs) {
|
||||
outputs.emplace_back(TensorShape(output_shape_.AsVector()));
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "Input has a wrong shape");
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_KERNELS_DATA_PAD_END_OP_H_
|
||||
#define DATASET_KERNELS_DATA_PAD_END_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class PadEndOp : public TensorOp {
|
||||
public:
|
||||
explicit PadEndOp(const TensorShape &pad_shape, const std::shared_ptr<Tensor> &pad_value)
|
||||
: output_shape_(pad_shape), pad_val_(pad_value) {}
|
||||
|
||||
~PadEndOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override { out << "PadEndOp"; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
private:
|
||||
TensorShape output_shape_;
|
||||
std::shared_ptr<Tensor> pad_val_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_KERNELS_DATA_PAD_END_OP_H_
|
@ -0,0 +1,140 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "dataset/kernels/data/pad_end_op.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestPadEndOp : public UT::Common {
|
||||
protected:
|
||||
MindDataTestPadEndOp() {}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestPadEndOp, TestOp) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPadEndOp.";
|
||||
|
||||
// first set of testunits for numeric values
|
||||
|
||||
TensorShape pad_data_shape({1});
|
||||
|
||||
// prepare input tensor
|
||||
float_t orig1[4] = {1, 1, 1, 1};
|
||||
TensorShape input_shape1({2, 2});
|
||||
std::vector<TensorShape> input_shape1_vector = {input_shape1};
|
||||
std::shared_ptr<Tensor> input1 =
|
||||
std::make_shared<Tensor>(input_shape1, DataType(DataType::DE_FLOAT32), reinterpret_cast<unsigned char *>(orig1));
|
||||
|
||||
// pad_shape
|
||||
TensorShape pad_shape1[3] = {TensorShape({3, 3}), TensorShape({2, 4}), TensorShape({4, 2})};
|
||||
|
||||
// value to pad
|
||||
float_t pad_data1[3][1] = {0, 3.5, 3.5};
|
||||
|
||||
std::shared_ptr<Tensor> expected1[3];
|
||||
|
||||
// expected tensor output for testunit 1
|
||||
float_t out1[9] = {1, 1, 0, 1, 1, 0, 0, 0, 0};
|
||||
|
||||
expected1[0] =
|
||||
std::make_shared<Tensor>(pad_shape1[0], DataType(DataType::DE_FLOAT32), reinterpret_cast<unsigned char *>(out1));
|
||||
|
||||
// expected tensor output for testunit 2
|
||||
float_t out2[8] = {1, 1, 3.5, 3.5, 1, 1, 3.5, 3.5};
|
||||
|
||||
expected1[1] =
|
||||
std::make_shared<Tensor>(pad_shape1[1], DataType(DataType::DE_FLOAT32), reinterpret_cast<unsigned char *>(out2));
|
||||
|
||||
// expected tensor output for testunit 3
|
||||
float_t out3[8] = {1, 1, 1, 1, 3.5, 3.5, 3.5, 3.5};
|
||||
|
||||
expected1[2] =
|
||||
std::make_shared<Tensor>(pad_shape1[2], DataType(DataType::DE_FLOAT32), reinterpret_cast<unsigned char *>(out3));
|
||||
|
||||
// run the PadEndOp
|
||||
for (auto i = 0; i < 3; i++) {
|
||||
std::shared_ptr<Tensor> output;
|
||||
std::vector<TensorShape> output_shape = {TensorShape({})};
|
||||
std::shared_ptr<Tensor> pad_value1 = std::make_shared<Tensor>(pad_data_shape, DataType(DataType::DE_FLOAT32),
|
||||
reinterpret_cast<unsigned char *>(pad_data1[i]));
|
||||
std::unique_ptr<PadEndOp> op(new PadEndOp(pad_shape1[i], pad_value1));
|
||||
Status s = op->Compute(input1, &output);
|
||||
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
ASSERT_TRUE(output->shape() == expected1[i]->shape());
|
||||
ASSERT_TRUE(output->type() == expected1[i]->type());
|
||||
MS_LOG(DEBUG) << *output << std::endl;
|
||||
MS_LOG(DEBUG) << *expected1[i] << std::endl;
|
||||
ASSERT_TRUE(*output == *expected1[i]);
|
||||
|
||||
s = op->OutputShape(input_shape1_vector, output_shape);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
ASSERT_TRUE(output_shape.size() == 1);
|
||||
ASSERT_TRUE(output->shape() == output_shape[0]);
|
||||
}
|
||||
|
||||
// second set of testunits for string
|
||||
|
||||
// input tensor
|
||||
std::vector<std::string> orig2 = {"this", "is"};
|
||||
TensorShape input_shape2({2});
|
||||
std::vector<TensorShape> input_shape2_vector = {input_shape2};
|
||||
std::shared_ptr<Tensor> input2;
|
||||
Tensor::CreateTensor(&input2, orig2, input_shape2);
|
||||
|
||||
// pad_shape
|
||||
TensorShape pad_shape2[3] = {TensorShape({5}), TensorShape({2}), TensorShape({10})};
|
||||
|
||||
// pad value
|
||||
std::vector<std::string> pad_data2[3] = {{""}, {"P"}, {" "}};
|
||||
std::shared_ptr<Tensor> pad_value2[3];
|
||||
|
||||
// expected output for 3 testunits
|
||||
std::shared_ptr<Tensor> expected2[3];
|
||||
std::vector<std::string> outstring[3] = {
|
||||
{"this", "is", "", "", ""}, {"this", "is"}, {"this", "is", " ", " ", " ", " ", " ", " ", " ", " "}};
|
||||
|
||||
for (auto i = 0; i < 3; i++) {
|
||||
// pad value
|
||||
Tensor::CreateTensor(&pad_value2[i], pad_data2[i], pad_data_shape);
|
||||
|
||||
std::shared_ptr<Tensor> output;
|
||||
std::vector<TensorShape> output_shape = {TensorShape({})};
|
||||
|
||||
std::unique_ptr<PadEndOp> op(new PadEndOp(pad_shape2[i], pad_value2[i]));
|
||||
|
||||
Status s = op->Compute(input2, &output);
|
||||
|
||||
Tensor::CreateTensor(&expected2[i], outstring[i], pad_shape2[i]);
|
||||
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
ASSERT_TRUE(output->shape() == expected2[i]->shape());
|
||||
ASSERT_TRUE(output->type() == expected2[i]->type());
|
||||
MS_LOG(DEBUG) << *output << std::endl;
|
||||
MS_LOG(DEBUG) << *expected2[i] << std::endl;
|
||||
ASSERT_TRUE(*output == *expected2[i]);
|
||||
|
||||
s = op->OutputShape(input_shape2_vector, output_shape);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
ASSERT_TRUE(output_shape.size() == 1);
|
||||
ASSERT_TRUE(output->shape() == output_shape[0]);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "MindDataTestPadEndOp end.";
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing PadEnd op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as ops
|
||||
|
||||
|
||||
def pad_compare(array, pad_shape, pad_value, res):
|
||||
data = ds.NumpySlicesDataset([array])
|
||||
if pad_value is not None:
|
||||
data = data.map(operations=ops.PadEnd(pad_shape, pad_value))
|
||||
else:
|
||||
data = data.map(operations=ops.PadEnd(pad_shape))
|
||||
for d in data:
|
||||
np.testing.assert_array_equal(res, d[0])
|
||||
|
||||
|
||||
# Extensive testing of PadEnd is already done in batch with Pad test cases
|
||||
|
||||
def test_pad_end_basics():
|
||||
pad_compare([1, 2], [3], -1, [1, 2, -1])
|
||||
pad_compare([1, 2, 3], [3], -1, [1, 2, 3])
|
||||
pad_compare([1, 2, 3], [2], -1, [1, 2])
|
||||
pad_compare([1, 2, 3], [5], None, [1, 2, 3, 0, 0])
|
||||
|
||||
|
||||
def test_pad_end_str():
|
||||
pad_compare([b"1", b"2"], [3], b"-1", [b"1", b"2", b"-1"])
|
||||
pad_compare([b"1", b"2", b"3"], [3], b"-1", [b"1", b"2", b"3"])
|
||||
pad_compare([b"1", b"2", b"3"], [2], b"-1", [b"1", b"2"])
|
||||
pad_compare([b"1", b"2", b"3"], [5], None, [b"1", b"2", b"3", b"", b""])
|
||||
|
||||
|
||||
def test_pad_end_exceptions():
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
pad_compare([1, 2], [3], "-1", [])
|
||||
assert "Source and pad_value tensors are not of the same type." in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
pad_compare([b"1", b"2", b"3", b"4", b"5"], [2], 1, [])
|
||||
assert "Source and pad_value tensors are not of the same type." in str(info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pad_end_basics()
|
||||
test_pad_end_str()
|
||||
test_pad_end_exceptions()
|
Loading…
Reference in new issue