parent
d79bcc923e
commit
048fc49aed
@ -0,0 +1,90 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/random_categorical.cuh"
|
||||
|
||||
template <typename S>
|
||||
__global__ void RandomCategorical(int num_samples, double** dev_rand, double** dev_cdf,
|
||||
int batch_size, int num_classes, S *output_addr) {
|
||||
int size = num_samples * batch_size;
|
||||
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += gridDim.x * blockDim.x) {
|
||||
int cur_row = pos / num_samples;
|
||||
int cur_col = pos % num_samples;
|
||||
const double to_find = dev_cdf[cur_row][num_classes-1] * dev_rand[cur_row][cur_col];
|
||||
|
||||
int idx = 0;
|
||||
while (dev_cdf[cur_row][idx] < to_find) {
|
||||
idx++;
|
||||
}
|
||||
output_addr[pos] = static_cast<S>(idx);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void GetCdf(const T *logits_addr, double** dev_cdf, const int batch_size, const int num_classes) {
|
||||
int size = num_classes * batch_size;
|
||||
for (int pos= blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += gridDim.x * blockDim.x) {
|
||||
int cur_row = pos / num_classes;
|
||||
int cur_col = pos % num_classes;
|
||||
if (cur_col != 0) {
|
||||
return;
|
||||
}
|
||||
T max_of_row = logits_addr[pos];
|
||||
for (int i = 1; i < num_classes; i++) {
|
||||
if (logits_addr[pos + i] > max_of_row) {
|
||||
max_of_row = logits_addr[pos + i];
|
||||
}
|
||||
}
|
||||
dev_cdf[cur_row][0] = exp(static_cast<double>(logits_addr[pos] - max_of_row));
|
||||
for (int i = 1; i < num_classes; i++) {
|
||||
double tmp = exp(static_cast<double>(logits_addr[pos + i] - max_of_row));
|
||||
dev_cdf[cur_row][i] = dev_cdf[cur_row][i - 1] + tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
void RandomCategoricalKernel(int num_samples, double** dev_rand, double** dev_cdf, int batch_size,
|
||||
int num_classes, S *output_addr, cudaStream_t cuda_stream) {
|
||||
int size_out = num_samples * batch_size;
|
||||
RandomCategorical<<<GET_BLOCKS(size_out), GET_THREADS, 0, cuda_stream>>>(num_samples, dev_rand,
|
||||
dev_cdf, batch_size,
|
||||
num_classes, output_addr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GetCdfKernel(const T *logits_addr, double** dev_cdf, const int batch_size, const int num_classes,
|
||||
cudaStream_t cuda_stream) {
|
||||
int size_cdf = num_classes * batch_size;
|
||||
GetCdf<<<GET_BLOCKS(size_cdf), GET_THREADS, 0, cuda_stream>>>(logits_addr, dev_cdf, batch_size, num_classes);
|
||||
}
|
||||
|
||||
template void GetCdfKernel<half>(const half *logits_addr, double** dev_cdf, const int batch_size,
|
||||
const int num_classes, cudaStream_t cuda_stream);
|
||||
template void GetCdfKernel<float>(const float *logits_addr, double** dev_cdf, const int batch_size,
|
||||
const int num_classes, cudaStream_t cuda_stream);
|
||||
template void GetCdfKernel<double>(const double *logits_addr, double** dev_cdf, const int batch_size,
|
||||
const int num_classes, cudaStream_t cuda_stream);
|
||||
|
||||
template void RandomCategoricalKernel<int16_t>(int num_samples,
|
||||
double** dev_rand, double** dev_cdf, int batch_size, int num_classes,
|
||||
int16_t *output_addr, cudaStream_t cuda_stream);
|
||||
template void RandomCategoricalKernel<int>(int num_samples,
|
||||
double** dev_rand, double** dev_cdf, int batch_size, int num_classes,
|
||||
int *output_addr, cudaStream_t cuda_stream);
|
||||
template void RandomCategoricalKernel<int64_t>(int num_samples,
|
||||
double** dev_rand, double** dev_cdf, int batch_size, int num_classes,
|
||||
int64_t *output_addr, cudaStream_t cuda_stream);
|
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RANDOM_CATEGORICAL_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RANDOM_CATEGORICAL_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void GetCdfKernel(const T *logits_addr, double** dev_cdf, const int batch_size, const int num_classes,
|
||||
cudaStream_t cuda_stream);
|
||||
template <typename S>
|
||||
void RandomCategoricalKernel(int num_samples, double** dev_rand, double** dev_cdf,
|
||||
int batch_size, int num_classes, S *output_addr,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RANDOM_CATEGORICAL_IMPL_H_
|
@ -0,0 +1,85 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
RandomCategoricalGpuKernel, half, int16_t)
|
||||
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
RandomCategoricalGpuKernel, half, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
RandomCategoricalGpuKernel, half, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
RandomCategoricalGpuKernel, float, int16_t)
|
||||
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
RandomCategoricalGpuKernel, float, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
RandomCategoricalGpuKernel, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
RandomCategoricalGpuKernel, double, int16_t)
|
||||
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
RandomCategoricalGpuKernel, double, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
RandomCategoricalGpuKernel, double, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,141 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CATEGORICAL_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CATEGORICAL_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/random_categorical.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class RandomCategoricalGpuKernel : public GpuKernel {
|
||||
public:
|
||||
RandomCategoricalGpuKernel() : batch_size_(0), num_classes_(0), num_samples_(0), seed_(0) {}
|
||||
~RandomCategoricalGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *logits_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
S *output_addr = GetDeviceAddress<S>(outputs, 0);
|
||||
|
||||
std::unique_ptr<double *[]> host_cdf;
|
||||
host_cdf = std::make_unique<double *[]>(batch_size_);
|
||||
for (int i = 0; i < batch_size_; i++) {
|
||||
host_cdf[i] = GetDeviceAddress<double>(workspaces, i);
|
||||
}
|
||||
double **dev_cdf = GetDeviceAddress<double *>(workspaces, batch_size_);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_cdf, host_cdf.get(), sizeof(double *) * batch_size_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Random_categorica cudaMemcpyAsync dev_cdf failed");
|
||||
|
||||
std::unique_ptr<double *[]> host_rand;
|
||||
host_rand = std::make_unique<double *[]>(batch_size_);
|
||||
for (int i = 0; i < batch_size_; i++) {
|
||||
host_rand[i] = GetDeviceAddress<double>(workspaces, batch_size_ + 1 + i);
|
||||
}
|
||||
|
||||
double **dev_rand = GetDeviceAddress<double *>(workspaces, batch_size_ * 2 + 1);
|
||||
for (int i = 0; i < batch_size_; i++) {
|
||||
double *host_1d_rand = new double[num_samples_];
|
||||
std::default_random_engine rng(seed_);
|
||||
std::uniform_real_distribution<> dist(0, 1);
|
||||
for (int j = 0; j < num_samples_; j++) {
|
||||
host_1d_rand[j] = dist(rng);
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(host_rand[i], host_1d_rand, sizeof(double) * num_samples_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Random_categorica cudaMemcpyAsync host_1d_rand failed");
|
||||
delete[] host_1d_rand;
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_rand, host_rand.get(), sizeof(double *) * batch_size_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Random_categorica cudaMemcpyAsync dev_rand failed");
|
||||
|
||||
GetCdfKernel(logits_addr, dev_cdf, batch_size_, num_classes_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
RandomCategoricalKernel(num_samples_, dev_rand, dev_cdf, batch_size_, num_classes_, output_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 3) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomCategorical needs 3 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but RandomCategorical has 1 output.";
|
||||
return false;
|
||||
}
|
||||
auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (logits_shape.size() != 2) {
|
||||
MS_LOG(ERROR) << "logits's dims is " << logits_shape.size() << ", but it should be only 2-D.";
|
||||
return false;
|
||||
}
|
||||
batch_size_ = SizeToInt(logits_shape[0]);
|
||||
num_classes_ = SizeToInt(logits_shape[1]);
|
||||
|
||||
num_samples_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "num_samples"));
|
||||
seed_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed"));
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {}
|
||||
void InitSizeLists() override {
|
||||
// init memory
|
||||
input_size_list_.push_back(sizeof(T) * batch_size_ * num_classes_);
|
||||
input_size_list_.push_back(sizeof(int) * 2);
|
||||
|
||||
output_size_list_.push_back(sizeof(S) * batch_size_ * num_samples_);
|
||||
|
||||
for (int i = 0; i < batch_size_; i++) {
|
||||
workspace_size_list_.push_back(sizeof(double) * num_classes_);
|
||||
}
|
||||
workspace_size_list_.push_back(sizeof(double *) * batch_size_);
|
||||
|
||||
for (int i = 0; i < batch_size_; i++) {
|
||||
workspace_size_list_.push_back(sizeof(double) * num_samples_);
|
||||
}
|
||||
workspace_size_list_.push_back(sizeof(double *) * batch_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
int batch_size_;
|
||||
int num_classes_;
|
||||
int num_samples_;
|
||||
int seed_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CATEGORICAL_GPU_KERNEL_H_
|
@ -0,0 +1,180 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class RCnet(nn.Cell):
|
||||
def __init__(self, num_sample, seed=0, dtype=ms.int64):
|
||||
super(RCnet, self).__init__()
|
||||
self.rc = P.RandomCategorical(dtype)
|
||||
self.num_sample = num_sample
|
||||
self.seed = seed
|
||||
|
||||
def construct(self, logits):
|
||||
return self.rc(logits, self.num_sample, self.seed)
|
||||
|
||||
TARGET = "GPU"
|
||||
|
||||
def test_rc_graph_fp16_int64():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=TARGET)
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16)
|
||||
num_sample = 10
|
||||
seed = 5
|
||||
dtype = ms.int64
|
||||
expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64)
|
||||
|
||||
random_cateogoric = RCnet(num_sample, seed, dtype)
|
||||
output = random_cateogoric(x)
|
||||
diff = output.asnumpy() - expect
|
||||
assert expect.dtype == output.asnumpy().dtype
|
||||
assert np.all(diff == 0)
|
||||
|
||||
def test_rc_graph_fp32_int64():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=TARGET)
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float32)
|
||||
num_sample = 10
|
||||
seed = 5
|
||||
dtype = ms.int64
|
||||
expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64)
|
||||
|
||||
random_cateogoric = RCnet(num_sample, seed, dtype)
|
||||
output = random_cateogoric(x)
|
||||
diff = output.asnumpy() - expect
|
||||
assert expect.dtype == output.asnumpy().dtype
|
||||
assert np.all(diff == 0)
|
||||
|
||||
def test_rc_graph_fp64_int64():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=TARGET)
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float64)
|
||||
num_sample = 10
|
||||
seed = 5
|
||||
dtype = ms.int64
|
||||
expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64)
|
||||
|
||||
random_cateogoric = RCnet(num_sample, seed, dtype)
|
||||
output = random_cateogoric(x)
|
||||
diff = output.asnumpy() - expect
|
||||
assert expect.dtype == output.asnumpy().dtype
|
||||
assert np.all(diff == 0)
|
||||
|
||||
def test_rc_graph_fp16_int16():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=TARGET)
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16)
|
||||
num_sample = 10
|
||||
seed = 5
|
||||
dtype = ms.int16
|
||||
expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int16)
|
||||
|
||||
random_cateogoric = RCnet(num_sample, seed, dtype)
|
||||
output = random_cateogoric(x)
|
||||
diff = output.asnumpy() - expect
|
||||
assert expect.dtype == output.asnumpy().dtype
|
||||
assert np.all(diff == 0)
|
||||
|
||||
def test_rc_graph_fp16_int32():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=TARGET)
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16)
|
||||
num_sample = 10
|
||||
seed = 5
|
||||
dtype = ms.int32
|
||||
expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int32)
|
||||
|
||||
random_cateogoric = RCnet(num_sample, seed, dtype)
|
||||
output = random_cateogoric(x)
|
||||
diff = output.asnumpy() - expect
|
||||
assert expect.dtype == output.asnumpy().dtype
|
||||
assert np.all(diff == 0)
|
||||
|
||||
def test_rc_pynative_fp16_int64():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET)
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16)
|
||||
num_sample = 10
|
||||
seed = 5
|
||||
dtype = ms.int64
|
||||
expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64)
|
||||
|
||||
output = P.RandomCategorical(dtype)(x, num_sample, seed)
|
||||
diff = output.asnumpy() - expect
|
||||
assert expect.dtype == output.asnumpy().dtype
|
||||
assert np.all(diff == 0)
|
||||
|
||||
def test_rc_pynative_fp32_int64():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET)
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float32)
|
||||
num_sample = 10
|
||||
seed = 5
|
||||
dtype = ms.int64
|
||||
expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64)
|
||||
|
||||
output = P.RandomCategorical(dtype)(x, num_sample, seed)
|
||||
diff = output.asnumpy() - expect
|
||||
assert expect.dtype == output.asnumpy().dtype
|
||||
assert np.all(diff == 0)
|
||||
|
||||
def test_rc_pynative_fp64_int64():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET)
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float64)
|
||||
num_sample = 10
|
||||
seed = 5
|
||||
dtype = ms.int64
|
||||
expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64)
|
||||
|
||||
output = P.RandomCategorical(dtype)(x, num_sample, seed)
|
||||
diff = output.asnumpy() - expect
|
||||
assert expect.dtype == output.asnumpy().dtype
|
||||
assert np.all(diff == 0)
|
||||
|
||||
def test_rc_pynative_fp16_int16():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET)
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16)
|
||||
num_sample = 10
|
||||
seed = 5
|
||||
dtype = ms.int16
|
||||
expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int16)
|
||||
|
||||
output = P.RandomCategorical(dtype)(x, num_sample, seed)
|
||||
diff = output.asnumpy() - expect
|
||||
assert expect.dtype == output.asnumpy().dtype
|
||||
assert np.all(diff == 0)
|
||||
|
||||
def test_rc_pynative_fp16_int32():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET)
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16)
|
||||
num_sample = 10
|
||||
seed = 5
|
||||
dtype = ms.int32
|
||||
expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int32)
|
||||
|
||||
output = P.RandomCategorical(dtype)(x, num_sample, seed)
|
||||
diff = output.asnumpy() - expect
|
||||
assert expect.dtype == output.asnumpy().dtype
|
||||
assert np.all(diff == 0)
|
Loading…
Reference in new issue