Merge branch 'develop' of github.com:PaddlePaddle/Paddle into overlap_memcpy_with_dist

revert-11610-move_hooks
Yancey1989 7 years ago
commit 15913d92c5

@ -29,9 +29,11 @@ Currently supported `--model` argument include:
You can choose to use GPU/CPU training. With GPU training, you can specify You can choose to use GPU/CPU training. With GPU training, you can specify
`--gpus <gpu_num>` to run multi GPU training. `--gpus <gpu_num>` to run multi GPU training.
* Run distributed training with parameter servers: * Run distributed training with parameter servers:
* see [run_fluid_benchmark.sh](https://github.com/PaddlePaddle/Paddle/blob/develop/benchmark/fluid/run_fluid_benchmark.sh) as an example.
* start parameter servers: * start parameter servers:
```bash ```bash
PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=1 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --device GPU --update_method pserver PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=1 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --device GPU --update_method pserver
sleep 15
``` ```
* start trainers: * start trainers:
```bash ```bash

@ -0,0 +1,9 @@
#!/bin/bash
PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=2 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model resnet --device CPU --update_method pserver --iterations=10000 &
sleep 15
CUDA_VISIBLE_DEVICES=0,1 PADDLE_TRAINING_ROLE=TRAINER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=2 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model resnet --device GPU --update_method pserver --iterations=10000 --gpus 2 &
CUDA_VISIBLE_DEVICES=2,3 PADDLE_TRAINING_ROLE=TRAINER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=2 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=1 python fluid_benchmark.py --model resnet --device GPU --update_method pserver --iterations=10000 --gpus 2 &

@ -34,13 +34,7 @@ DEFINE_bool(
namespace paddle { namespace paddle {
namespace framework { namespace framework {
Scope::~Scope() { Scope::~Scope() { DropKids(); }
DropKids();
for (auto& kv : vars_) {
VLOG(3) << "Destroy variable " << kv.first;
delete kv.second;
}
}
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
@ -49,10 +43,13 @@ Scope& Scope::NewScope() const {
} }
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
// acquire the lock when new var under this scope
std::unique_lock<std::mutex> lock(mutex_);
auto* v = FindVarLocally(name); auto* v = FindVarLocally(name);
if (v != nullptr) return v; if (v != nullptr) return v;
v = new Variable(); v = new Variable();
vars_[name] = v; vars_[name].reset(v);
VLOG(3) << "Create variable " << name; VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first); v->name_ = &(vars_.find(name)->first);
return v; return v;
@ -67,22 +64,29 @@ Variable* Scope::Var(std::string* name) {
} }
Variable* Scope::FindVar(const std::string& name) const { Variable* Scope::FindVar(const std::string& name) const {
// acquire the lock when find var
std::unique_lock<std::mutex> lock(mutex_);
return FindVarInternal(name);
}
Variable* Scope::FindVarInternal(const std::string& name) const {
auto var = FindVarLocally(name); auto var = FindVarLocally(name);
if (var != nullptr) { if (var != nullptr) {
return var; return var;
} }
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name); return (parent_ == nullptr) ? nullptr : parent_->FindVarInternal(name);
} }
const Scope* Scope::FindScope(const Variable* var) const { const Scope* Scope::FindScope(const Variable* var) const {
for (auto& kv : vars_) { for (auto& kv : vars_) {
if (kv.second == var) { if (kv.second.get() == var) {
return this; return this;
} }
} }
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var); return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
} }
void Scope::DropKids() { void Scope::DropKids() {
std::unique_lock<std::mutex> lock(mutex_);
for (Scope* s : kids_) delete s; for (Scope* s : kids_) delete s;
kids_.clear(); kids_.clear();
} }
@ -110,10 +114,10 @@ void Scope::DeleteScope(Scope* scope) const {
} }
void Scope::EraseVars(const std::vector<std::string>& var_names) { void Scope::EraseVars(const std::vector<std::string>& var_names) {
std::unique_lock<std::mutex> lock(mutex_);
std::set<std::string> var_set(var_names.begin(), var_names.end()); std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) { for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) { if (var_set.find(it->first) != var_set.end()) {
delete it->second;
it = vars_.erase(it); it = vars_.erase(it);
} else { } else {
++it; ++it;
@ -129,7 +133,7 @@ void Scope::Rename(const std::string& origin_name,
auto new_it = vars_.find(new_name); auto new_it = vars_.find(new_name);
PADDLE_ENFORCE(new_it == vars_.end(), PADDLE_ENFORCE(new_it == vars_.end(),
"The variable with name %s is already in the scope", new_name); "The variable with name %s is already in the scope", new_name);
vars_[new_name] = origin_it->second; vars_[new_name].reset(origin_it->second.release());
vars_.erase(origin_it); vars_.erase(origin_it);
} }
@ -141,7 +145,7 @@ std::string Scope::Rename(const std::string& origin_name) const {
Variable* Scope::FindVarLocally(const std::string& name) const { Variable* Scope::FindVarLocally(const std::string& name) const {
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) return it->second; if (it != vars_.end()) return it->second.get();
return nullptr; return nullptr;
} }

@ -47,15 +47,18 @@ class Scope {
Scope& NewScope() const; Scope& NewScope() const;
/// Create a variable with given name if it doesn't exist. /// Create a variable with given name if it doesn't exist.
/// Caller doesn't own the returned Variable.
Variable* Var(const std::string& name); Variable* Var(const std::string& name);
/// Create a variable with a scope-unique name. /// Create a variable with a scope-unique name.
/// Caller doesn't own the returned Variable.
Variable* Var(std::string* name = nullptr); Variable* Var(std::string* name = nullptr);
void EraseVars(const std::vector<std::string>& var_names); void EraseVars(const std::vector<std::string>& var_names);
/// Find a variable in the scope or any of its ancestors. Returns /// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find. /// nullptr if cannot find.
/// Caller doesn't own the returned Variable.
Variable* FindVar(const std::string& name) const; Variable* FindVar(const std::string& name) const;
const Scope* parent() const { return parent_; } const Scope* parent() const { return parent_; }
@ -78,13 +81,21 @@ class Scope {
// Rename variable to a new name and return the new name // Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const; std::string Rename(const std::string& origin_name) const;
Variable* FindVarLocally(const std::string& name) const;
private: private:
// Call Scope::NewScope for a sub-scope. // Call Scope::NewScope for a sub-scope.
explicit Scope(Scope const* parent) : parent_(parent) {} explicit Scope(Scope const* parent) : parent_(parent) {}
mutable std::unordered_map<std::string, Variable*> vars_; // Called by FindVar recursively.
// Caller doesn't own the returned Variable.
Variable* FindVarInternal(const std::string& name) const;
// Called by FindVarInternal and Var.
// Caller doesn't own the returned Variable.
Variable* FindVarLocally(const std::string& name) const;
mutable std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
// Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_; mutable std::list<Scope*> kids_;
Scope const* parent_{nullptr}; Scope const* parent_{nullptr};

@ -18,6 +18,8 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
@ -107,6 +109,13 @@ class OrderedRegistry {
std::vector<std::unique_ptr<T>> data_; std::vector<std::unique_ptr<T>> data_;
}; };
template <typename T>
T &GetFromScope(const framework::Scope &scope, const std::string &name) {
framework::Variable *var = scope.FindVar(name);
PADDLE_ENFORCE(var != nullptr);
return *var->GetMutable<T>();
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle

@ -1,10 +1,16 @@
# Add TRT tests # Add TRT tests
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine)
# This test is not stable # This test is not stable
# See https://paddleci.ngrok.io/viewLog.html?tab=buildLog&buildTypeId=Paddle_PrCi2&buildId=36834&_focus=8828 # See https://paddleci.ngrok.io/viewLog.html?tab=buildLog&buildTypeId=Paddle_PrCi2&buildId=36834&_focus=8828
#nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc #nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc
# DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine # DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine
# SERIAL) # SERIAL)
nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc
DEPS tensorrt_engine mul_op)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle { namespace paddle {
@ -36,8 +37,8 @@ class ReluOpConverter : public OpConverter {
} }
}; };
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter);
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter);

