!4473 Implementing Posterize Op
Merge pull request !4473 from islam_amin/posterize_oppull/4473/MERGE
commit
56bd92b88f
@ -0,0 +1,50 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/kernels/image/posterize_op.h"
|
||||
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
const uint8_t PosterizeOp::kBit = 8;
|
||||
|
||||
PosterizeOp::PosterizeOp(uint8_t bit) : bit_(bit) {}
|
||||
|
||||
Status PosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
uint8_t mask_value = ~((uint8_t)(1 << (8 - bit_)) - 1);
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
if (!input_cv->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
|
||||
}
|
||||
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
|
||||
RETURN_STATUS_UNEXPECTED("Input Tensor is not in shape of <H,W,C> or <H,W>");
|
||||
}
|
||||
std::vector<uint8_t> lut_vector;
|
||||
for (std::size_t i = 0; i < 256; i++) {
|
||||
lut_vector.push_back(i & mask_value);
|
||||
}
|
||||
cv::Mat in_image = input_cv->mat();
|
||||
cv::Mat output_img;
|
||||
cv::LUT(in_image, lut_vector, output_img);
|
||||
std::shared_ptr<CVTensor> result_tensor;
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(output_img, &result_tensor));
|
||||
*output = std::static_pointer_cast<Tensor>(result_tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,56 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_POSTERIZE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_POSTERIZE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class PosterizeOp : public TensorOp {
|
||||
public:
|
||||
/// Default values
|
||||
static const uint8_t kBit;
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param[in] bit: bits to use
|
||||
explicit PosterizeOp(uint8_t bit = kBit);
|
||||
|
||||
~PosterizeOp() override = default;
|
||||
|
||||
std::string Name() const override { return kPosterizeOp; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
/// Member variables
|
||||
private:
|
||||
std::string kPosterizeOp = "PosterizeOp";
|
||||
|
||||
protected:
|
||||
uint8_t bit_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_POSTERIZE_OP_H_
|
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/kernels/image/random_posterize_op.h"
|
||||
|
||||
#include <random>
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
const uint8_t RandomPosterizeOp::kMinBit = 8;
|
||||
const uint8_t RandomPosterizeOp::kMaxBit = 8;
|
||||
|
||||
RandomPosterizeOp::RandomPosterizeOp(uint8_t min_bit, uint8_t max_bit)
|
||||
: PosterizeOp(min_bit), min_bit_(min_bit), max_bit_(max_bit) {
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
|
||||
Status RandomPosterizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
bit_ = (min_bit_ == max_bit_) ? min_bit_ : std::uniform_int_distribution<uint8_t>(min_bit_, max_bit_)(rnd_);
|
||||
return PosterizeOp::Compute(input, output);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,55 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_POSTERIZE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_POSTERIZE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/kernels/image/posterize_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class RandomPosterizeOp : public PosterizeOp {
|
||||
public:
|
||||
/// Default values
|
||||
static const uint8_t kMinBit;
|
||||
static const uint8_t kMaxBit;
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param[in] min_bit: Minimum bit in range
|
||||
/// \param[in] max_bit: Maximum bit in range
|
||||
explicit RandomPosterizeOp(uint8_t min_bit = kMinBit, uint8_t max_bit = kMaxBit);
|
||||
|
||||
~RandomPosterizeOp() override = default;
|
||||
|
||||
std::string Name() const override { return kRandomPosterizeOp; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
/// Member variables
|
||||
private:
|
||||
std::string kRandomPosterizeOp = "RandomPosterizeOp";
|
||||
uint8_t min_bit_;
|
||||
uint8_t max_bit_;
|
||||
std::mt19937 rnd_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_POSTERIZE_OP_H_
|
@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "common/cvop_common.h"
|
||||
#include "minddata/dataset/kernels/image/random_posterize_op.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestRandomPosterizeOp : public UT::CVOP::CVOpCommon {
|
||||
public:
|
||||
MindDataTestRandomPosterizeOp() : CVOpCommon() {}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestRandomPosterizeOp, TestOp1) {
|
||||
MS_LOG(INFO) << "Doing testRandomPosterize.";
|
||||
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
std::unique_ptr<RandomPosterizeOp> op(new RandomPosterizeOp(1, 1));
|
||||
EXPECT_TRUE(op->OneToOne());
|
||||
Status s = op->Compute(input_tensor_, &output_tensor);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
CheckImageShapeAndData(output_tensor, kRandomPosterize);
|
||||
}
|
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 380 KiB |
@ -0,0 +1,149 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing RandomPosterize op in DE
|
||||
"""
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as c_vision
|
||||
from mindspore import log as logger
|
||||
from util import visualize_list, save_and_check_md5, \
|
||||
config_get_set_seed, config_get_set_num_parallel_workers
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
|
||||
|
||||
def test_random_posterize_op_c(plot=False, run_golden=True):
|
||||
"""
|
||||
Test RandomPosterize in C transformations
|
||||
"""
|
||||
logger.info("test_random_posterize_op_c")
|
||||
|
||||
original_seed = config_get_set_seed(55)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
# define map operations
|
||||
transforms1 = [
|
||||
c_vision.Decode(),
|
||||
c_vision.RandomPosterize((1, 8))
|
||||
]
|
||||
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
data1 = data1.map(input_columns=["image"], operations=transforms1)
|
||||
# Second dataset
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
data2 = data2.map(input_columns=["image"], operations=[c_vision.Decode()])
|
||||
|
||||
image_posterize = []
|
||||
image_original = []
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
image1 = item1["image"]
|
||||
image2 = item2["image"]
|
||||
image_posterize.append(image1)
|
||||
image_original.append(image2)
|
||||
|
||||
if run_golden:
|
||||
# check results with md5 comparison
|
||||
filename = "random_posterize_01_result_c.npz"
|
||||
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
if plot:
|
||||
visualize_list(image_original, image_posterize)
|
||||
|
||||
# Restore configuration
|
||||
ds.config.set_seed(original_seed)
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_random_posterize_op_fixed_point_c(plot=False, run_golden=True):
|
||||
"""
|
||||
Test RandomPosterize in C transformations with fixed point
|
||||
"""
|
||||
logger.info("test_random_posterize_op_c")
|
||||
|
||||
# define map operations
|
||||
transforms1 = [
|
||||
c_vision.Decode(),
|
||||
c_vision.RandomPosterize(1)
|
||||
]
|
||||
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
data1 = data1.map(input_columns=["image"], operations=transforms1)
|
||||
# Second dataset
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
data2 = data2.map(input_columns=["image"], operations=[c_vision.Decode()])
|
||||
|
||||
image_posterize = []
|
||||
image_original = []
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
image1 = item1["image"]
|
||||
image2 = item2["image"]
|
||||
image_posterize.append(image1)
|
||||
image_original.append(image2)
|
||||
|
||||
if run_golden:
|
||||
# check results with md5 comparison
|
||||
filename = "random_posterize_fixed_point_01_result_c.npz"
|
||||
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
if plot:
|
||||
visualize_list(image_original, image_posterize)
|
||||
|
||||
|
||||
def test_random_posterize_exception_bit():
|
||||
"""
|
||||
Test RandomPosterize: out of range input bits and invalid type
|
||||
"""
|
||||
logger.info("test_random_posterize_exception_bit")
|
||||
# Test max > 8
|
||||
try:
|
||||
_ = c_vision.RandomPosterize((1, 9))
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Input is not within the required interval of (1 to 8)."
|
||||
# Test min < 1
|
||||
try:
|
||||
_ = c_vision.RandomPosterize((0, 7))
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Input is not within the required interval of (1 to 8)."
|
||||
# Test max < min
|
||||
try:
|
||||
_ = c_vision.RandomPosterize((8, 1))
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Input is not within the required interval of (1 to 8)."
|
||||
# Test wrong type (not uint8)
|
||||
try:
|
||||
_ = c_vision.RandomPosterize(1.1)
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Argument bits with value 1.1 is not of type (<class 'list'>, <class 'tuple'>, <class 'int'>)."
|
||||
# Test wrong number of bits
|
||||
try:
|
||||
_ = c_vision.RandomPosterize((1, 1, 1))
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert str(e) == "Size of bits should be a single integer or a list/tuple (min, max) of length 2."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_posterize_op_c(plot=True)
|
||||
test_random_posterize_op_fixed_point_c(plot=True)
|
||||
test_random_posterize_exception_bit()
|
Loading…
Reference in new issue