Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-dist-sparse-decay
test=developrevert-15207-remove_op_handle_lock_and_fix_var
commit
d0e3b24002
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,130 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
|
||||||
|
#define PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <boost/algorithm/string/predicate.hpp>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/ir/graph.h"
|
||||||
|
#include "paddle/fluid/framework/ir/pass.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
class Node;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Remove the sum op of all gradients of the backward op.
|
||||||
|
* And remove the dependecies of the optimizer related to the
|
||||||
|
* same backward op.
|
||||||
|
*
|
||||||
|
* Before this pass:
|
||||||
|
*
|
||||||
|
* forward_op1 forward_op2
|
||||||
|
* | |
|
||||||
|
* grad_op1 grad_op2
|
||||||
|
* \ /
|
||||||
|
* \ /
|
||||||
|
* sum_op
|
||||||
|
* |
|
||||||
|
* sgd_op
|
||||||
|
*
|
||||||
|
* After this pass:
|
||||||
|
* forward_op1 forward_op2
|
||||||
|
* | |
|
||||||
|
* grad_op1 grad_op2
|
||||||
|
* | |
|
||||||
|
* sgd_op1 sgd_op2
|
||||||
|
*
|
||||||
|
* sgd_op1 and sgd_op2 will update the same weight which holds the same
|
||||||
|
* memory, so we could benefits from the acceleration
|
||||||
|
*/
|
||||||
|
class LockFreeOptimizePass : public Pass {
|
||||||
|
public:
|
||||||
|
virtual ~LockFreeOptimizePass() {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Create a new sgd node via current optimizer node
|
||||||
|
ir::Node* CreateNewSGDNode(ir::Graph* graph, ir::Node* forward_node,
|
||||||
|
ir::Node* backward_node, ir::Node* grad_sum_node,
|
||||||
|
ir::Node* optimize_node) const;
|
||||||
|
|
||||||
|
// Replace the input weight's optimizers
|
||||||
|
void ReplaceUpstreamNode(ir::Node* upstream_node,
|
||||||
|
ir::Node* old_optimizer_node,
|
||||||
|
ir::Node* new_optimizer_node) const;
|
||||||
|
|
||||||
|
// Replace the output weight's optimizers
|
||||||
|
void ReplaceAllDownstreamNode(ir::Node* old_optimizer_node,
|
||||||
|
ir::Node* new_optimizer_node) const;
|
||||||
|
|
||||||
|
// Find all weight variables in graph
|
||||||
|
bool FindAllWeightVars(ir::Graph* graph) const;
|
||||||
|
|
||||||
|
// Find the forward_op node via the backward_op node
|
||||||
|
ir::Node* FindForwardOpViaBackwardOp(ir::Graph* graph,
|
||||||
|
ir::Node* backward_node) const;
|
||||||
|
|
||||||
|
std::vector<ir::Node*> FindConnectedNode(ir::Node* upstream_node,
|
||||||
|
ir::Node* downstream_node) const;
|
||||||
|
|
||||||
|
inline bool IsOpNamed(ir::Node* node, const std::string& name) const {
|
||||||
|
PADDLE_ENFORCE(node);
|
||||||
|
|
||||||
|
return node->NodeType() == Node::Type::kOperation && node->Name() == name;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool IsVarNamed(ir::Node* node, const std::string& name) const {
|
||||||
|
PADDLE_ENFORCE(node);
|
||||||
|
|
||||||
|
return node->NodeType() == Node::Type::kVariable && node->Name() == name;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const {
|
||||||
|
PADDLE_ENFORCE(node);
|
||||||
|
|
||||||
|
return node->NodeType() == Node::Type::kVariable &&
|
||||||
|
boost::algorithm::ends_with(node->Name(), name);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool IsVarNameContains(ir::Node* node, const std::string& name) const {
|
||||||
|
PADDLE_ENFORCE(node);
|
||||||
|
|
||||||
|
return node->NodeType() == Node::Type::kVariable &&
|
||||||
|
node->Name().find(name) != std::string::npos;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const {
|
||||||
|
PADDLE_ENFORCE(ctrl_dep_node);
|
||||||
|
PADDLE_ENFORCE(node);
|
||||||
|
|
||||||
|
return IsControlDepVar(*ctrl_dep_node) &&
|
||||||
|
ctrl_dep_node->inputs.size() >= 1u &&
|
||||||
|
ctrl_dep_node->inputs[0] == node;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
#endif // PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
|
@ -0,0 +1,85 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/jit/gen/seqpool.h"
|
||||||
|
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
|
||||||
|
#include "paddle/fluid/operators/jit/registry.h"
|
||||||
|
#include "paddle/fluid/platform/cpu_info.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace jit {
|
||||||
|
namespace gen {
|
||||||
|
|
||||||
|
void SeqPoolJitCode::genCode() {
|
||||||
|
constexpr int block = YMM_FLOAT_BLOCK;
|
||||||
|
constexpr int max_num_regs = 8;
|
||||||
|
const int num_block = w_ / block;
|
||||||
|
const int num_groups = num_block / max_num_regs;
|
||||||
|
int rest_num_regs = num_block % max_num_regs;
|
||||||
|
mov(reg32_int_h, dword[param_attr]);
|
||||||
|
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
|
||||||
|
mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
|
||||||
|
vmovups(xmm_t(1), ptr[reg_tmp + OFFSET_EXP_ONE]);
|
||||||
|
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
|
||||||
|
fild(dword[param_attr]);
|
||||||
|
fstp(dword[reg_tmp]);
|
||||||
|
vmovss(xmm_t(0), ptr[reg_tmp]);
|
||||||
|
if (type_ == SeqPoolType::kSqrt) {
|
||||||
|
vsqrtps(xmm_t(0), xmm_t(0));
|
||||||
|
}
|
||||||
|
vdivps(xmm_t(1), xmm_t(1), xmm_t(0));
|
||||||
|
vmovss(ptr[reg_tmp], xmm_t(1));
|
||||||
|
}
|
||||||
|
const int group_len = max_num_regs * block * sizeof(float);
|
||||||
|
for (int g = 0; g < num_groups; ++g) {
|
||||||
|
pool_height<ymm_t>(g * group_len, block, max_num_regs);
|
||||||
|
}
|
||||||
|
if (rest_num_regs > 0) {
|
||||||
|
pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs);
|
||||||
|
}
|
||||||
|
// part of rest_w * height
|
||||||
|
const int rest = w_ % block;
|
||||||
|
pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs);
|
||||||
|
ret();
|
||||||
|
}
|
||||||
|
|
||||||
|
class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
|
||||||
|
public:
|
||||||
|
bool UseMe(const seq_pool_attr_t& attr) const override {
|
||||||
|
return platform::MayIUse(platform::avx);
|
||||||
|
}
|
||||||
|
size_t CodeSize(const seq_pool_attr_t& attr) const override {
|
||||||
|
return 96 +
|
||||||
|
((attr.w / YMM_FLOAT_BLOCK + 4 /* for rest */) *
|
||||||
|
4 /* load, mul and save */ +
|
||||||
|
256) *
|
||||||
|
8;
|
||||||
|
}
|
||||||
|
std::unique_ptr<GenBase> CreateJitCode(
|
||||||
|
const seq_pool_attr_t& attr) const override {
|
||||||
|
PADDLE_ENFORCE_GT(attr.w, 0);
|
||||||
|
PADDLE_ENFORCE_GT(attr.h, 0);
|
||||||
|
return make_unique<SeqPoolJitCode>(attr, CodeSize(attr));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gen
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace gen = paddle::operators::jit::gen;
|
||||||
|
|
||||||
|
REGISTER_JITKERNEL_GEN(kSeqPool, gen::SeqPoolCreator);
|
@ -0,0 +1,214 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "glog/logging.h"
|
||||||
|
#include "paddle/fluid/operators/jit/gen/jitcode.h"
|
||||||
|
#include "paddle/fluid/platform/enforce.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace jit {
|
||||||
|
namespace gen {
|
||||||
|
|
||||||
|
class SeqPoolJitCode : public JitCode {
|
||||||
|
public:
|
||||||
|
explicit SeqPoolJitCode(const seq_pool_attr_t& attr,
|
||||||
|
size_t code_size = 256 * 1024,
|
||||||
|
void* code_ptr = nullptr)
|
||||||
|
: JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
|
||||||
|
if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg ||
|
||||||
|
type_ == SeqPoolType::kSqrt)) {
|
||||||
|
LOG(FATAL) << "Only support sum pool yet ";
|
||||||
|
}
|
||||||
|
fp_h_[0] = 1.f;
|
||||||
|
this->genCode();
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const char* name() const {
|
||||||
|
std::string base = "SeqPoolJitCode";
|
||||||
|
if (type_ == SeqPoolType::kSum) {
|
||||||
|
base += "_Sum";
|
||||||
|
} else if (type_ == SeqPoolType::kAvg) {
|
||||||
|
base += "_Avg";
|
||||||
|
} else if (type_ == SeqPoolType::kSqrt) {
|
||||||
|
base += "_Sqrt";
|
||||||
|
}
|
||||||
|
base += ("_W" + std::to_string(w_));
|
||||||
|
return base.c_str();
|
||||||
|
}
|
||||||
|
void genCode() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
template <typename JMM>
|
||||||
|
void pool_height(int w_offset, int block, int max_num_regs) {
|
||||||
|
int offset = w_offset;
|
||||||
|
for (int i = 0; i < max_num_regs; ++i) {
|
||||||
|
vmovups(JMM(i), ptr[param_src + offset]);
|
||||||
|
offset += sizeof(float) * block;
|
||||||
|
}
|
||||||
|
cmp(reg32_int_h, 1);
|
||||||
|
Label l_next_h, l_h_done;
|
||||||
|
jle(l_h_done, T_NEAR);
|
||||||
|
mov(reg_h_i, 1);
|
||||||
|
mov(reg_tmp, param_src);
|
||||||
|
add(reg_tmp, w_ * sizeof(float) + w_offset);
|
||||||
|
L(l_next_h);
|
||||||
|
{
|
||||||
|
mov(reg_ptr_src_i, reg_tmp);
|
||||||
|
for (int i = 0; i < max_num_regs; ++i) {
|
||||||
|
vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]);
|
||||||
|
// sum anyway
|
||||||
|
vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
|
||||||
|
add(reg_ptr_src_i, sizeof(float) * block);
|
||||||
|
}
|
||||||
|
inc(reg_h_i);
|
||||||
|
add(reg_tmp, w_ * sizeof(float));
|
||||||
|
cmp(reg_h_i, reg32_int_h);
|
||||||
|
jl(l_next_h, T_NEAR);
|
||||||
|
}
|
||||||
|
L(l_h_done);
|
||||||
|
// save right now
|
||||||
|
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
|
||||||
|
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
|
||||||
|
vbroadcastss(JMM(max_num_regs), ptr[reg_tmp]);
|
||||||
|
}
|
||||||
|
offset = w_offset;
|
||||||
|
for (int i = 0; i < max_num_regs; ++i) {
|
||||||
|
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
|
||||||
|
vmulps(JMM(i), JMM(i), JMM(max_num_regs));
|
||||||
|
}
|
||||||
|
vmovups(ptr[param_dst + offset], JMM(i));
|
||||||
|
offset += sizeof(float) * block;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void pool_height_of_rest_width(int rest, int w_offset, int max_num_regs) {
|
||||||
|
const int rest_used_num_regs = load_rest(rest, w_offset, 0);
|
||||||
|
const bool has_block4 = rest / 4 > 0;
|
||||||
|
const bool has_block2 = (rest % 4) / 2 > 0;
|
||||||
|
const bool has_block1 = (rest % 2) == 1;
|
||||||
|
cmp(reg32_int_h, 1);
|
||||||
|
Label l_next_h, l_h_done;
|
||||||
|
jle(l_h_done, T_NEAR);
|
||||||
|
mov(reg_h_i, 1);
|
||||||
|
mov(reg_tmp, param_src);
|
||||||
|
add(reg_tmp, w_ * sizeof(float) + w_offset);
|
||||||
|
L(l_next_h);
|
||||||
|
{
|
||||||
|
int reg_idx = 0;
|
||||||
|
mov(reg_ptr_src_i, reg_tmp);
|
||||||
|
if (has_block4) {
|
||||||
|
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
|
||||||
|
add(reg_ptr_src_i, sizeof(float) * 4);
|
||||||
|
reg_idx++;
|
||||||
|
}
|
||||||
|
if (has_block2) {
|
||||||
|
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
|
||||||
|
add(reg_ptr_src_i, sizeof(float) * 2);
|
||||||
|
reg_idx++;
|
||||||
|
}
|
||||||
|
if (has_block1) {
|
||||||
|
vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
|
||||||
|
reg_idx++;
|
||||||
|
}
|
||||||
|
PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs,
|
||||||
|
"All heights should use same regs");
|
||||||
|
for (int i = 0; i < reg_idx; ++i) {
|
||||||
|
vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
|
||||||
|
}
|
||||||
|
inc(reg_h_i);
|
||||||
|
add(reg_tmp, w_ * sizeof(float));
|
||||||
|
cmp(reg_h_i, reg32_int_h);
|
||||||
|
jl(l_next_h, T_NEAR);
|
||||||
|
}
|
||||||
|
L(l_h_done);
|
||||||
|
// save right now
|
||||||
|
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
|
||||||
|
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
|
||||||
|
vbroadcastss(xmm_t(max_num_regs), ptr[reg_tmp]);
|
||||||
|
for (int i = 0; i < rest_used_num_regs; ++i) {
|
||||||
|
vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
save_rest(rest, w_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the number of used regs, use start from reg 0
|
||||||
|
int load_rest(int rest, int w_offset, const int num_shift_regs,
|
||||||
|
const int reg_start = 0) {
|
||||||
|
const bool has_block4 = rest / 4 > 0;
|
||||||
|
const bool has_block2 = (rest % 4) / 2 > 0;
|
||||||
|
const bool has_block1 = (rest % 2) == 1;
|
||||||
|
int reg_idx = reg_start;
|
||||||
|
if (has_block4) {
|
||||||
|
vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
|
||||||
|
w_offset += sizeof(float) * 4;
|
||||||
|
reg_idx++;
|
||||||
|
}
|
||||||
|
if (has_block2) {
|
||||||
|
vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
|
||||||
|
w_offset += sizeof(float) * 2;
|
||||||
|
reg_idx++;
|
||||||
|
}
|
||||||
|
if (has_block1) {
|
||||||
|
vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
|
||||||
|
reg_idx++;
|
||||||
|
}
|
||||||
|
return reg_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// use reg start from 0
|
||||||
|
void save_rest(int rest, int w_offset, int reg_start = 0) {
|
||||||
|
const bool has_block4 = rest / 4 > 0;
|
||||||
|
const bool has_block2 = (rest % 4) / 2 > 0;
|
||||||
|
const bool has_block1 = (rest % 2) == 1;
|
||||||
|
int reg_idx = reg_start;
|
||||||
|
if (has_block4) {
|
||||||
|
vmovups(ptr[param_dst + w_offset], xmm_t(reg_idx));
|
||||||
|
w_offset += sizeof(float) * 4;
|
||||||
|
reg_idx++;
|
||||||
|
}
|
||||||
|
if (has_block2) {
|
||||||
|
vmovq(ptr[param_dst + w_offset], xmm_t(reg_idx));
|
||||||
|
w_offset += sizeof(float) * 2;
|
||||||
|
reg_idx++;
|
||||||
|
}
|
||||||
|
if (has_block1) {
|
||||||
|
vmovss(ptr[param_dst + w_offset], xmm_t(reg_idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
float ALIGN32_BEG fp_h_[1] ALIGN32_END;
|
||||||
|
int w_;
|
||||||
|
SeqPoolType type_;
|
||||||
|
reg64_t param_src{abi_param1};
|
||||||
|
reg64_t param_dst{abi_param2};
|
||||||
|
reg64_t param_attr{abi_param3};
|
||||||
|
reg64_t reg_tmp{rax};
|
||||||
|
|
||||||
|
reg32_t reg32_int_h{r8d};
|
||||||
|
reg32_t reg32_fp_h{r9d};
|
||||||
|
|
||||||
|
reg64_t reg_h_i{r10};
|
||||||
|
reg64_t reg_ptr_src_i{r11};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gen
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue