Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_googlenet_bug_with_rule

test=develop
revert-14324-fix_vlog
nhzlx 7 years ago
commit 9d98ca0424

@ -18,7 +18,7 @@ function(copy TARGET)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DSTS DEPS) set(multiValueArgs SRCS DSTS DEPS)
cmake_parse_arguments(copy_lib "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(copy_lib "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(inference_lib_dist_dep ${TARGET} ${inference_lib_dist_dep} PARENT_SCOPE) set(fluid_lib_dist_dep ${TARGET} ${fluid_lib_dist_dep} PARENT_SCOPE)
list(LENGTH copy_lib_SRCS copy_lib_SRCS_len) list(LENGTH copy_lib_SRCS copy_lib_SRCS_len)
list(LENGTH copy_lib_DSTS copy_lib_DSTS_len) list(LENGTH copy_lib_DSTS copy_lib_DSTS_len)
@ -185,7 +185,8 @@ copy(cmake_cache
SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt
DSTS ${FLUID_INSTALL_DIR}) DSTS ${FLUID_INSTALL_DIR})
add_custom_target(inference_lib_dist DEPENDS ${inference_lib_dist_dep}) # This command generates a complete fluid library for both train and inference
add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep})
# paddle fluid version # paddle fluid version
execute_process( execute_process(

@ -75,7 +75,8 @@ paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'outp
paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None)) paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.sequence_expand_as ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_expand_as ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen', 'name'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.sequence_unpad ArgSpec(args=['x', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None)) paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None))
paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)) paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
paddle.fluid.layers.reduce_mean ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None)) paddle.fluid.layers.reduce_mean ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
@ -127,6 +128,7 @@ paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.margin_rank_loss ArgSpec(args=['label', 'left', 'right', 'margin', 'name'], varargs=None, keywords=None, defaults=(0.1, None))
paddle.fluid.layers.elu ArgSpec(args=['x', 'alpha', 'name'], varargs=None, keywords=None, defaults=(1.0, None)) paddle.fluid.layers.elu ArgSpec(args=['x', 'alpha', 'name'], varargs=None, keywords=None, defaults=(1.0, None))
paddle.fluid.layers.relu6 ArgSpec(args=['x', 'threshold', 'name'], varargs=None, keywords=None, defaults=(6.0, None)) paddle.fluid.layers.relu6 ArgSpec(args=['x', 'threshold', 'name'], varargs=None, keywords=None, defaults=(6.0, None))
paddle.fluid.layers.pow ArgSpec(args=['x', 'factor', 'name'], varargs=None, keywords=None, defaults=(1.0, None)) paddle.fluid.layers.pow ArgSpec(args=['x', 'factor', 'name'], varargs=None, keywords=None, defaults=(1.0, None))

@ -12,6 +12,5 @@ endif(NOT WIN32)
if(WITH_INFERENCE) if(WITH_INFERENCE)
# NOTE: please add subdirectory inference at last. # NOTE: please add subdirectory inference at last.
add_subdirectory(inference) add_subdirectory(inference)
add_subdirectory(train)
endif() endif()
add_subdirectory(train)

@ -64,7 +64,8 @@ class OpHandleBase {
virtual bool IsMultiDeviceTransfer() { return false; } virtual bool IsMultiDeviceTransfer() { return false; }
const platform::DeviceContext *DeviceContext(platform::Place place) { const platform::DeviceContext *DeviceContext(platform::Place place) {
return dev_ctxes_[place]; auto it = dev_ctxes_.find(place);
return it != dev_ctxes_.end() ? it->second : nullptr;
} }
void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) { void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) {

@ -46,6 +46,41 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext"; VLOG(5) << "destroy ExecutorPrepareContext";
} }
template <typename RefCntMap>
static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
GarbageCollector<Tensor>* gc,
RefCntMap* ref_cnts) {
std::unordered_set<Tensor*> erase_tensors;
auto handler = [&](const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
auto it = ref_cnts->find(name);
if (it == ref_cnts->end()) continue;
if ((it->second)-- == 1) {
auto* var = scope.FindVar(name);
if (var != nullptr) {
VLOG(10) << "Erase tensor \'" << name << "\'";
if (var->IsType<LoDTensor>()) {
erase_tensors.insert(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) {
erase_tensors.insert(
var->GetMutable<SelectedRows>()->mutable_value());
}
}
}
}
}
};
handler(op->Inputs());
handler(op->Outputs());
if (!erase_tensors.empty()) {
gc->Add(erase_tensors);
}
}
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
void Executor::Close() { void Executor::Close() {
@ -331,9 +366,13 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
int64_t max_memory_size = GetEagerDeletionThreshold(); int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector<Tensor>> gc; std::unique_ptr<GarbageCollector<Tensor>> gc;
if (max_memory_size >= 0) { // WhileOp would set keep_kids to false
// WhileGradOp would need the scopes created in WhileOp
// Perhaps, we should not perform eager deletion in WhileOp
// The scopes and variables created by WhileOp would be deleted
// in WhileGradOp.
if (max_memory_size >= 0 && !keep_kids) {
ctx->ResetReferenceCount(); ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
@ -352,45 +391,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (gc != nullptr) { if (gc != nullptr) {
std::vector<std::string> erase_vars; DeleteUnusedTensors(*local_scope, op.get(), gc.get(),
for (auto& input : op->Inputs()) { &(ctx->cur_ref_cnts_));
for (auto& input_name : input.second) {
auto it = ctx->cur_ref_cnts_.find(input_name);
if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) { // should delete it
erase_vars.emplace_back(input_name);
ctx->cur_ref_cnts_.erase(input_name);
} else {
--(it->second);
}
}
}
for (auto& output : op->Outputs()) {
for (auto& output_name : output.second) {
auto it = ctx->cur_ref_cnts_.find(output_name);
if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) {
erase_vars.emplace_back(output_name);
ctx->cur_ref_cnts_.erase(output_name);
} else {
--(it->second);
}
}
}
if (!erase_vars.empty()) {
std::vector<framework::LoDTensor*> erase_tensors;
for (auto& name : erase_vars) {
auto* var = local_scope->FindVar(name);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
erase_tensors.push_back(tensor);
}
}
if (!erase_tensors.empty()) gc->Add(erase_tensors);
}
} }
if (FLAGS_benchmark) { if (FLAGS_benchmark) {

@ -32,38 +32,32 @@ template <typename T>
std::unordered_map<std::string, T> GetNonPersistableReferenceCount( std::unordered_map<std::string, T> GetNonPersistableReferenceCount(
const ProgramDesc& prog, size_t block_id) { const ProgramDesc& prog, size_t block_id) {
auto& block = prog.Block(block_id); auto& block = prog.Block(block_id);
std::unordered_set<std::string> ignored_vars;
std::unordered_map<std::string, T> ref_cnts; std::unordered_map<std::string, T> ref_cnts;
for (auto var_desc : block.AllVars()) { auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) {
auto type = var_desc->Proto()->type().type(); for (auto& name_pair : name_map) {
if (type != proto::VarType::LOD_TENSOR || var_desc->Persistable()) { for (auto& name : name_pair.second) {
ignored_vars.insert(var_desc->Name()); // ignore persistable vars auto* var_desc = block.FindVar(name);
} if (var_desc == nullptr || var_desc->Persistable()) continue;
} auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR &&
for (auto op_desc : block.AllOps()) { type != proto::VarType::SELECTED_ROWS) {
for (auto& input : op_desc->Inputs()) { continue;
for (auto& input_name : input.second) {
if (!ignored_vars.count(input_name)) {
if (ref_cnts.count(input_name))
++ref_cnts[input_name];
else
ref_cnts[input_name] = 1;
} }
}
}
for (auto& output : op_desc->Outputs()) { auto it = ref_cnts.find(name);
for (auto output_name : output.second) { if (it != ref_cnts.end()) {
if (!ignored_vars.count(output_name)) { ++it->second;
if (ref_cnts.count(output_name)) } else {
++ref_cnts[output_name]; ref_cnts[name] = 1;
else
ref_cnts[output_name] = 1;
} }
} }
} }
};
for (auto op_desc : block.AllOps()) {
update_ref_cnts(op_desc, op_desc->Inputs());
update_ref_cnts(op_desc, op_desc->Outputs());
} }
return ref_cnts; return ref_cnts;
} }

@ -44,89 +44,6 @@ namespace ir {
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \ GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name) GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)
template <typename UnaryOperation>
LoDTensor tensor_apply(const LoDTensor& vec, UnaryOperation f) {
LoDTensor vec_y;
vec_y.Resize(vec.dims());
const float* x = vec.data<float>();
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
for (int64_t i = 0; i < vec.numel(); i++) {
y[i] = f(x[i]);
}
return vec_y;
}
void tensor_apply_inplace(LoDTensor* vec, float (*f)(float)) {
float* data = vec->mutable_data<float>(platform::CPUPlace());
for (int64_t i = 0; i < vec->numel(); i++) {
data[i] = f(data[i]);
}
}
template <typename BinaryOperation>
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
BinaryOperation f) {
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
LoDTensor vec_y;
vec_y.Resize(vec_a.dims());
const float* a = vec_a.data<float>();
const float* b = vec_b.data<float>();
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
for (int64_t i = 0; i < vec_a.numel(); i++) {
y[i] = f(a[i], b[i]);
}
return vec_y;
}
template <typename BinaryOperation>
LoDTensor tensor_apply_eltwise_broadcast(const LoDTensor& vec_a,
const LoDTensor& vec_b,
BinaryOperation f) {
PADDLE_ENFORCE_EQ(vec_a.dims().size(), 2);
PADDLE_ENFORCE_EQ(vec_b.dims().size(), 2);
PADDLE_ENFORCE_EQ(vec_a.dims()[0], vec_b.dims()[0]);
PADDLE_ENFORCE_EQ(vec_b.dims()[1], 1);
LoDTensor vec_y;
vec_y.Resize(vec_a.dims());
const float* a = vec_a.data<float>();
const float* b = vec_b.data<float>();
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
size_t a_height = vec_a.dims()[0];
size_t a_width = vec_a.dims()[1];
for (size_t h = 0; h < a_height; h++) {
for (size_t w = 0; w < a_width; ++w) {
*(y++) = f(*(a++), b[h]);
}
}
return vec_y;
}
// reshape to two dimensions {A, B * C * ...}
void make_tensor_2d(LoDTensor* tensor_to_reshape) {
auto dims_count = tensor_to_reshape->dims().size();
PADDLE_ENFORCE_GT(dims_count, 0);
int size2 = 1;
for (int i = 1; i < dims_count; i++) {
size2 *= tensor_to_reshape->dims()[i];
}
tensor_to_reshape->Resize(make_ddim({tensor_to_reshape->dims()[0], size2}));
}
void recompute_conv_weights(LoDTensor* weights, LoDTensor* tmp) {
// remember the weights tensor shape {A, B, C, ...}
auto weights_shape = weights->dims();
// reduce the weights to 2d {A, B * C * ...}
make_tensor_2d(weights);
// make tmp tensor 2d by adding 1 as second dim {A, 1}
make_tensor_2d(tmp);
*weights =
tensor_apply_eltwise_broadcast(*weights, *tmp, std::multiplies<float>());
// reshape weights to the original dims {A, B, C, ...}
weights->Resize(weights_shape);
}
void recompute_bias_and_weights(const Scope* scope, void recompute_bias_and_weights(const Scope* scope,
ir::Node* conv_weight, // ir::Node* conv_weight, //
const ir::Node& bn_scale, // const ir::Node& bn_scale, //
@ -135,6 +52,13 @@ void recompute_bias_and_weights(const Scope* scope,
const ir::Node& bn_variance, // const ir::Node& bn_variance, //
LoDTensor* eltwise_y_in_tensor, // LoDTensor* eltwise_y_in_tensor, //
float epsilon) { float epsilon) {
using EigenVectorArrayMap =
Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
using EigenMatrixArrayMap = Eigen::Map<
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
// Re-compute bias of conv2d from BN // Re-compute bias of conv2d from BN
PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), bn_bias_tensor.dims()); PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), bn_bias_tensor.dims());
@ -143,31 +67,38 @@ void recompute_bias_and_weights(const Scope* scope,
scope->FindVar(bn_variance.Name())->GetMutable<LoDTensor>(); scope->FindVar(bn_variance.Name())->GetMutable<LoDTensor>();
auto* mean_tensor = scope->FindVar(bn_mean.Name())->GetMutable<LoDTensor>(); auto* mean_tensor = scope->FindVar(bn_mean.Name())->GetMutable<LoDTensor>();
auto std_tensor = LoDTensor(); ConstEigenVectorArrayMap scale_array(scale_tensor->data<float>(),
std_tensor.Resize(bn_bias_tensor.dims()); scale_tensor->numel(), 1);
std_tensor = EigenVectorArrayMap variance_array(
tensor_apply(*variance_tensor, [&](float x) { return x + epsilon; }); variance_tensor->mutable_data<float>(platform::CPUPlace()),
variance_tensor->numel(), 1);
ConstEigenVectorArrayMap mean_array(mean_tensor->data<float>(),
mean_tensor->numel(), 1);
ConstEigenVectorArrayMap bn_bias_array(bn_bias_tensor.data<float>(),
bn_bias_tensor.numel(), 1);
using EigenVectorArrayMap = // variance will not be used anymore, so make it std_array and then tmp_array
Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>; variance_array += epsilon;
variance_array = variance_array.sqrt();
variance_array = scale_array / variance_array;
EigenVectorArrayMap eltwise_y_in_array(
eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(), 1);
EigenVectorArrayMap std_vec( eltwise_y_in_array =
std_tensor.mutable_data<float>(platform::CPUPlace()), std_tensor.numel(), ((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array;
1);
std_vec = std_vec.sqrt();
auto tmp_tensor =
tensor_apply_eltwise(*scale_tensor, std_tensor, std::divides<float>());
auto tensor_minus = tensor_apply_eltwise(*eltwise_y_in_tensor, *mean_tensor,
std::minus<float>());
auto tensor_mul =
tensor_apply_eltwise(tensor_minus, tmp_tensor, std::multiplies<float>());
*eltwise_y_in_tensor =
tensor_apply_eltwise(tensor_mul, bn_bias_tensor, std::plus<float>());
// Re-compute weight of conv2d from BN // Re-compute weight of conv2d from BN
auto* current_param = auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>(); auto weights_shape = weights->dims();
recompute_conv_weights(current_param, &tmp_tensor); auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
EigenMatrixArrayMap weights_array_2d(
weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0],
weights_shape_2d[1]);
weights_array_2d.colwise() *= variance_array;
} }
std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl( std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(

@ -307,6 +307,10 @@ ParallelExecutor::~ParallelExecutor() {
} }
} }
} }
// member_ must be destructed before gcs_ since the destructor of
// ReferenceCountOpHandle use raw pointers of gcs_ inside.
member_.reset();
} }
} // namespace framework } // namespace framework

@ -75,7 +75,7 @@ class ParallelExecutor {
private: private:
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const; void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
ParallelExecutorPrivate *member_; std::unique_ptr<ParallelExecutorPrivate> member_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
// ref_cnts_ is only initialized when ParallelExecutor constructs, and then // ref_cnts_ is only initialized when ParallelExecutor constructs, and then

@ -49,18 +49,18 @@ int64_t GetEagerDeletionThreshold() {
Scope::~Scope() { DropKids(); } Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
kids_.push_back(new Scope(this)); kids_.push_back(new Scope(this));
return *kids_.back(); return *kids_.back();
} }
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return VarInternal(name); return VarInternal(name);
} }
Variable* Scope::Var(std::string* name) { Variable* Scope::Var(std::string* name) {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto new_name = string::Sprintf("%p.%d", this, vars_.size()); auto new_name = string::Sprintf("%p.%d", this, vars_.size());
if (name != nullptr) { if (name != nullptr) {
*name = new_name; *name = new_name;
@ -69,29 +69,34 @@ Variable* Scope::Var(std::string* name) {
} }
Variable* Scope::FindVar(const std::string& name) const { Variable* Scope::FindVar(const std::string& name) const {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return FindVarInternal(name); return FindVarInternal(name);
} }
Variable* Scope::FindLocalVar(const std::string& name) const {
std::lock_guard<std::mutex> lock(mutex_);
return FindVarLocally(name);
}
const Scope* Scope::FindScope(const Variable* var) const { const Scope* Scope::FindScope(const Variable* var) const {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return FindScopeInternal(var); return FindScopeInternal(var);
} }
void Scope::DropKids() { void Scope::DropKids() {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
for (Scope* s : kids_) delete s; for (Scope* s : kids_) delete s;
kids_.clear(); kids_.clear();
} }
bool Scope::HasKid(const Scope* scope) const { bool Scope::HasKid(const Scope* scope) const {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
return it != this->kids_.end(); return it != this->kids_.end();
} }
std::vector<std::string> Scope::LocalVarNames() const { std::vector<std::string> Scope::LocalVarNames() const {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
std::vector<std::string> known_vars; std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size()); known_vars.reserve(this->vars_.size());
for (auto& p : vars_) { for (auto& p : vars_) {
@ -101,7 +106,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
} }
void Scope::DeleteScope(Scope* scope) const { void Scope::DeleteScope(Scope* scope) const {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
this->kids_.erase(it); this->kids_.erase(it);
@ -114,7 +119,7 @@ void Scope::DeleteScope(Scope* scope) const {
} }
void Scope::EraseVars(const std::vector<std::string>& var_names) { void Scope::EraseVars(const std::vector<std::string>& var_names) {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
std::set<std::string> var_set(var_names.begin(), var_names.end()); std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) { for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) { if (var_set.find(it->first) != var_set.end()) {
@ -127,12 +132,12 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {
void Scope::Rename(const std::string& origin_name, void Scope::Rename(const std::string& origin_name,
const std::string& new_name) const { const std::string& new_name) const {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
RenameInternal(origin_name, new_name); RenameInternal(origin_name, new_name);
} }
std::string Scope::Rename(const std::string& origin_name) const { std::string Scope::Rename(const std::string& origin_name) const {
std::unique_lock<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto new_name = string::Sprintf("%p.%d", this, vars_.size()); auto new_name = string::Sprintf("%p.%d", this, vars_.size());
RenameInternal(origin_name, new_name); RenameInternal(origin_name, new_name);
return new_name; return new_name;

@ -63,6 +63,11 @@ class Scope {
/// Caller doesn't own the returned Variable. /// Caller doesn't own the returned Variable.
Variable* FindVar(const std::string& name) const; Variable* FindVar(const std::string& name) const;
/// Find a variable in the current scope.
/// Return nullptr if cannot find.
/// Caller doesn't own the returned Variable.
Variable* FindLocalVar(const std::string& name) const;
const Scope* parent() const { return parent_; } const Scope* parent() const { return parent_; }
/// Find the scope or an ancestor scope that contains the given variable. /// Find the scope or an ancestor scope that contains the given variable.

@ -36,6 +36,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
auto size = src.numel() * SizeOfType(src.type()); auto size = src.numel() * SizeOfType(src.type());
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size); boost::get<platform::CPUPlace>(src_place), src_ptr, size);
} }
@ -71,6 +76,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream(); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
if (platform::is_same_place(src_place, dst_place)) { if (platform::is_same_place(src_place, dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
stream); stream);
} else { } else {
@ -114,6 +124,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
auto dst_ptr = dst->mutable_data(dst_place, src.type()); auto dst_ptr = dst->mutable_data(dst_place, src.type());
auto size = src.numel() * SizeOfType(src.type()); auto size = src.numel() * SizeOfType(src.type());
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size); boost::get<platform::CPUPlace>(src_place), src_ptr, size);
} }
@ -130,6 +145,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, nullptr); memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, nullptr);
} else if (platform::is_gpu_place(src_place) && } else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) { platform::is_gpu_place(dst_place)) {
if (src_ptr == dst_ptr && platform::is_same_place(src_place, dst_place)) {
VLOG(3) << "Skip copy the same data from " << src_place << " to "
<< dst_place;
return;
}
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place); auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place); auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr); memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr);

@ -41,6 +41,11 @@ TEST(TensorCopy, Tensor) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]); EXPECT_EQ(src_ptr[i], dst_ptr[i]);
} }
TensorCopy(dst_tensor, *cpu_place, &dst_tensor);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout()); EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
Tensor slice_tensor = src_tensor.Slice(1, 2); Tensor slice_tensor = src_tensor.Slice(1, 2);
@ -82,6 +87,15 @@ TEST(TensorCopy, Tensor) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]); EXPECT_EQ(src_ptr[i], dst_ptr[i]);
} }
// Copy the same tensor
TensorCopy(gpu_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
gpu_ctx.Wait();
const int* dst_ptr_tmp = dst_tensor.data<int>();
EXPECT_NE(src_ptr, dst_ptr_tmp);
for (size_t i = 0; i < 9; ++i) {
EXPECT_EQ(src_ptr[i], dst_ptr_tmp[i]);
}
Tensor slice_tensor = src_tensor.Slice(1, 2); Tensor slice_tensor = src_tensor.Slice(1, 2);
// CPU Slice Tensor to GPU Tensor // CPU Slice Tensor to GPU Tensor

@ -70,7 +70,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
auto trt_teller = [&](const Node* node) { auto trt_teller = [&](const Node* node) {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "dropout"}); "elementwise_add", "dropout"});
if (!node->IsFunction()) return false; if (!node->IsFunction()) return false;

