|
|
|
@ -46,9 +46,8 @@ class BatchNormFoldGradGpuKernel : public GpuKernel {
|
|
|
|
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
|
|
|
|
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
|
|
|
|
|
|
|
|
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
|
|
|
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
|
|
|
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
|
|
|
|
(void)workspace;
|
|
|
|
|
// 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step'
|
|
|
|
|
T *d_batch_mean = GetDeviceAddress<T>(inputs, 0);
|
|
|
|
|
T *d_batch_std = GetDeviceAddress<T>(inputs, 1);
|
|
|
|
@ -139,11 +138,8 @@ class BatchNormFoldGradGpuKernel : public GpuKernel {
|
|
|
|
|
input_size_list_.push_back(channel_size_);
|
|
|
|
|
input_size_list_.push_back(channel_size_);
|
|
|
|
|
input_size_list_.push_back(sizeof(int));
|
|
|
|
|
|
|
|
|
|
// 'dx'
|
|
|
|
|
output_size_list_.push_back(input_size_);
|
|
|
|
|
|
|
|
|
|
workspace_size_list_.push_back(workspace_size_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|