@ -22,14 +22,14 @@ class Conv2dOpConverter : public OpConverter {
public: public:
Conv2dOpConverter() {} Conv2dOpConverter() {}
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override { const framework::Scope& scope, bool test_mode) override {
LOG(INFO) LOG(INFO)
<< "convert a fluid conv2d op to tensorrt conv layer without bias"; << "convert a fluid conv2d op to tensorrt conv layer without bias";
} }
}; };
REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter);
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter);

@ -56,7 +56,7 @@ void ReorderCKtoKC(TensorRTEngine::Weight& iweights,
class FcOpConverter : public OpConverter { class FcOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override { const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid fc op to tensorrt fc layer without bias"; VLOG(4) << "convert a fluid fc op to tensorrt fc layer without bias";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
@ -106,14 +106,16 @@ class FcOpConverter : public OpConverter {
n_output, weight.get(), bias.get()); n_output, weight.get(), bias.get());
auto output_name = op_desc.Output("Out").front(); auto output_name = op_desc.Output("Out").front();
engine_->DeclareOutput(layer, 0, output_name); engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) {
engine_->DeclareOutput(output_name);
}
} }
}; };
REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter);
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter);
USE_OP(mul); USE_OP(mul);

@ -23,9 +23,8 @@ namespace tensorrt {
*/ */
class MulOpConverter : public OpConverter { class MulOpConverter : public OpConverter {
public: public:
MulOpConverter() {}
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override { const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias"; VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
@ -37,12 +36,18 @@ class MulOpConverter : public OpConverter {
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false, engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false,
*const_cast<nvinfer1::ITensor*>(input2), false); *const_cast<nvinfer1::ITensor*>(input2), false);
engine_->DeclareOutput(layer, 0, op_desc.Output("Out")[0]); auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
} }
}; };
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter);
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(mul);
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter);

@ -17,6 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
@ -34,12 +35,15 @@ class OpConverter {
// Converter logic for an op. // Converter logic for an op.
virtual void operator()(const framework::proto::OpDesc& op, virtual void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) {} const framework::Scope& scope,
bool test_mode = false) {}
// Convert a single fluid operaotr and add the corresponding layer to TRT. // Convert a single fluid operator and add the corresponding layer to TRT.
// test_mode: whether the instance executes in an unit test.
void ConvertOp(const framework::proto::OpDesc& op, void ConvertOp(const framework::proto::OpDesc& op,
const std::unordered_set<std::string>& parameters, const std::unordered_set<std::string>& parameters,
const framework::Scope& scope, TensorRTEngine* engine) { const framework::Scope& scope, TensorRTEngine* engine,
bool test_mode = false) {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
OpConverter* it{nullptr}; OpConverter* it{nullptr};
@ -57,7 +61,7 @@ class OpConverter {
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type()); op_desc.Type());
it->SetEngine(engine); it->SetEngine(engine);
(*it)(op, scope); (*it)(op, scope, test_mode);
} }
// convert fluid block to tensorrt network // convert fluid block to tensorrt network
@ -77,6 +81,9 @@ class OpConverter {
// TensorRT engine // TensorRT engine
TensorRTEngine* engine_{nullptr}; TensorRTEngine* engine_{nullptr};
protected:
bool test_mode_;
private: private:
// registered op converter map, whose key is the fluid op type, and value is // registered op converter map, whose key is the fluid op type, and value is
// the pointer position of corresponding OpConverter class. // the pointer position of corresponding OpConverter class.
@ -85,13 +92,24 @@ class OpConverter {
framework::Scope* scope_{nullptr}; framework::Scope* scope_{nullptr};
}; };
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \ #define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
struct trt_##op_type__##_converter { \ struct trt_##op_type__##_converter : public ::paddle::framework::Registrar { \
trt_##op_type__##_converter() { \ trt_##op_type__##_converter() { \
Registry<OpConverter>::Register<Converter__>(#op_type__); \ ::paddle::inference:: \
} \ Registry<paddle::inference::tensorrt::OpConverter>::Register< \
}; \ ::paddle::inference::tensorrt::Converter__>(#op_type__); \
trt_##op_type__##_converter trt_##op_type__##_converter__; } \
}; \
trt_##op_type__##_converter trt_##op_type__##_converter__; \
int TouchConverterRegister_##op_type__() { \
trt_##op_type__##_converter__.Touch(); \
return 0; \
}
#define USE_TRT_CONVERTER(op_type__) \
extern int TouchConverterRegister_##op_type__(); \
static int use_op_converter_trt_##op_type__ __attribute__((unused)) = \
TouchConverterRegister_##op_type__();
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference

@ -36,3 +36,5 @@ TEST(OpConverter, ConvertBlock) {
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_TRT_CONVERTER(conv2d)

@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
@ -104,8 +105,8 @@ class TRTConvertValidation {
void SetOp(const framework::proto::OpDesc& desc) { void SetOp(const framework::proto::OpDesc& desc) {
op_ = framework::OpRegistry::CreateOp(desc); op_ = framework::OpRegistry::CreateOp(desc);
OpConverter op_converter; Singleton<OpConverter>::Global().ConvertOp(
op_converter.ConvertOp(desc, parameters_, scope_, engine_.get()); desc, parameters_, scope_, engine_.get(), true /*test_mode*/);
engine_->FreezeNetwork(); engine_->FreezeNetwork();

@ -43,9 +43,10 @@ void TensorRTEngine::Execute(int batch_size) {
} }
TensorRTEngine::~TensorRTEngine() { TensorRTEngine::~TensorRTEngine() {
cudaStreamSynchronize(*stream_);
// clean buffer // clean buffer
for (auto& buf : buffers_) { for (auto& buf : buffers_) {
if (buf.buffer != nullptr) { if (buf.device == DeviceType::GPU && buf.buffer != nullptr) {
PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer)); PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer));
buf.buffer = nullptr; buf.buffer = nullptr;
buf.max_size = 0; buf.max_size = 0;
@ -80,6 +81,8 @@ void TensorRTEngine::FreezeNetwork() {
auto& buf = buffer(item.first); auto& buf = buffer(item.first);
CHECK(buf.buffer == nullptr); // buffer should be allocated only once. CHECK(buf.buffer == nullptr); // buffer should be allocated only once.
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second)); PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second));
VLOG(4) << "buffer malloc " << item.first << " " << item.second << " "
<< buf.buffer;
buf.size = buf.max_size = item.second; buf.size = buf.max_size = item.second;
buf.device = DeviceType::GPU; buf.device = DeviceType::GPU;
} }
@ -96,6 +99,7 @@ nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
PADDLE_ENFORCE(input, "infer network add input %s failed", name); PADDLE_ENFORCE(input, "infer network add input %s failed", name);
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] * buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
analysis::AccuDims(dims.d, dims.nbDims); analysis::AccuDims(dims.d, dims.nbDims);
PADDLE_ENFORCE(input->isNetworkInput());
TensorRTEngine::SetITensor(name, input); TensorRTEngine::SetITensor(name, input);
return input; return input;
} }
@ -109,7 +113,9 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
SetITensor(name, output); SetITensor(name, output);
PADDLE_ENFORCE(output != nullptr); PADDLE_ENFORCE(output != nullptr);
output->setName(name.c_str()); output->setName(name.c_str());
PADDLE_ENFORCE(!output->isNetworkInput());
infer_network_->markOutput(*output); infer_network_->markOutput(*output);
PADDLE_ENFORCE(output->isNetworkOutput());
// output buffers' size can only be decided latter, set zero here to mark this // output buffers' size can only be decided latter, set zero here to mark this
// and will reset latter. // and will reset latter.
buffer_sizes_[name] = 0; buffer_sizes_[name] = 0;
@ -122,6 +128,7 @@ void TensorRTEngine::DeclareOutput(const std::string& name) {
auto* output = TensorRTEngine::GetITensor(name); auto* output = TensorRTEngine::GetITensor(name);
PADDLE_ENFORCE(output != nullptr); PADDLE_ENFORCE(output != nullptr);
output->setName(name.c_str()); output->setName(name.c_str());
PADDLE_ENFORCE(!output->isNetworkInput());
infer_network_->markOutput(*output); infer_network_->markOutput(*output);
// output buffers' size can only be decided latter, set zero here to mark this // output buffers' size can only be decided latter, set zero here to mark this
// and will reset latter. // and will reset latter.

@ -21,6 +21,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
@ -131,7 +132,11 @@ class TensorRTEngine : public EngineBase {
// TensorRT related internal members // TensorRT related internal members
template <typename T> template <typename T>
struct Destroyer { struct Destroyer {
void operator()(T* x) { x->destroy(); } void operator()(T* x) {
if (x) {
x->destroy();
}
}
}; };
template <typename T> template <typename T>
using infer_ptr = std::unique_ptr<T, Destroyer<T>>; using infer_ptr = std::unique_ptr<T, Destroyer<T>>;
@ -155,6 +160,27 @@ class TensorRTEngine : public EngineBase {
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \ #define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
engine__->network()->add##layer__(ARGS); engine__->network()->add##layer__(ARGS);
/*
* Helper to control the TensorRT engine's creation and deletion.
*/
class TRT_EngineManager {
public:
TensorRTEngine* Create(int max_batch, int max_workspace,
cudaStream_t* stream) {
engines_.emplace_back(new TensorRTEngine(max_batch, max_workspace, stream));
return engines_.back().get();
}
void DeleteALl() {
for (auto& ptr : engines_) {
ptr.reset(nullptr);
}
}
private:
std::vector<std::unique_ptr<TensorRTEngine>> engines_;
};
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle

@ -101,23 +101,22 @@ void SplitData(
} }
void ThreadRunInfer( void ThreadRunInfer(
const int tid, paddle::framework::Executor* executor, const int tid, paddle::framework::Scope* scope,
paddle::framework::Scope* scope,
const std::unique_ptr<paddle::framework::ProgramDesc>& inference_program,
const std::vector<std::vector<const paddle::framework::LoDTensor*>>& jobs) { const std::vector<std::vector<const paddle::framework::LoDTensor*>>& jobs) {
auto copy_program = std::unique_ptr<paddle::framework::ProgramDesc>( // maybe framework:ProgramDesc is not thread-safe
new paddle::framework::ProgramDesc(*inference_program));
auto& sub_scope = scope->NewScope(); auto& sub_scope = scope->NewScope();
auto place = paddle::platform::CPUPlace();
auto executor = paddle::framework::Executor(place);
auto inference_program =
paddle::inference::Load(&executor, scope, FLAGS_model_path);
std::string feed_holder_name = "feed_" + paddle::string::to_string(tid); auto ctx = executor.Prepare(*inference_program, /*block_id*/ 0);
std::string fetch_holder_name = "fetch_" + paddle::string::to_string(tid); executor.CreateVariables(*inference_program, &sub_scope, /*block_id*/ 0);
copy_program->SetFeedHolderName(feed_holder_name);
copy_program->SetFetchHolderName(fetch_holder_name);
const std::vector<std::string>& feed_target_names = const std::vector<std::string>& feed_target_names =
copy_program->GetFeedTargetNames(); inference_program->GetFeedTargetNames();
const std::vector<std::string>& fetch_target_names = const std::vector<std::string>& fetch_target_names =
copy_program->GetFetchTargetNames(); inference_program->GetFetchTargetNames();
PADDLE_ENFORCE_EQ(fetch_target_names.size(), 1UL); PADDLE_ENFORCE_EQ(fetch_target_names.size(), 1UL);
std::map<std::string, paddle::framework::LoDTensor*> fetch_targets; std::map<std::string, paddle::framework::LoDTensor*> fetch_targets;
@ -131,9 +130,8 @@ void ThreadRunInfer(
auto start_ms = GetCurrentMs(); auto start_ms = GetCurrentMs();
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
feed_targets[feed_target_names[0]] = inputs[i]; feed_targets[feed_target_names[0]] = inputs[i];
executor->Run(*copy_program, &sub_scope, &feed_targets, &fetch_targets, executor.RunPreparedContext(ctx.get(), &sub_scope, &feed_targets,
true /*create_local_scope*/, true /*create_vars*/, &fetch_targets, false /*create_local_scope*/);
feed_holder_name, fetch_holder_name);
} }
auto stop_ms = GetCurrentMs(); auto stop_ms = GetCurrentMs();
scope->DeleteScope(&sub_scope); scope->DeleteScope(&sub_scope);
@ -158,22 +156,10 @@ TEST(inference, nlp) {
LOG(INFO) << "Number of samples (seq_len<1024): " << datasets.size(); LOG(INFO) << "Number of samples (seq_len<1024): " << datasets.size();
LOG(INFO) << "Total number of words: " << num_total_words; LOG(INFO) << "Total number of words: " << num_total_words;
const bool model_combined = false;
// 0. Call `paddle::framework::InitDevices()` initialize all the devices // 0. Call `paddle::framework::InitDevices()` initialize all the devices
// 1. Define place, executor, scope
auto place = paddle::platform::CPUPlace();
auto executor = paddle::framework::Executor(place);
std::unique_ptr<paddle::framework::Scope> scope( std::unique_ptr<paddle::framework::Scope> scope(
new paddle::framework::Scope()); new paddle::framework::Scope());
// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
inference_program =
InitProgram(&executor, scope.get(), FLAGS_model_path, model_combined);
if (FLAGS_use_mkldnn) {
EnableMKLDNN(inference_program);
}
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
// only use 1 thread number per std::thread // only use 1 thread number per std::thread
omp_set_dynamic(0); omp_set_dynamic(0);
@ -189,21 +175,30 @@ TEST(inference, nlp) {
start_ms = GetCurrentMs(); start_ms = GetCurrentMs();
for (int i = 0; i < FLAGS_num_threads; ++i) { for (int i = 0; i < FLAGS_num_threads; ++i) {
threads.emplace_back( threads.emplace_back(
new std::thread(ThreadRunInfer, i, &executor, scope.get(), new std::thread(ThreadRunInfer, i, scope.get(), std::ref(jobs)));
std::ref(inference_program), std::ref(jobs)));
} }
for (int i = 0; i < FLAGS_num_threads; ++i) { for (int i = 0; i < FLAGS_num_threads; ++i) {
threads[i]->join(); threads[i]->join();
} }
stop_ms = GetCurrentMs(); stop_ms = GetCurrentMs();
} else { } else {
if (FLAGS_prepare_vars) { // 1. Define place, executor, scope
executor.CreateVariables(*inference_program, scope.get(), 0); auto place = paddle::platform::CPUPlace();
auto executor = paddle::framework::Executor(place);
// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
inference_program = InitProgram(&executor, scope.get(), FLAGS_model_path,
/*model combined*/ false);
if (FLAGS_use_mkldnn) {
EnableMKLDNN(inference_program);
} }
// always prepare context // always prepare context
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx; std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
ctx = executor.Prepare(*inference_program, 0); ctx = executor.Prepare(*inference_program, 0);
if (FLAGS_prepare_vars) {
executor.CreateVariables(*inference_program, scope.get(), 0);
}
// preapre fetch // preapre fetch
const std::vector<std::string>& fetch_target_names = const std::vector<std::string>& fetch_target_names =
inference_program->GetFetchTargetNames(); inference_program->GetFetchTargetNames();

@ -227,6 +227,8 @@ op_library(softmax_op DEPS softmax)
op_library(sequence_softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax)
if (WITH_GPU AND TENSORRT_FOUND) if (WITH_GPU AND TENSORRT_FOUND)
op_library(tensorrt_engine_op DEPS tensorrt_engine) op_library(tensorrt_engine_op DEPS tensorrt_engine)
nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
DEPS tensorrt_engine_op tensorrt_engine tensorrt_converter)
else() else()
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
endif() endif()

@ -1,6 +1,6 @@
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory) selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

@ -25,29 +25,15 @@ namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
std::once_flag RPCClient::init_flag_; void GRPCClient::InitImpl() { InitEventLoop(); }
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr); void GRPCClient::InitEventLoop() {
RPCClient* RPCClient::GetInstance() {
std::call_once(init_flag_, &RPCClient::Init);
return rpc_client_.get();
}
void RPCClient::Init() {
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new RPCClient());
}
rpc_client_->InitEventLoop();
}
void RPCClient::InitEventLoop() {
// start the client process thread // start the client process thread
// TODO(wuyi): can make this in a threadpool // TODO(wuyi): can make this in a threadpool
client_thread_.reset(new std::thread(std::bind(&RPCClient::Proceed, this))); client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
} }
RPCClient::~RPCClient() { GRPCClient::~GRPCClient() {
Wait(); Wait();
cq_.Shutdown(); cq_.Shutdown();
{ {
@ -59,11 +45,10 @@ RPCClient::~RPCClient() {
client_thread_->join(); client_thread_->join();
} }
bool RPCClient::AsyncSendVariable(const std::string& ep, bool GRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name, int64_t time_out) {
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
@ -113,11 +98,10 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
result->Swap(&tmp); result->Swap(&tmp);
} }
bool RPCClient::AsyncGetVariable(const std::string& ep, bool GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name, int64_t time_out) {
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
@ -155,12 +139,12 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true; return true;
} }
bool RPCClient::AsyncPrefetchVariable(const std::string& ep, bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
int64_t time_out) { int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name; const std::string in_var_name_val = in_var_name;
@ -198,7 +182,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
return true; return true;
} }
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { void GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
@ -211,7 +196,8 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
req_count_++; req_count_++;
} }
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out); s->Prepare(time_out);
@ -223,12 +209,12 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
req_count_++; req_count_++;
} }
void RPCClient::Wait() { void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; }); sync_cond_.wait(lk, [this] { return req_count_ == 0; });
} }
void RPCClient::Proceed() { void GRPCClient::Proceed() {
void* tag = nullptr; void* tag = nullptr;
bool ok = false; bool ok = false;
@ -251,7 +237,7 @@ void RPCClient::Proceed() {
} }
} }
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
// TODO(Yancey1989): make grpc client completely thread-safe // TODO(Yancey1989): make grpc client completely thread-safe
std::lock_guard<std::mutex> guard(chan_mutex_); std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep); auto it = channels_.find(ep);

@ -38,6 +38,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
@ -164,47 +165,46 @@ class FetchBarrierProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
class RPCClient { class GRPCClient : public RPCClient {
public: public:
RPCClient() {} GRPCClient() {}
~RPCClient(); virtual ~GRPCClient();
static RPCClient* GetInstance(); bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
bool AsyncSendVariable(const std::string& ep, bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name,
const framework::Scope& scope, int64_t time_out = RPCClient::rpc_time_out) override;
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool AsyncGetVariable(const std::string& ep, bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& in_var_name,
int64_t time_out = 600 * 1000); const std::string& out_var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
bool AsyncPrefetchVariable(const std::string& ep, void AsyncSendBatchBarrier(
const platform::DeviceContext& ctx, const std::string& ep,
const framework::Scope& scope, int64_t time_out = RPCClient::rpc_time_out) override;
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = 600 * 1000);
void AsyncSendBatchBarrier(const std::string& ep, void AsyncSendFetchBarrier(
int64_t time_out = 600 * 1000); const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendFetchBarrier(const std::string& ep, void Wait() override;
int64_t time_out = 600 * 1000);
void Wait(); protected:
void InitImpl() override;
private:
// InitEventLoop should only be called by Init() // InitEventLoop should only be called by Init()
void InitEventLoop(); void InitEventLoop();
private:
void Proceed(); void Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
// Init is called by GetInstance.
static void Init();
private: private:
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
@ -218,9 +218,7 @@ class RPCClient {
// mutex for GetChannel thread safety // mutex for GetChannel thread safety
std::mutex chan_mutex_; std::mutex chan_mutex_;
static std::unique_ptr<RPCClient> rpc_client_; DISABLE_COPY_AND_ASSIGN(GRPCClient);
static std::once_flag init_flag_;
DISABLE_COPY_AND_ASSIGN(RPCClient);
}; };
} // namespace detail } // namespace detail

@ -19,6 +19,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
@ -123,7 +124,8 @@ TEST(PREFETCH, CPU) {
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
detail::RPCClient* client = detail::RPCClient::GetInstance(); detail::RPCClient* client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
int port = g_rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
@ -137,7 +139,7 @@ TEST(PREFETCH, CPU) {
std::string in_var_name("ids"); std::string in_var_name("ids");
std::string out_var_name("out"); std::string out_var_name("out");
client->AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name); client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
client->Wait(); client->Wait();
auto var = scope.Var(out_var_name); auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value(); auto value = var->GetMutable<framework::SelectedRows>()->value();

@ -0,0 +1,26 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/rpc_client.h"
namespace paddle {
namespace operators {
namespace detail {
std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
} // namespace detail
} // namespace operators
} // namespace paddle

@ -0,0 +1,82 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace operators {
namespace detail {
class RPCClient {
public:
virtual bool AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = rpc_time_out) = 0;
virtual bool AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = rpc_time_out) = 0;
virtual bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = rpc_time_out) = 0;
virtual void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;
virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;
virtual void Wait() = 0;
static constexpr int64_t rpc_time_out = 120 * 1000;
template <typename T>
static RPCClient* GetInstance() {
std::call_once(init_flag_, &RPCClient::Init<T>);
return rpc_client_.get();
}
// Init is called by GetInstance.
template <typename T>
static void Init() {
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new T());
rpc_client_->InitImpl();
}
}
protected:
virtual void InitImpl() {}
private:
static std::once_flag init_flag_;
static std::unique_ptr<RPCClient> rpc_client_;
};
} // namespace detail
} // namespace operators
} // namespace paddle

@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
@ -43,7 +44,8 @@ class FetchBarrierOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance(); detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
rpc_client->Wait(); rpc_client->Wait();

@ -61,12 +61,13 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list"); Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient client; detail::RPCClient* client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (auto& ep : endpoint_list) { for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep; VLOG(3) << "sending nccl id to " << ep;
client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME); client->AsyncSendVar(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
} }
client.Wait(); client->Wait();
VLOG(3) << "sending completed..."; VLOG(3) << "sending completed...";
} }

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save