@ -25,9 +25,11 @@
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DECLARE_bool(profile); DECLARE_bool(profile);
DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
@ -47,6 +49,9 @@ bool AnalysisPredictor::Init(
} }
#endif #endif
// no matter with or without MKLDNN
paddle::platform::SetNumThreads(FLAGS_paddle_num_threads);
if (config_.use_gpu) { if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device); place_ = paddle::platform::CUDAPlace(config_.device);
LOG(WARNING) << "ir optimize only supports CPU currently, enable_ir_optim " LOG(WARNING) << "ir optimize only supports CPU currently, enable_ir_optim "

@ -23,9 +23,11 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DEFINE_bool(profile, false, "Turn on profiler for fluid"); DEFINE_bool(profile, false, "Turn on profiler for fluid");
DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
namespace { namespace {
@ -72,6 +74,9 @@ bool NativePaddlePredictor::Init(
} }
#endif #endif
// no matter with or without MKLDNN
paddle::platform::SetNumThreads(FLAGS_paddle_num_threads);
if (config_.use_gpu) { if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device); place_ = paddle::platform::CUDAPlace(config_.device);
} else { } else {

@ -185,3 +185,4 @@ USE_TRT_CONVERTER(softmax);
USE_TRT_CONVERTER(batch_norm); USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat); USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad);

