!2405 change some comment name in the whole project

Merge pull request !2405 from chenzhongming/master
pull/2405/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit f1a69de0b6

@ -248,7 +248,7 @@ checkopts()
done
}
checkopts "$@"
echo "---------------- mindspore: build start ----------------"
echo "---------------- MindSpore: build start ----------------"
mkdir -pv "${BUILD_PATH}/package/mindspore/lib"
git submodule update --init graphengine
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then

@ -36,7 +36,7 @@ class Monitor {
~Monitor() = default;
// Functor for Perf Monitor main loop.
// This function will be the entry point of Mindspore::Dataset::Task
// This function will be the entry point of mindspore::Dataset::Task
Status operator()();
int64_t GetSamplingInterval() { return sampling_interval_; }

@ -29,7 +29,7 @@
// brief mindspore namespace.
//
// mindspore namespace is the top level namespace of Mindsporeession project.
// mindspore namespace is the top level namespace of MindSpore project.
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace mindspore {

@ -91,7 +91,7 @@ using mindspore::device::DeviceAddress;
using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>;
// brief mindspore namespace.
//
// mindspore namespace is the top level namespace of Mindsporeession project.
// mindspore namespace is the top level namespace of MindSpore project.
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace mindspore {
// brief mindspore::tensor namespace

@ -19,7 +19,7 @@
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <thrust/pair.h>
#include "fake_quant_per_channel_impl.cuh"
#include "fake_quant_perchannel_impl.cuh"
#include "device/gpu/cuda_common.h"
/**
@ -113,44 +113,6 @@ void CalFakeQuantizePerChannel(const float *input, float *output, const int tota
input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric);
}
/**
* UpdateInputMinMaxPerChannel or UpdateInputMinMaxPerChannel With EMA.
* @param input_min
* @param input_max
* @param min
* @param max
* @return
*/
__global__ void UpdateInputMinMaxPerChannel(float *input_min, float *input_max, float *input, int channels,
int per_channel_nums, bool ema, float ema_decay) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) {
thrust::pair<float *, float *> sum =
thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1));
if (ema) {
input_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i];
input_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i];
} else {
input_min[i] = sum.first[0];
input_max[i] = sum.second[0];
}
input_min[i] = input_min[i] > 0 ? 0 : input_min[i];
input_max[i] = input_max[i] < 0 ? 0 : input_max[i];
}
}
__global__ void UpdateInputMinMaxPerChannelWithEMA(float *input_min, float *input_max, float min, float max,
const float decay) {
*input_min = decay * (min) + (1 - decay) * (*input_min);
*input_max = decay * (max) + (1 - decay) * (*input_max);
}
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, const int total_size, const int channel_size,
const float ema_decay, const bool ema, cudaStream_t cuda_stream) {
int per_channel_num = total_size / channel_size;
UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(
input_min, input_max, input, channel_size, per_channel_num, ema, ema_decay);
}
__global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output,
const int total_size, const int channel_size, const float *nudge_min,
const float *nudge_max) {

@ -18,7 +18,7 @@
#include <thrust/device_vector.h>
#include <thrust/pair.h>
#include "device/gpu/cuda_common.h"
#include "fake_quant_impl.cuh"
#include "fake_quant_perlayer_impl.cuh"
__global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale) {

@ -0,0 +1,104 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <thrust/pair.h>
#include "minmax_update_impl.cuh"
#include "device/gpu/cuda_common.h"
__global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min,
float *output_max, const float min, const float max, const float decay,
const float symmetric) {
output_min[0] = decay * (min) + (1 - decay) * (input_min[0]);
output_min[0] = input_min[0] > 0 ? 0 : input_min[0];
output_max[0] = decay * (max) + (1 - decay) * (input_max[0]);
output_max[0] = input_max[0] < 0 ? 0 : input_max[0];
if (symmetric) {
output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0];
output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0];
}
return;
}
__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max,
const float symmetric) {
output_min[0] = min > 0 ? 0 : min;
output_max[0] = max < 0 ? 0 : max;
if (symmetric) {
output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0];
output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0];
}
return;
}
__global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min,
float *output_max, int channels, int per_channel_nums, bool ema,
float ema_decay, bool symmetric) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) {
thrust::pair<float *, float *> sum =
thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1));
if (ema) {
output_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i];
output_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i];
} else {
output_min[i] = sum.first[0];
output_max[i] = sum.second[0];
}
output_min[i] = input_min[i] > 0 ? 0 : input_min[i];
output_max[i] = input_max[i] < 0 ? 0 : input_max[i];
if (symmetric) {
output_max[i] = abs(output_min[i]) < output_max[i] ? output_max[i] : -output_min[i];
output_min[i] = abs(output_min[i]) < output_max[i] ? -output_max[i] : output_min[i];
}
}
return;
}
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const int channel_num, const float ema_decay, const bool ema,
const bool symmetric, cudaStream_t cuda_stream) {
int per_channel_num = total_num / channel_num;
UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay, symmetric);
return;
}
void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const float ema_decay, const bool ema, const bool symmetric,
cudaStream_t cuda_stream) {
float minel = 0.f;
float maxel = 0.f;
auto policy = thrust::cuda::par.on(cuda_stream);
thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple;
tuple =
thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + total_num);
minel = tuple.first[0];
maxel = tuple.second[0];
if (ema) {
UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel,
maxel, ema_decay, symmetric);
} else {
UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel, symmetric);
}
return;
}

