parent
729d847dd4
commit
1896950ae5
@ -0,0 +1,108 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
MixUpBatchOp::MixUpBatchOp(float alpha) : alpha_(alpha) { rnd_.seed(GetSeed()); }
|
||||
|
||||
Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||
if (input.size() < 2) {
|
||||
RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation");
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<CVTensor>> images;
|
||||
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector();
|
||||
std::vector<int64_t> label_shape = input.at(1)->shape().AsVector();
|
||||
|
||||
// Check inputs
|
||||
if (label_shape.size() != 2 || image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
|
||||
RETURN_STATUS_UNEXPECTED("You must batch before calling MixUpBatch");
|
||||
}
|
||||
|
||||
if ((image_shape[1] != 1 && image_shape[1] != 3) && (image_shape[3] != 1 && image_shape[3] != 3)) {
|
||||
RETURN_STATUS_UNEXPECTED("MixUpBatch: Images must be in the shape of HWC or CHW");
|
||||
}
|
||||
|
||||
// Move images into a vector of CVTensors
|
||||
RETURN_IF_NOT_OK(BatchTensorToCVTensorVector(input.at(0), &images));
|
||||
|
||||
// Calculating lambda
|
||||
// If x1 is a random variable from Gamma(a1, 1) and x2 is a random variable from Gamma(a2, 1)
|
||||
// then x = x1 / (x1+x2) is a random variable from Beta(a1, a2)
|
||||
std::gamma_distribution<float> distribution(alpha_, 1);
|
||||
float x1 = distribution(rnd_);
|
||||
float x2 = distribution(rnd_);
|
||||
float lam = x1 / (x1 + x2);
|
||||
|
||||
// Calculate random labels
|
||||
std::vector<int64_t> rand_indx;
|
||||
for (int64_t i = 0; i < images.size(); i++) rand_indx.push_back(i);
|
||||
std::shuffle(rand_indx.begin(), rand_indx.end(), rnd_);
|
||||
|
||||
// Compute labels
|
||||
std::shared_ptr<Tensor> out_labels;
|
||||
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), &out_labels, DataType("float32")));
|
||||
for (int64_t i = 0; i < label_shape[0]; i++) {
|
||||
for (int64_t j = 0; j < label_shape[1]; j++) {
|
||||
uint64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j}));
|
||||
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value));
|
||||
}
|
||||
}
|
||||
|
||||
// Compute images
|
||||
for (int64_t i = 0; i < images.size(); i++) {
|
||||
TensorShape remaining({-1});
|
||||
uchar *start_addr_of_index = nullptr;
|
||||
std::shared_ptr<Tensor> out;
|
||||
RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx[i], 0, 0, 0}, &start_addr_of_index, &remaining));
|
||||
RETURN_IF_NOT_OK(input.at(0)->CreateFromMemory(TensorShape({image_shape[1], image_shape[2], image_shape[3]}),
|
||||
input.at(0)->type(), start_addr_of_index, &out));
|
||||
std::shared_ptr<CVTensor> rand_image = CVTensor::AsCVTensor(std::move(out));
|
||||
if (!rand_image->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
|
||||
}
|
||||
images[i]->mat() = lam * images[i]->mat() + (1 - lam) * rand_image->mat();
|
||||
}
|
||||
|
||||
// Move the output into a TensorRow
|
||||
std::shared_ptr<Tensor> output_image;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input.at(0)->shape(), input.at(0)->type(), &output_image));
|
||||
for (int64_t i = 0; i < images.size(); i++) {
|
||||
RETURN_IF_NOT_OK(output_image->InsertTensor({i}, images[i]));
|
||||
}
|
||||
output->push_back(output_image);
|
||||
output->push_back(out_labels);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MixUpBatchOp::Print(std::ostream &out) const {
|
||||
out << "MixUpBatchOp: "
|
||||
<< "alpha: " << alpha_ << "\n";
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MIXUPBATCH_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MIXUPBATCH_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class MixUpBatchOp : public TensorOp {
|
||||
public:
|
||||
// Default values, also used by python_bindings.cc
|
||||
|
||||
explicit MixUpBatchOp(float alpha);
|
||||
|
||||
~MixUpBatchOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kMixUpBatchOp; }
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
std::mt19937 rnd_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MIXUPBATCH_OP_H_
|
@ -0,0 +1,69 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "common/cvop_common.h"
|
||||
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
|
||||
class MindDataTestMixUpBatchOp : public UT::CVOP::CVOpCommon {
|
||||
protected:
|
||||
MindDataTestMixUpBatchOp() : CVOpCommon() {}
|
||||
|
||||
std::shared_ptr<Tensor> output_tensor_;
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestMixUpBatchOp, TestSuccess) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestMixUpBatchOp success case";
|
||||
std::shared_ptr<Tensor> batched_tensor;
|
||||
std::shared_ptr<Tensor> batched_labels;
|
||||
Tensor::CreateEmpty(TensorShape({2, input_tensor_->shape()[0], input_tensor_->shape()[1], input_tensor_->shape()[2]}), input_tensor_->type(), &batched_tensor);
|
||||
for (int i = 0; i < 2; i++) {
|
||||
batched_tensor->InsertTensor({i}, input_tensor_);
|
||||
}
|
||||
Tensor::CreateFromVector(std::vector<uint32_t>({0, 1, 1, 0}), TensorShape({2, 2}), &batched_labels);
|
||||
std::shared_ptr<MixUpBatchOp> op = std::make_shared<MixUpBatchOp>(1);
|
||||
TensorRow in;
|
||||
in.push_back(batched_tensor);
|
||||
in.push_back(batched_labels);
|
||||
TensorRow out;
|
||||
ASSERT_TRUE(op->Compute(in, &out).IsOk());
|
||||
|
||||
EXPECT_EQ(in.at(0)->shape()[0], out.at(0)->shape()[0]);
|
||||
EXPECT_EQ(in.at(0)->shape()[1], out.at(0)->shape()[1]);
|
||||
EXPECT_EQ(in.at(0)->shape()[2], out.at(0)->shape()[2]);
|
||||
EXPECT_EQ(in.at(0)->shape()[3], out.at(0)->shape()[3]);
|
||||
|
||||
EXPECT_EQ(in.at(1)->shape()[0], out.at(1)->shape()[0]);
|
||||
EXPECT_EQ(in.at(1)->shape()[1], out.at(1)->shape()[1]);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestMixUpBatchOp, TestFail) {
|
||||
// This is a fail case because our labels are not batched and are 1-dimensional
|
||||
MS_LOG(INFO) << "Doing MindDataTestMixUpBatchOp fail case";
|
||||
std::shared_ptr<Tensor> labels;
|
||||
Tensor::CreateFromVector(std::vector<uint32_t>({0, 1, 1, 0}), TensorShape({4}), &labels);
|
||||
std::shared_ptr<MixUpBatchOp> op = std::make_shared<MixUpBatchOp>(1);
|
||||
TensorRow in;
|
||||
in.push_back(input_tensor_);
|
||||
in.push_back(labels);
|
||||
TensorRow out;
|
||||
ASSERT_FALSE(op->Compute(in, &out).IsOk());
|
||||
}
|
Binary file not shown.
@ -0,0 +1,247 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing the MixUpBatch op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
import mindspore.dataset.transforms.c_transforms as data_trans
|
||||
from mindspore import log as logger
|
||||
from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
|
||||
config_get_set_num_parallel_workers
|
||||
|
||||
DATA_DIR = "../data/dataset/testCifar10Data"
|
||||
|
||||
GENERATE_GOLDEN = False
|
||||
|
||||
def test_mixup_batch_success1(plot=False):
|
||||
"""
|
||||
Test MixUpBatch op with specified alpha parameter
|
||||
"""
|
||||
logger.info("test_mixup_batch_success1")
|
||||
|
||||
# Original Images
|
||||
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
ds_original = ds_original.batch(5, drop_remainder=True)
|
||||
|
||||
images_original = None
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original, image, axis=0)
|
||||
|
||||
# MixUp Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
mixup_batch_op = vision.MixUpBatch(2)
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
|
||||
|
||||
images_mixup = None
|
||||
for idx, (image, _) in enumerate(data1):
|
||||
if idx == 0:
|
||||
images_mixup = image
|
||||
else:
|
||||
images_mixup = np.append(images_mixup, image, axis=0)
|
||||
if plot:
|
||||
visualize_list(images_original, images_mixup)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_mixup[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
|
||||
def test_mixup_batch_success2(plot=False):
|
||||
"""
|
||||
Test MixUpBatch op without specified alpha parameter.
|
||||
Alpha parameter will be selected by default in this case
|
||||
"""
|
||||
logger.info("test_mixup_batch_success2")
|
||||
|
||||
# Original Images
|
||||
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
ds_original = ds_original.batch(5, drop_remainder=True)
|
||||
|
||||
images_original = None
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original, image, axis=0)
|
||||
|
||||
# MixUp Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
mixup_batch_op = vision.MixUpBatch()
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
|
||||
|
||||
images_mixup = np.array([])
|
||||
for idx, (image, _) in enumerate(data1):
|
||||
if idx == 0:
|
||||
images_mixup = image
|
||||
else:
|
||||
images_mixup = np.append(images_mixup, image, axis=0)
|
||||
if plot:
|
||||
visualize_list(images_original, images_mixup)
|
||||
|
||||
num_samples = images_original.shape[0]
|
||||
mse = np.zeros(num_samples)
|
||||
for i in range(num_samples):
|
||||
mse[i] = diff_mse(images_mixup[i], images_original[i])
|
||||
logger.info("MSE= {}".format(str(np.mean(mse))))
|
||||
|
||||
|
||||
def test_mixup_batch_md5():
|
||||
"""
|
||||
Test MixUpBatch with MD5:
|
||||
"""
|
||||
logger.info("test_mixup_batch_md5")
|
||||
original_seed = config_get_set_seed(0)
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
|
||||
# MixUp Images
|
||||
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data = data.map(input_columns=["label"], operations=one_hot_op)
|
||||
mixup_batch_op = vision.MixUpBatch()
|
||||
data = data.batch(5, drop_remainder=True)
|
||||
data = data.map(input_columns=["image", "label"], operations=mixup_batch_op)
|
||||
|
||||
filename = "mixup_batch_c_result.npz"
|
||||
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
# Restore config setting
|
||||
ds.config.set_seed(original_seed)
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
|
||||
|
||||
def test_mixup_batch_fail1():
|
||||
"""
|
||||
Test MixUpBatch Fail 1
|
||||
We expect this to fail because the images and labels are not batched
|
||||
"""
|
||||
logger.info("test_mixup_batch_fail1")
|
||||
|
||||
# Original Images
|
||||
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
ds_original = ds_original.batch(5)
|
||||
|
||||
images_original = np.array([])
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original, image, axis=0)
|
||||
|
||||
# MixUp Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
mixup_batch_op = vision.MixUpBatch(0.1)
|
||||
with pytest.raises(RuntimeError) as error:
|
||||
data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
|
||||
for idx, (image, _) in enumerate(data1):
|
||||
if idx == 0:
|
||||
images_mixup = image
|
||||
else:
|
||||
images_mixup = np.append(images_mixup, image, axis=0)
|
||||
error_message = "You must batch before calling MixUp"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
def test_mixup_batch_fail2():
|
||||
"""
|
||||
Test MixUpBatch Fail 2
|
||||
We expect this to fail because alpha is negative
|
||||
"""
|
||||
logger.info("test_mixup_batch_fail2")
|
||||
|
||||
# Original Images
|
||||
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
ds_original = ds_original.batch(5)
|
||||
|
||||
images_original = np.array([])
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original, image, axis=0)
|
||||
|
||||
# MixUp Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
with pytest.raises(ValueError) as error:
|
||||
vision.MixUpBatch(-1)
|
||||
error_message = "Input is not within the required interval"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
def test_mixup_batch_fail3():
|
||||
"""
|
||||
Test MixUpBatch op
|
||||
We expect this to fail because label column is not passed to mixup_batch
|
||||
"""
|
||||
# Original Images
|
||||
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
ds_original = ds_original.batch(5, drop_remainder=True)
|
||||
|
||||
images_original = None
|
||||
for idx, (image, _) in enumerate(ds_original):
|
||||
if idx == 0:
|
||||
images_original = image
|
||||
else:
|
||||
images_original = np.append(images_original, image, axis=0)
|
||||
|
||||
# MixUp Images
|
||||
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
|
||||
|
||||
one_hot_op = data_trans.OneHot(num_classes=10)
|
||||
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
|
||||
mixup_batch_op = vision.MixUpBatch()
|
||||
data1 = data1.batch(5, drop_remainder=True)
|
||||
data1 = data1.map(input_columns=["image"], operations=mixup_batch_op)
|
||||
|
||||
with pytest.raises(RuntimeError) as error:
|
||||
images_mixup = np.array([])
|
||||
for idx, (image, _) in enumerate(data1):
|
||||
if idx == 0:
|
||||
images_mixup = image
|
||||
else:
|
||||
images_mixup = np.append(images_mixup, image, axis=0)
|
||||
error_message = "Both images and labels columns are required"
|
||||
assert error_message in str(error.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mixup_batch_success1(plot=True)
|
||||
test_mixup_batch_success2(plot=True)
|
||||
test_mixup_batch_md5()
|
||||
test_mixup_batch_fail1()
|
||||
test_mixup_batch_fail2()
|
||||
test_mixup_batch_fail3()
|
Loading…
Reference in new issue