You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/operators/dropout_op.h

152 lines
5.2 KiB

/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
Refine dropout gpu memory (#17095) * refine_dropout_mem,test=develop * # This is a combination of 14 commits. # The first commit's message is: remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (#17066) # This is the 2nd commit message: Fleet unify distributed training (#16791) * implement distributed transpiler with fleet # This is the 3rd commit message: ParallelDyGraph with GPU collective mode (#16827) implement dygraph.parallel.DataParallel to hook reduce op. # This is the 4th commit message: Init mixed precision training interface (#16856) * Init mixed precision training interface * Add fp16 test script test=develop * All initializers support float16 test=develop * Code cleanup & add more code annotations test=develop * Update API spec test=develop * Add usage example in doc test=develop # This is the 5th commit message: fix reference_count_pass,test=develop (#17060) test=develop # This is the 6th commit message: Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090) * Cache the information of linear interpolation in forward and use it in backward. test=develop * Fix cuda kernel. test=develop # This is the 7th commit message: remove unnecessary prepare_data (#17080) test=develop # This is the 8th commit message: fix interpolate cu. test=develop (#17101) # This is the 9th commit message: test=develop, double backward leaky_relu (#17067) backward of backward: leaky_relu # This is the 10th commit message: fix fuse optimizer ops (#17102) test=develop # This is the 11th commit message: truncated_gaussian_random supported in distributed training, test=develop (#17091) # This is the 12th commit message: Detailed coordinate description for yolov3 loss (#17007) * Detailed coordinate description for yolov3 loss test=develop * modified api.spec test=develop * modified loss name * fix api.spec test=develop * polish description test=develop * modified api.spec test=develop # This is the 13th commit message: fix test_weight_decay (#17109) test=develop # This is the 14th commit message: Path flag (#17105) * fix python/paddle/fluid/__init__.py detecting problems
6 years ago
#include <cstring>
#include <random>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
7 years ago
template <typename DeviceContext, typename T>
class CPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* seed =
context.HasInput("Seed") ? context.Input<Tensor>("Seed") : nullptr;
auto* y = context.Output<Tensor>("Out");
const auto* x_data = x->data<T>();
auto* y_data = y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");
Refine dropout gpu memory (#17095) * refine_dropout_mem,test=develop * # This is a combination of 14 commits. # The first commit's message is: remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (#17066) # This is the 2nd commit message: Fleet unify distributed training (#16791) * implement distributed transpiler with fleet # This is the 3rd commit message: ParallelDyGraph with GPU collective mode (#16827) implement dygraph.parallel.DataParallel to hook reduce op. # This is the 4th commit message: Init mixed precision training interface (#16856) * Init mixed precision training interface * Add fp16 test script test=develop * All initializers support float16 test=develop * Code cleanup & add more code annotations test=develop * Update API spec test=develop * Add usage example in doc test=develop # This is the 5th commit message: fix reference_count_pass,test=develop (#17060) test=develop # This is the 6th commit message: Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090) * Cache the information of linear interpolation in forward and use it in backward. test=develop * Fix cuda kernel. test=develop # This is the 7th commit message: remove unnecessary prepare_data (#17080) test=develop # This is the 8th commit message: fix interpolate cu. test=develop (#17101) # This is the 9th commit message: test=develop, double backward leaky_relu (#17067) backward of backward: leaky_relu # This is the 10th commit message: fix fuse optimizer ops (#17102) test=develop # This is the 11th commit message: truncated_gaussian_random supported in distributed training, test=develop (#17091) # This is the 12th commit message: Detailed coordinate description for yolov3 loss (#17007) * Detailed coordinate description for yolov3 loss test=develop * modified api.spec test=develop * modified loss name * fix api.spec test=develop * polish description test=develop * modified api.spec test=develop # This is the 13th commit message: fix test_weight_decay (#17109) test=develop # This is the 14th commit message: Path flag (#17105) * fix python/paddle/fluid/__init__.py detecting problems
6 years ago
auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
Refine dropout gpu memory (#17095) * refine_dropout_mem,test=develop * # This is a combination of 14 commits. # The first commit's message is: remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (#17066) # This is the 2nd commit message: Fleet unify distributed training (#16791) * implement distributed transpiler with fleet # This is the 3rd commit message: ParallelDyGraph with GPU collective mode (#16827) implement dygraph.parallel.DataParallel to hook reduce op. # This is the 4th commit message: Init mixed precision training interface (#16856) * Init mixed precision training interface * Add fp16 test script test=develop * All initializers support float16 test=develop * Code cleanup & add more code annotations test=develop * Update API spec test=develop * Add usage example in doc test=develop # This is the 5th commit message: fix reference_count_pass,test=develop (#17060) test=develop # This is the 6th commit message: Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090) * Cache the information of linear interpolation in forward and use it in backward. test=develop * Fix cuda kernel. test=develop # This is the 7th commit message: remove unnecessary prepare_data (#17080) test=develop # This is the 8th commit message: fix interpolate cu. test=develop (#17101) # This is the 9th commit message: test=develop, double backward leaky_relu (#17067) backward of backward: leaky_relu # This is the 10th commit message: fix fuse optimizer ops (#17102) test=develop # This is the 11th commit message: truncated_gaussian_random supported in distributed training, test=develop (#17091) # This is the 12th commit message: Detailed coordinate description for yolov3 loss (#17007) * Detailed coordinate description for yolov3 loss test=develop * modified api.spec test=develop * modified loss name * fix api.spec test=develop * polish description test=develop * modified api.spec test=develop # This is the 13th commit message: fix test_weight_decay (#17109) test=develop # This is the 14th commit message: Path flag (#17105) * fix python/paddle/fluid/__init__.py detecting problems
6 years ago
bool upscale_in_train = (dropout_implementation == "upscale_in_train");
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
Refine dropout gpu memory (#17095) * refine_dropout_mem,test=develop * # This is a combination of 14 commits. # The first commit's message is: remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (#17066) # This is the 2nd commit message: Fleet unify distributed training (#16791) * implement distributed transpiler with fleet # This is the 3rd commit message: ParallelDyGraph with GPU collective mode (#16827) implement dygraph.parallel.DataParallel to hook reduce op. # This is the 4th commit message: Init mixed precision training interface (#16856) * Init mixed precision training interface * Add fp16 test script test=develop * All initializers support float16 test=develop * Code cleanup & add more code annotations test=develop * Update API spec test=develop * Add usage example in doc test=develop # This is the 5th commit message: fix reference_count_pass,test=develop (#17060) test=develop # This is the 6th commit message: Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090) * Cache the information of linear interpolation in forward and use it in backward. test=develop * Fix cuda kernel. test=develop # This is the 7th commit message: remove unnecessary prepare_data (#17080) test=develop # This is the 8th commit message: fix interpolate cu. test=develop (#17101) # This is the 9th commit message: test=develop, double backward leaky_relu (#17067) backward of backward: leaky_relu # This is the 10th commit message: fix fuse optimizer ops (#17102) test=develop # This is the 11th commit message: truncated_gaussian_random supported in distributed training, test=develop (#17091) # This is the 12th commit message: Detailed coordinate description for yolov3 loss (#17007) * Detailed coordinate description for yolov3 loss test=develop * modified api.spec test=develop * modified loss name * fix api.spec test=develop * polish description test=develop * modified api.spec test=develop # This is the 13th commit message: fix test_weight_decay (#17109) test=develop # This is the 14th commit message: Path flag (#17105) * fix python/paddle/fluid/__init__.py detecting problems
6 years ago
auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
size_t size = framework::product(mask->dims());
// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
std::memset(y_data, 0, size * sizeof(*y_data)); // NOLINT
std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT
return;
}
bool init_generator_py = framework::Generator::GetInstance()->is_init_py;
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
std::random_device rnd;
std::minstd_rand engine;
int seed_data;
if (seed) {
seed_data = *(seed->data<int>());
} else {
seed_data =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
}
engine.seed(seed_data);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
float cur_random =
init_generator_py
? dist(framework::Generator::GetInstance()->GetCPUEngine())
: dist(engine);
if (cur_random < dropout_prob) {
mask_data[i] = 0;
y_data[i] = 0;
} else {
Refine dropout gpu memory (#17095) * refine_dropout_mem,test=develop * # This is a combination of 14 commits. # The first commit's message is: remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (#17066) # This is the 2nd commit message: Fleet unify distributed training (#16791) * implement distributed transpiler with fleet # This is the 3rd commit message: ParallelDyGraph with GPU collective mode (#16827) implement dygraph.parallel.DataParallel to hook reduce op. # This is the 4th commit message: Init mixed precision training interface (#16856) * Init mixed precision training interface * Add fp16 test script test=develop * All initializers support float16 test=develop * Code cleanup & add more code annotations test=develop * Update API spec test=develop * Add usage example in doc test=develop # This is the 5th commit message: fix reference_count_pass,test=develop (#17060) test=develop # This is the 6th commit message: Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090) * Cache the information of linear interpolation in forward and use it in backward. test=develop * Fix cuda kernel. test=develop # This is the 7th commit message: remove unnecessary prepare_data (#17080) test=develop # This is the 8th commit message: fix interpolate cu. test=develop (#17101) # This is the 9th commit message: test=develop, double backward leaky_relu (#17067) backward of backward: leaky_relu # This is the 10th commit message: fix fuse optimizer ops (#17102) test=develop # This is the 11th commit message: truncated_gaussian_random supported in distributed training, test=develop (#17091) # This is the 12th commit message: Detailed coordinate description for yolov3 loss (#17007) * Detailed coordinate description for yolov3 loss test=develop * modified api.spec test=develop * modified loss name * fix api.spec test=develop * polish description test=develop * modified api.spec test=develop # This is the 13th commit message: fix test_weight_decay (#17109) test=develop # This is the 14th commit message: Path flag (#17105) * fix python/paddle/fluid/__init__.py detecting problems
6 years ago
mask_data[i] = 1;
if (upscale_in_train) {
y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
} else {
y_data[i] = x_data[i];
}
}
}
} else {
Refine dropout gpu memory (#17095) * refine_dropout_mem,test=develop * # This is a combination of 14 commits. # The first commit's message is: remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (#17066) # This is the 2nd commit message: Fleet unify distributed training (#16791) * implement distributed transpiler with fleet # This is the 3rd commit message: ParallelDyGraph with GPU collective mode (#16827) implement dygraph.parallel.DataParallel to hook reduce op. # This is the 4th commit message: Init mixed precision training interface (#16856) * Init mixed precision training interface * Add fp16 test script test=develop * All initializers support float16 test=develop * Code cleanup & add more code annotations test=develop * Update API spec test=develop * Add usage example in doc test=develop # This is the 5th commit message: fix reference_count_pass,test=develop (#17060) test=develop # This is the 6th commit message: Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090) * Cache the information of linear interpolation in forward and use it in backward. test=develop * Fix cuda kernel. test=develop # This is the 7th commit message: remove unnecessary prepare_data (#17080) test=develop # This is the 8th commit message: fix interpolate cu. test=develop (#17101) # This is the 9th commit message: test=develop, double backward leaky_relu (#17067) backward of backward: leaky_relu # This is the 10th commit message: fix fuse optimizer ops (#17102) test=develop # This is the 11th commit message: truncated_gaussian_random supported in distributed training, test=develop (#17091) # This is the 12th commit message: Detailed coordinate description for yolov3 loss (#17007) * Detailed coordinate description for yolov3 loss test=develop * modified api.spec test=develop * modified loss name * fix api.spec test=develop * polish description test=develop * modified api.spec test=develop # This is the 13th commit message: fix test_weight_decay (#17109) test=develop # This is the 14th commit message: Path flag (#17105) * fix python/paddle/fluid/__init__.py detecting problems
6 years ago
if (upscale_in_train) {
const auto* X_data = x->data<T>();
auto* Y_data = y->mutable_data<T>(context.GetPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < x->numel(); i++) {
Y_data[i] = X_data[i];
}
} else {
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
}
}
};
template <typename DeviceContext, typename T>
class DropoutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(!context.Attr<bool>("is_test"), true,
platform::errors::PreconditionNotMet(
"GradOp is only callable when is_test is false"));
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace());
Refine dropout gpu memory (#17095) * refine_dropout_mem,test=develop * # This is a combination of 14 commits. # The first commit's message is: remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (#17066) # This is the 2nd commit message: Fleet unify distributed training (#16791) * implement distributed transpiler with fleet # This is the 3rd commit message: ParallelDyGraph with GPU collective mode (#16827) implement dygraph.parallel.DataParallel to hook reduce op. # This is the 4th commit message: Init mixed precision training interface (#16856) * Init mixed precision training interface * Add fp16 test script test=develop * All initializers support float16 test=develop * Code cleanup & add more code annotations test=develop * Update API spec test=develop * Add usage example in doc test=develop # This is the 5th commit message: fix reference_count_pass,test=develop (#17060) test=develop # This is the 6th commit message: Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090) * Cache the information of linear interpolation in forward and use it in backward. test=develop * Fix cuda kernel. test=develop # This is the 7th commit message: remove unnecessary prepare_data (#17080) test=develop # This is the 8th commit message: fix interpolate cu. test=develop (#17101) # This is the 9th commit message: test=develop, double backward leaky_relu (#17067) backward of backward: leaky_relu # This is the 10th commit message: fix fuse optimizer ops (#17102) test=develop # This is the 11th commit message: truncated_gaussian_random supported in distributed training, test=develop (#17091) # This is the 12th commit message: Detailed coordinate description for yolov3 loss (#17007) * Detailed coordinate description for yolov3 loss test=develop * modified api.spec test=develop * modified loss name * fix api.spec test=develop * polish description test=develop * modified api.spec test=develop # This is the 13th commit message: fix test_weight_decay (#17109) test=develop # This is the 14th commit message: Path flag (#17105) * fix python/paddle/fluid/__init__.py detecting problems
6 years ago
auto M = EigenMatrix<uint8_t>::Reshape(*mask, 1);
auto dX = EigenMatrix<T>::Reshape(*grad_x, 1);
auto dY = EigenMatrix<T>::Reshape(*grad_y, 1);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Refine dropout gpu memory (#17095) * refine_dropout_mem,test=develop * # This is a combination of 14 commits. # The first commit's message is: remove ut test_dist_word2vec in mac ci, will fix it in private, test=develop (#17066) # This is the 2nd commit message: Fleet unify distributed training (#16791) * implement distributed transpiler with fleet # This is the 3rd commit message: ParallelDyGraph with GPU collective mode (#16827) implement dygraph.parallel.DataParallel to hook reduce op. # This is the 4th commit message: Init mixed precision training interface (#16856) * Init mixed precision training interface * Add fp16 test script test=develop * All initializers support float16 test=develop * Code cleanup & add more code annotations test=develop * Update API spec test=develop * Add usage example in doc test=develop # This is the 5th commit message: fix reference_count_pass,test=develop (#17060) test=develop # This is the 6th commit message: Speedup roi_perspective_transform op by caching the information of linear interpolation in forward (#17090) * Cache the information of linear interpolation in forward and use it in backward. test=develop * Fix cuda kernel. test=develop # This is the 7th commit message: remove unnecessary prepare_data (#17080) test=develop # This is the 8th commit message: fix interpolate cu. test=develop (#17101) # This is the 9th commit message: test=develop, double backward leaky_relu (#17067) backward of backward: leaky_relu # This is the 10th commit message: fix fuse optimizer ops (#17102) test=develop # This is the 11th commit message: truncated_gaussian_random supported in distributed training, test=develop (#17091) # This is the 12th commit message: Detailed coordinate description for yolov3 loss (#17007) * Detailed coordinate description for yolov3 loss test=develop * modified api.spec test=develop * modified loss name * fix api.spec test=develop * polish description test=develop * modified api.spec test=develop # This is the 13th commit message: fix test_weight_decay (#17109) test=develop # This is the 14th commit message: Path flag (#17105) * fix python/paddle/fluid/__init__.py detecting problems
6 years ago
auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
if (dropout_implementation == "upscale_in_train") {
float dropout_prob = context.Attr<float>("dropout_prob");
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
} else {
dX.device(place) =
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
} else {
dX.device(place) = dY * M.cast<T>();
}
}
};
} // namespace operators
} // namespace paddle