@ -0,0 +1,30 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
#include "device/gpu/cuda_common.h"
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const int channel_num, const float ema_decay, const bool ema,
const bool symmetric, cudaStream_t cuda_stream);
void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int size, const float ema_decay, const bool ema, const bool symmetric,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
@ -25,21 +25,15 @@ namespace mindspore {
namespace kernel {
FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_channels_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_delay_(0),
ema_(false),
ema_decay_(0),
global_step_(0),
training_(false),
channel_out_(0),
symmetric_(false),
narrow_range_(false),
symmetric_(false) {}
quant_delay_(0),
quant_min_(0),
quant_max_(0),
global_step_(0) {}
const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; }
@ -60,90 +54,56 @@ bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
return false;
}
// get attribute
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = 1.0 - GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16.";
return false;
}
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0.";
return false;
}
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
// shape info for gpu
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
channel_out_ = SizeToInt(input_shape[0]);
min_size_ = sizeof(float) * channel_out_;
max_size_ = sizeof(float) * channel_out_;
num_channels_ = SizeToInt(input_shape[0]);
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantPerChannelGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input in tensor
input_size_list_.push_back(min_size_); // min one scalar
input_size_list_.push_back(max_size_); // max on scalar
output_size_list_.push_back(output_size_); // output in tensor
workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel
workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel
workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel
}
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min,
float *input_max, float *d_nudge_min, float *d_nudge_max,
float *d_scale, void *stream_ptr) {
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel(input, input_min, input_max, input_size_ / sizeof(float), channel_out_, ema_decay_, ema_,
reinterpret_cast<cudaStream_t>(stream_ptr));
// control flow for quant_delay
if (global_step_ >= quant_delay_) {
// real launch
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max,
d_scale, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(
cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Copy gpu memory failed.");
}
global_step_++;
input_size_list_.push_back(input_size_); // input in tensor
input_size_list_.push_back(sizeof(float) * num_channels_); // min one scalar
input_size_list_.push_back(sizeof(float) * num_channels_); // max on scalar
output_size_list_.push_back(input_size_); // output in tensor
workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel
workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel
workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel
}
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForInfer(float *input, float *output, float *input_min,
float *input_max, float *d_nudge_min, float *d_nudge_max,
float *d_scale, void *stream_ptr) {
// real launch
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max,
float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) {
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, d_scale,
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale,
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
}
@ -155,9 +115,9 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
float *d_scale = GetDeviceAddress<float>(workspace, 0);
float *d_nudge_min = GetDeviceAddress<float>(workspace, 1);
float *d_nudge_max = GetDeviceAddress<float>(workspace, 2);
float *scale = GetDeviceAddress<float>(workspace, 0);
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null.";
@ -167,9 +127,16 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
}
if (training_) {
CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
if (global_step_ >= quant_delay_) {
CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr);
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Copy gpu memory failed.");
}
global_step_++;
} else {
CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr);
}
return true;

@ -39,31 +39,23 @@ class FakeQuantPerChannelGpuKernel : public GpuKernel {
void InitSizeLists() override;
private:
void CalFakeQuantizeForTraining(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min,
float *d_nudge_max, float *d_scale, void *stream_ptr);
void CalFakeQuantizeForInfer(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min,
float *d_nudge_max, float *d_scale, void *stream_ptr);
void CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, float *nudge_min,
float *nudge_max, float *scale, void *stream_ptr);
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_channels_;
int num_bits_;
bool training_;
bool symmetric_;
bool narrow_range_;
int quant_delay_;
float quant_min_;
float quant_max_;
int quant_delay_;
bool ema_;
float ema_decay_;
int global_step_;
bool training_;
int channel_out_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore

@ -14,21 +14,17 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh"
namespace mindspore {
namespace kernel {
FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
channel_out_(0),
num_channels_(0),
quant_delay_(0),
global_step_(0),
narrow_range_(false),
@ -64,42 +60,34 @@ bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) {
}
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
channel_out_ = SizeToInt(input_shape[0]);
min_size_ = sizeof(float) * channel_out_;
max_size_ = sizeof(float) * channel_out_;
num_channels_ = SizeToInt(input_shape[0]);
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantPerChannelGradGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel
workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel
workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(sizeof(float) * num_channels_); // min
input_size_list_.push_back(sizeof(float) * num_channels_); // max
output_size_list_.push_back(input_size_); // output
workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel
workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel
workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel
}
bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
@ -111,9 +99,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
float *input = GetDeviceAddress<float>(inputs, 1);
float *input_min = GetDeviceAddress<float>(inputs, 2);
float *input_max = GetDeviceAddress<float>(inputs, 3);
float *d_scale = GetDeviceAddress<float>(workspace, 0);
float *d_nudge_min = GetDeviceAddress<float>(workspace, 1);
float *d_nudge_max = GetDeviceAddress<float>(workspace, 2);
float *scale = GetDeviceAddress<float>(workspace, 0);
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
if (gradient == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null";
@ -130,9 +118,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
int total_size = input_size_ / sizeof(float);
if (global_step_ >= quant_delay_) {
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, channel_out_, d_nudge_min, d_nudge_max,
CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,

@ -40,10 +40,6 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel {
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
@ -51,7 +47,7 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel {
int num_bits_;
float quant_min_;
float quant_max_;
int channel_out_;
int num_channels_;
int quant_delay_;
int global_step_;
bool narrow_range_;

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
@ -23,31 +23,25 @@
namespace mindspore {
namespace kernel {
FakeQuantGpuKernel::FakeQuantGpuKernel()
FakeQuantPerLayerGpuKernel::FakeQuantPerLayerGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_num_(0),
quant_delay_(0),
ema_(false),
ema_decay_(0),
quant_num_(1),
global_step_(0),
num_bits_(0),
quant_delay_(0),
training_(false),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &FakeQuantGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) {
bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
@ -59,95 +53,73 @@ bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) {
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0.";
}
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
if (quant_num_ == 0) {
quant_num_ = 1;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
min_size_ = sizeof(float);
max_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(workspace_size_);
void FakeQuantPerLayerGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // x
input_size_list_.push_back(sizeof(float)); // min
input_size_list_.push_back(sizeof(float)); // max
output_size_list_.push_back(input_size_); // y
workspace_size_list_.push_back(sizeof(float)); // scale
workspace_size_list_.push_back(sizeof(float)); // nudge_min
workspace_size_list_.push_back(sizeof(float)); // nudge_max
}
bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
bool FakeQuantPerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
float *output = GetDeviceAddress<float>(outputs, 0);
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
float *scale = GetDeviceAddress<float>(workspace, 0);
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input x is null.";
}
if (input_min == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input min is null.";
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input x is null.";
}
if (input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input max is null.";
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input min or input max is null.";
}
// Allocate space for device copies
int size = sizeof(float);
float *d_scale = nullptr;
float *d_nudge_min = nullptr;
float *d_nudge_max = nullptr;
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed");
if (training_) {
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMax(input, input_min, input_max, quant_num_, ema_decay_, ema_, reinterpret_cast<cudaStream_t>(stream_ptr));
// control flow for quant_delay
if (global_step_ >= quant_delay_) {
// real launch
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_,
CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice,
@ -157,20 +129,15 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
global_step_++;
} else {
// real launch
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_,
CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
// Cleanup
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
return true;
}
MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantPerLayerGpuKernel)
} // namespace kernel
} // namespace mindspore

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
@ -23,10 +23,10 @@
namespace mindspore {
namespace kernel {
class FakeQuantGpuKernel : public GpuKernel {
class FakeQuantPerLayerGpuKernel : public GpuKernel {
public:
FakeQuantGpuKernel();
~FakeQuantGpuKernel() = default;
FakeQuantPerLayerGpuKernel();
~FakeQuantPerLayerGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
@ -40,22 +40,16 @@ class FakeQuantGpuKernel : public GpuKernel {
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_num_;
int quant_delay_;
bool ema_;
float ema_decay_;
int global_step_;
int num_bits_;
int quant_delay_;
bool training_;
bool narrow_range_;
bool symmetric_;
@ -63,4 +57,4 @@ class FakeQuantGpuKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_

@ -14,33 +14,30 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh"
namespace mindspore {
namespace kernel {
FakeQuantGradGpuKernel::FakeQuantGradGpuKernel()
FakeQuantPerLayerGradGpuKernel::FakeQuantPerLayerGradGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_size_(0),
quant_num_(1),
quant_delay_(0),
global_step_(0),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantGradGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) {
bool FakeQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output.";
@ -62,87 +59,66 @@ bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) {
}
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
if (quant_size_ == 0) {
quant_size_ = 1;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_size_ *= SizeToInt(input_shape[i]);
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
min_size_ = sizeof(float);
max_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantGradGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
void FakeQuantPerLayerGradGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(sizeof(float)); // min
input_size_list_.push_back(sizeof(float)); // max
output_size_list_.push_back(input_size_); // output
workspace_size_list_.push_back(sizeof(float)); // scale
workspace_size_list_.push_back(sizeof(float)); // nudge_min
workspace_size_list_.push_back(sizeof(float)); // nudge_max
}
bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
float *output = GetDeviceAddress<float>(outputs, 0);
float *gradient = GetDeviceAddress<float>(inputs, 0);
float *input = GetDeviceAddress<float>(inputs, 1);
float *input_min = GetDeviceAddress<float>(inputs, 2);
float *input_max = GetDeviceAddress<float>(inputs, 3);
float *scale = GetDeviceAddress<float>(workspace, 0);
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
if (gradient == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel gradient is null";
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel gradient is null";
}
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input is null.";
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input is null.";
}
if (input_min == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input min is null.";
}
if (input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input max is null.";
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input min or max is null.";
}
if (global_step_ >= quant_delay_) {
float *d_scale = nullptr;
float *d_nudge_min = nullptr;
float *d_nudge_max = nullptr;
int size = sizeof(float);
// Allocate space for device copies
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed");
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizeGrad(input, gradient, output, quant_size_, d_nudge_min, d_nudge_max,
CalFakeQuantizeGrad(input, gradient, output, quant_num_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
// Cleanup
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
@ -152,6 +128,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const
return true;
}
MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantGradGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantPerLayerGradGpuKernel)
} // namespace kernel
} // namespace mindspore

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
@ -23,10 +23,10 @@
namespace mindspore {
namespace kernel {
class FakeQuantGradGpuKernel : public GpuKernel {
class FakeQuantPerLayerGradGpuKernel : public GpuKernel {
public:
FakeQuantGradGpuKernel();
~FakeQuantGradGpuKernel() = default;
FakeQuantPerLayerGradGpuKernel();
~FakeQuantPerLayerGradGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
@ -40,9 +40,6 @@ class FakeQuantGradGpuKernel : public GpuKernel {
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
@ -51,7 +48,7 @@ class FakeQuantGradGpuKernel : public GpuKernel {
int num_bits_;
float quant_min_;
float quant_max_;
int quant_size_;
int quant_num_;
int quant_delay_;
int global_step_;
bool narrow_range_;
@ -60,4 +57,4 @@ class FakeQuantGradGpuKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_

@ -0,0 +1,119 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
#include <cuda_runtime_api.h>
namespace mindspore {
namespace kernel {
MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel()
: input_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_num_(1),
ema_(false),
ema_decay_(0),
num_channels_(0),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetWorkspaceSizeList() const {
return workspace_size_list_;
}
bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
// quant min and max
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
num_channels_ = SizeToInt(input_shape[0]);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
InitSizeLists();
return true;
}
void MinMaxUpdatePerChannelGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(sizeof(float) * num_channels_); // min
input_size_list_.push_back(sizeof(float) * num_channels_); // max
output_size_list_.push_back(sizeof(float) * num_channels_); // output min
output_size_list_.push_back(sizeof(float) * num_channels_); // output max
}
bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
float *output_min = GetDeviceAddress<float>(outputs, 0);
float *output_max = GetDeviceAddress<float>(outputs, 1);
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input x is null.";
}
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input min or input max is null.";
}
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_,
ema_decay_, ema_, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
MS_REG_GPU_KERNEL(MinMaxUpdatePerChannel, MinMaxUpdatePerChannelGpuKernel)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,60 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class MinMaxUpdatePerChannelGpuKernel : public GpuKernel {
public:
MinMaxUpdatePerChannelGpuKernel();
~MinMaxUpdatePerChannelGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const CNodePtr &kernel) override;
protected:
void InitSizeLists() override;
private:
size_t input_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_num_;
bool ema_;
float ema_decay_;
int num_channels_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_

@ -0,0 +1,115 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
#include <cuda_runtime_api.h>
namespace mindspore {
namespace kernel {
MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel()
: input_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_num_(1),
ema_(false),
ema_decay_(0),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
// quant min and max
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
InitSizeLists();
return true;
}
void MinMaxUpdatePerLayerGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(sizeof(float)); // input min
input_size_list_.push_back(sizeof(float)); // input max
output_size_list_.push_back(sizeof(float)); // output min
output_size_list_.push_back(sizeof(float)); // output max
}
bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
float *output_min = GetDeviceAddress<float>(outputs, 0);
float *output_max = GetDeviceAddress<float>(outputs, 1);
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input x is null.";
}
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null.";
}
CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
MS_REG_GPU_KERNEL(MinMaxUpdatePerLayer, MinMaxUpdatePerLayerGpuKernel)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class MinMaxUpdatePerLayerGpuKernel : public GpuKernel {
public:
MinMaxUpdatePerLayerGpuKernel();
~MinMaxUpdatePerLayerGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const CNodePtr &kernel) override;
protected:
void InitSizeLists() override;
private:
size_t input_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_num_;
bool ema_;
float ema_decay_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_

@ -28,7 +28,7 @@ message LineageEvent {
oneof what {
// An event file was started, with the specified version.
// Now version is "Mindspore.Event:1"
// Now version is "MindSpore.Event:1"
string version = 3;
// Train lineage

@ -32,7 +32,7 @@ message Event {
oneof what {
// An event file was started, with the specified version.
// Now version is "Mindspore.Event:1"
// Now version is "MindSpore.Event:1"
string version = 3;
// GraphDef.

@ -32,7 +32,7 @@
#include "vm/segment_runner.h"
#include "vm/backend.h"
// mindspore namespace is the top level namespace of Mindsporeession project.
// mindspore namespace is the top level namespace of MindSpore project.
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace mindspore {
extern const char kMsVm[];

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save