@ -1,7 +1,7 @@
# Add TRT tests # Add TRT tests
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry) DEPS tensorrt_engine operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
@ -26,6 +26,8 @@ nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL)
nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL)
nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL)
nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL)

@ -0,0 +1,68 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* PadOp.
*/
class PadOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid transpose op to tensorrt tranpose layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
const std::vector<int> paddings =
boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
const float pad_value = boost::get<float>(op_desc.GetAttr("pad_value"));
nvinfer1::Dims input_shape = input->getDimensions();
int nbDims = input_shape.nbDims;
int pad_size = static_cast<int>(paddings.size());
PADDLE_ENFORCE_GE(nbDims, 2);
PADDLE_ENFORCE_EQ((nbDims + 1) * 2, pad_size);
PADDLE_ENFORCE(pad_value == 0.0, "The pad layer of TRT only support zero.");
nvinfer1::DimsHW pre_pad(paddings[pad_size - 4], paddings[pad_size - 2]);
nvinfer1::DimsHW post_pad(paddings[pad_size - 3], paddings[pad_size - 1]);
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Padding,
*const_cast<nvinfer1::ITensor*>(input),
pre_pad, post_pad);
PADDLE_ENFORCE(layer != nullptr);
auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
layer->setName(("scale (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(pad, PadOpConverter);

@ -0,0 +1,52 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(PadConverter, main) {
framework::Scope scope;
std::unordered_set<std::string> parameters;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("pad-X", nvinfer1::Dims3(3, 2, 2));
validator.DeclOutputVar("pad-Out", nvinfer1::Dims3(3, 3, 5));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("pad");
desc.SetInput("X", {"pad-X"});
desc.SetOutput("Out", {"pad-Out"});
std::vector<int> paddings = {0, 0, 0, 0, 0, 1, 1, 2};
float pad_value = 0.0;
desc.SetAttr("paddings", paddings);
desc.SetAttr("pad_value", pad_value);
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
validator.Execute(2);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(pad);

@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
class AdadeltaOp : public framework::OperatorWithKernel { class AdadeltaOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
@ -31,6 +32,16 @@ class AdadeltaOp : public framework::OperatorWithKernel {
"Input(AvgSquaredGrad) of AdadeltaOp should not be null."); "Input(AvgSquaredGrad) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"), PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"),
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null."); "Input(AvgSquaredUpdate) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdadeltaOp should not be null."); "Output(ParamOut) of AdadeltaOp should not be null.");
@ -56,6 +67,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("AvgSquaredGradOut", param_dim); ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim); ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =

@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class AdadeltaOpKernel : public framework::OpKernel<T> { class AdadeltaOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut"); auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto avg_squared_grad_out_tensor = auto avg_squared_grad_out_tensor =
ctx.Output<framework::Tensor>("AvgSquaredGradOut"); ctx.Output<framework::Tensor>("AvgSquaredGradOut");

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
@ -21,25 +22,31 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct SparseAdagradFunctor { struct SparseAdagradFunctor {
void operator()(const DeviceContext& context, void operator()(const DeviceContext &context,
const framework::SelectedRows& grad, const framework::SelectedRows &grad,
const framework::Tensor& learning_rate, T epsilon, const framework::Tensor &learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param); framework::Tensor *moment, framework::Tensor *param);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class AdagradOpKernel : public framework::OpKernel<T> { class AdagradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto* param_out_tensor = ctx.Output<framework::Tensor>("ParamOut"); const auto *param_var = ctx.InputVar("Param");
auto* moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut"); PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto *param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto *moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
param_out_tensor->mutable_data<T>(ctx.GetPlace()); param_out_tensor->mutable_data<T>(ctx.GetPlace());
moment_out_tensor->mutable_data<T>(ctx.GetPlace()); moment_out_tensor->mutable_data<T>(ctx.GetPlace());
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon")); T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto* grad_var = ctx.InputVar("Grad"); auto *grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::LoDTensor>()) { if (grad_var->IsType<framework::LoDTensor>()) {
auto param = framework::EigenVector<T>::Flatten( auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param")); *ctx.Input<framework::Tensor>("Param"));
@ -47,16 +54,16 @@ class AdagradOpKernel : public framework::OpKernel<T> {
*ctx.Input<framework::Tensor>("Grad")); *ctx.Input<framework::Tensor>("Grad"));
auto moment = framework::EigenVector<T>::Flatten( auto moment = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment")); *ctx.Input<framework::Tensor>("Moment"));
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate"); auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor); auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor); auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto* place = ctx.template device_context<DeviceContext>().eigen_device(); auto *place = ctx.template device_context<DeviceContext>().eigen_device();
moment_out.device(*place) = moment + grad * grad; moment_out.device(*place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel()); Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
auto* lr = learning_rate->data<T>(); auto *lr = learning_rate->data<T>();
param_out.device(*place) = param_out.device(*place) =
param - lr[0] * grad / (moment_out.sqrt() + epsilon); param - lr[0] * grad / (moment_out.sqrt() + epsilon);
} else { } else {
@ -66,10 +73,10 @@ class AdagradOpKernel : public framework::OpKernel<T> {
lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
} }
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
auto* param_tensor = ctx.Input<framework::Tensor>("Param"); auto *param_tensor = ctx.Input<framework::Tensor>("Param");
PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor); PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor);
auto* moment_tensor = ctx.Input<framework::Tensor>("Moment"); auto *moment_tensor = ctx.Input<framework::Tensor>("Moment");
PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor); PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor);
SparseAdagradFunctor<DeviceContext, T> functor; SparseAdagradFunctor<DeviceContext, T> functor;

@ -18,6 +18,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
@ -199,23 +200,9 @@ struct SparseAdamFunctor {
row_numel_(row_numel), row_numel_(row_numel),
row_count_(row_count) {} row_count_(row_count) {}
inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
int64_t beg = 0, end = row_count_ - 1;
while (beg <= end) {
auto mid = ((beg + end) >> 1);
if (rows_[mid] == row)
return mid;
else if (rows_[mid] < row)
beg = mid + 1;
else
end = mid - 1;
}
return -1;
}
inline HOSTDEVICE void operator()(size_t i) const { inline HOSTDEVICE void operator()(size_t i) const {
int64_t row = i / row_numel_; auto row_idx =
auto row_idx = BinarySearchInRows(row); math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
// The following code is the same as dense // The following code is the same as dense
@ -244,6 +231,12 @@ template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> { class AdamOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref; using paddle::operators::detail::Ref;

@ -35,6 +35,16 @@ class AdamaxOp : public framework::OperatorWithKernel {
"Input(LearningRate) of AdamaxOp should not be null."); "Input(LearningRate) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"), PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
"Input(Beta1Pow) of AdamaxOp should not be null."); "Input(Beta1Pow) of AdamaxOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdamaxOp should not be null."); "Output(ParamOut) of AdamaxOp should not be null.");

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save