add fix to GPU-RandomChoiceWithMask (bitonicsort & testcase)

pull/4368/head
TFbunny 5 years ago
parent ad8a786b07
commit 17d01e838f

@ -134,7 +134,7 @@ template <typename T>
__global__ void Sort(const int ceil_power2, T *rank_buff) { __global__ void Sort(const int ceil_power2, T *rank_buff) {
for (size_t i = 2; i <= ceil_power2; i <<= 1) { for (size_t i = 2; i <= ceil_power2; i <<= 1) {
for (size_t j = (i >> 1); j > 0; j >>= 1) { for (size_t j = (i >> 1); j > 0; j >>= 1) {
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < ceil_power2; tid += blockDim.x * gridDim.x) { for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) {
size_t tid_comp = tid ^ j; size_t tid_comp = tid ^ j;
if (tid_comp > tid) { if (tid_comp > tid) {
if ((tid & i) == 0) { if ((tid & i) == 0) {
@ -165,7 +165,7 @@ __global__ void Shuffle(const int ceil_power2, curandState *globalState, T *rank
int value; int value;
for (size_t i = 2; i <= ceil_power2; i <<= 1) { for (size_t i = 2; i <= ceil_power2; i <<= 1) {
for (size_t j = (i >> 1); j > 0; j >>= 1) { for (size_t j = (i >> 1); j > 0; j >>= 1) {
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < ceil_power2; tid += blockDim.x * gridDim.x) { for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) {
size_t tid_comp = tid ^ j; size_t tid_comp = tid ^ j;
if (tid_comp > tid) { if (tid_comp > tid) {
value = static_cast<int>(curand(&globalState[tid])); value = static_cast<int>(curand(&globalState[tid]));
@ -249,10 +249,10 @@ void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size,
Reshape2Index<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, input_shape_size, d1, d2, d3, d4, d5, Reshape2Index<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, input_shape_size, d1, d2, d3, d4, d5,
input, index_buff); input, index_buff);
Sort<<<GET_BLOCKS(ceil_power2), GET_THREADS, 0, stream>>>(ceil_power2, rank_buff); Sort<<<1, GET_THREADS, 0, stream>>>(ceil_power2, rank_buff);
SrandInit<<<GET_BLOCKS(ceil_power2), GET_THREADS, 0, stream>>>(ceil_power2, globalState, seedc); SrandInit<<<GET_BLOCKS(ceil_power2), GET_THREADS, 0, stream>>>(ceil_power2, globalState, seedc);
Shuffle<<<GET_BLOCKS(ceil_power2), GET_THREADS, 0, stream>>>(ceil_power2, globalState, rank_buff); Shuffle<<<1, GET_THREADS, 0, stream>>>(ceil_power2, globalState, rank_buff);
MoveToOutput<<<GET_BLOCKS(count), GET_THREADS, 0, stream>>>(input_shape_size, count, input, output_index, output_mask, MoveToOutput<<<GET_BLOCKS(count), GET_THREADS, 0, stream>>>(input_shape_size, count, input, output_index, output_mask,
index_buff, rank_buff, Tnum_buff); index_buff, rank_buff, Tnum_buff);

@ -49,38 +49,37 @@ class RCWM_3D(nn.Cell):
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_RCWM_3D(): def test_RCWM_3D():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
input_tensor = Tensor(np.ones([3, 4, 5]).astype(np.bool)) input_tensor = Tensor(np.ones([3, 4, 5]).astype(np.bool))
expect1 = [[0, 1, 1], [0, 2, 1], [0, 2, 2], [1, 0, 1], [0, 1, 3], [0, 3, 0], [1, 3, 2], \ expect1 = (10, 3)
[0, 0, 0], [1, 1, 2], [1, 3, 4]] expect2 = (10,)
expect2 = [True, True, True, True, True, True, True, True, True, True]
rcwm = RCWM_3D() rcwm = RCWM_3D()
output1, output2 = rcwm(input_tensor) output1, output2 = rcwm(input_tensor)
assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1) assert output1.shape == expect1
assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2) assert output2.shape == expect2
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_RCWM_count_out(): def test_RCWM_count_out():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool))
expect1 = [[0, 2], [2, 2], [2, 1], [2, 0], [0, 0], [3, 3], [2, 3], [1, 3], [0, 0], [0, 0]] expect1 = (10, 2)
expect2 = [True, True, True, True, True, True, True, True, False, False] expect2 = (10,)
rcwm = RCWM_count_out() rcwm = RCWM_count_out()
output1, output2 = rcwm(input_tensor) output1, output2 = rcwm(input_tensor)
assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1) assert output1.shape == expect1
assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2) assert output2.shape == expect2
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_RCWM_count_in(): def test_RCWM_count_in():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool))
expect1 = [[0, 2], [2, 2], [2, 1], [2, 0]] expect1 = (4, 2)
expect2 = [True, True, True, True] expect2 = (4,)
rcwm = RCWM_count_in() rcwm = RCWM_count_in()
output1, output2 = rcwm(input_tensor) output1, output2 = rcwm(input_tensor)
assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1) assert output1.shape == expect1
assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2) assert output2.shape == expect2

Loading…
Cancel
Save