Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into update_simple_distranspiler

wangkuiyi-patch-1
minqiyang 8 years ago
commit 0abf173ed5

@ -58,6 +58,8 @@ PaddlePaddle uses this [Git branching model](http://nvie.com/posts/a-successful-
create mode 100644 233 create mode 100644 233
``` ```
NOTE: The `yapf` installed by `pip install pre-commit` and `conda install -c conda-forge pre-commit` is slightly different. Paddle developers use `pip install pre-commit`.
1. Build and test 1. Build and test
Users can build PaddlePaddle natively on Linux and Mac OS X. But to unify the building environment and to make it easy for debugging, the recommended way is [using Docker](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/build_en.md). Users can build PaddlePaddle natively on Linux and Mac OS X. But to unify the building environment and to make it easy for debugging, the recommended way is [using Docker](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/build_en.md).

@ -98,6 +98,8 @@ def parse_args():
'--use_fake_data', '--use_fake_data',
action='store_true', action='store_true',
help='If set ommit the actual read data operators.') help='If set ommit the actual read data operators.')
parser.add_argument(
'--profile', action='store_true', help='If set, profile a few steps.')
parser.add_argument( parser.add_argument(
'--update_method', '--update_method',
type=str, type=str,
@ -108,8 +110,8 @@ def parse_args():
return args return args
def append_nccl2_prepare(): def append_nccl2_prepare(trainer_id):
if os.getenv("PADDLE_TRAINER_ID", None) != None: if trainer_id >= 0:
# append gen_nccl_id at the end of startup program # append gen_nccl_id at the end of startup program
trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
port = os.getenv("PADDLE_PSERVER_PORT") port = os.getenv("PADDLE_PSERVER_PORT")
@ -136,12 +138,12 @@ def append_nccl2_prepare():
}) })
return nccl_id_var, num_trainers, trainer_id return nccl_id_var, num_trainers, trainer_id
else: else:
raise Exception( raise Exception("must set positive PADDLE_TRAINER_ID env variables for "
"must set PADDLE_TRAINER_ID env variables for dist train.") "nccl-based dist train.")
def dist_transpile(): def dist_transpile(trainer_id):
if "PADDLE_TRAINING_ROLE" not in os.environ: if trainer_id < 0:
return None, None return None, None
# the port of all pservers, needed by both trainer and pserver # the port of all pservers, needed by both trainer and pserver
@ -158,9 +160,6 @@ def dist_transpile():
trainers = int(os.getenv("PADDLE_TRAINERS")) trainers = int(os.getenv("PADDLE_TRAINERS"))
# the IP of the local machine, needed by pserver only # the IP of the local machine, needed by pserver only
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
# the unique trainer id, starting from 0, needed by trainer
# only
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
# the role, should be either PSERVER or TRAINER # the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE") training_role = os.getenv("PADDLE_TRAINING_ROLE")
@ -295,6 +294,11 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
iters = 0 iters = 0
start_time = time.time() start_time = time.time()
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
if args.profile and pass_id == 0 and batch_id == 5:
profiler.start_profiler("All")
elif args.profile and pass_id == 0 and batch_id == 10:
profiler.stop_profiler("total", "/tmp/profile_%d" % trainer_id)
if iters == args.skip_batch_num: if iters == args.skip_batch_num:
start_time = time.time() start_time = time.time()
num_samples = 0 num_samples = 0
@ -334,7 +338,11 @@ def print_arguments(args):
def main(): def main():
args = parse_args() args = parse_args()
print_arguments(args) print_arguments(args)
nccl_id_var, num_trainers, trainer_id = None, 1, 0
# the unique trainer id, starting from 0, needed by trainer
# only
nccl_id_var, num_trainers, trainer_id = (
None, 1, int(os.getenv("PADDLE_TRAINER_ID", "-1")))
if args.use_cprof: if args.use_cprof:
pr = cProfile.Profile() pr = cProfile.Profile()
@ -348,7 +356,7 @@ def main():
fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program())
if args.update_method == "pserver": if args.update_method == "pserver":
train_prog, startup_prog = dist_transpile() train_prog, startup_prog = dist_transpile(trainer_id)
if not train_prog: if not train_prog:
raise Exception( raise Exception(
"Must configure correct environments to run dist train.") "Must configure correct environments to run dist train.")
@ -364,7 +372,7 @@ def main():
train_args.append(fluid.default_startup_program()) train_args.append(fluid.default_startup_program())
if args.update_method == "nccl2": if args.update_method == "nccl2":
nccl_id_var, num_trainers, trainer_id = append_nccl2_prepare() nccl_id_var, num_trainers, trainer_id = append_nccl2_prepare(trainer_id)
if args.gpus == 1: if args.gpus == 1:
# NOTE: parallel executor use profiler interanlly # NOTE: parallel executor use profiler interanlly
if args.use_nvprof and args.device == 'GPU': if args.use_nvprof and args.device == 'GPU':

@ -86,7 +86,7 @@
<br> <br>
<p align="center"> <p align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/fluid_compiler.png" width=100%> <img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/fluid-compiler.png" width=100%>
</p> </p>
--- ---

@ -17,3 +17,4 @@
:maxdepth: 1 :maxdepth: 1
concepts/use_concepts_cn.rst concepts/use_concepts_cn.rst
developer's_guide_to_paddle_fluid.md

@ -16,3 +16,4 @@ Here is an example of linear regression. It introduces workflow of PaddlePaddle,
:maxdepth: 1 :maxdepth: 1
concepts/index_en.rst concepts/index_en.rst
developer's_guide_to_paddle_fluid.md

@ -11,7 +11,7 @@ PaddlePaddle支持使用pip快速安装目前支持CentOS 6以上, Ubuntu 14.
pip install paddlepaddle pip install paddlepaddle
如果需要安装支持GPU的版本cuda7.5_cudnn5_avx_openblas需要执行 如果需要安装支持GPU的版本cuda8.0_cudnn5_avx_openblas需要执行
.. code-block:: bash .. code-block:: bash

@ -12,7 +12,7 @@ Simply run the following command to install, the version is cpu_avx_openblas:
pip install paddlepaddle pip install paddlepaddle
If you need to install GPU version (cuda7.5_cudnn5_avx_openblas), run: If you need to install GPU version (cuda8.0_cudnn5_avx_openblas), run:
.. code-block:: bash .. code-block:: bash

@ -51,6 +51,8 @@ Paddle 开发人员使用 [pre-commit](http://pre-commit.com/) 工具来管理 G
Paddle 使用 `clang-format` 来调整 C/C++ 源代码格式,请确保 `clang-format` 版本在 3.8 以上。 Paddle 使用 `clang-format` 来调整 C/C++ 源代码格式,请确保 `clang-format` 版本在 3.8 以上。
注:通过`pip install pre-commit`和`conda install -c conda-forge pre-commit`安装的`yapf`稍有不同的Paddle 开发人员使用的是`pip install pre-commit`。
## 开始开发 ## 开始开发
在本例中,我删除了 README.md 中的一行,并创建了一个新文件。 在本例中,我删除了 README.md 中的一行,并创建了一个新文件。

@ -13,7 +13,11 @@
# limitations under the License. # limitations under the License.
# #
function(inference_api_test TARGET_NAME TEST_SRC DEP_TEST) if(APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pessimizing-move")
endif(APPLE)
function(inference_api_test TARGET_NAME TEST_SRC)
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs ARGS) set(multiValueArgs ARGS)
@ -34,6 +38,8 @@ function(inference_api_test TARGET_NAME TEST_SRC DEP_TEST)
SRCS ${TEST_SRC} SRCS ${TEST_SRC}
DEPS paddle_fluid_api paddle_inference_api paddle_inference_api_impl DEPS paddle_fluid_api paddle_inference_api paddle_inference_api_impl
ARGS --dirname=${PYTHON_TESTS_DIR}/book/) ARGS --dirname=${PYTHON_TESTS_DIR}/book/)
# TODO(panyx0178): Figure out how to add word2vec and image_classification
# as deps.
# set_tests_properties(${TARGET_NAME} # set_tests_properties(${TARGET_NAME}
# PROPERTIES DEPENDS ${DEP_TEST}) # PROPERTIES DEPENDS ${DEP_TEST})
endforeach() endforeach()
@ -53,5 +59,4 @@ cc_test(test_paddle_inference_api
DEPS paddle_inference_api) DEPS paddle_inference_api)
inference_api_test(test_paddle_inference_api_impl inference_api_test(test_paddle_inference_api_impl
test_paddle_inference_api_impl.cc test_paddle_inference_api_impl.cc)
test_word2vec)

@ -102,8 +102,8 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs,
Timer timer; Timer timer;
timer.tic(); timer.tic();
// set feed variable // set feed variable
std::map<std::string, const paddle::framework::LoDTensor *> feed_targets; std::map<std::string, const framework::LoDTensor *> feed_targets;
std::vector<paddle::framework::LoDTensor> feeds; std::vector<framework::LoDTensor> feeds;
if (!SetFeed(inputs, &feeds)) { if (!SetFeed(inputs, &feeds)) {
LOG(ERROR) << "fail to set feed"; LOG(ERROR) << "fail to set feed";
return false; return false;
@ -112,8 +112,8 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs,
feed_targets[feed_target_names_[i]] = &feeds[i]; feed_targets[feed_target_names_[i]] = &feeds[i];
} }
// get fetch variable // get fetch variable
std::map<std::string, paddle::framework::LoDTensor *> fetch_targets; std::map<std::string, framework::LoDTensor *> fetch_targets;
std::vector<paddle::framework::LoDTensor> fetchs; std::vector<framework::LoDTensor> fetchs;
fetchs.resize(fetch_target_names_.size()); fetchs.resize(fetch_target_names_.size());
for (size_t i = 0; i < fetch_target_names_.size(); ++i) { for (size_t i = 0; i < fetch_target_names_.size(); ++i) {
fetch_targets[fetch_target_names_[i]] = &fetchs[i]; fetch_targets[fetch_target_names_[i]] = &fetchs[i];
@ -149,28 +149,27 @@ bool PaddlePredictorImpl::InitShared() {
VLOG(3) << "Predictor::init_shared"; VLOG(3) << "Predictor::init_shared";
// 1. Define place, executor, scope // 1. Define place, executor, scope
if (this->config_.device >= 0) { if (this->config_.device >= 0) {
place_ = paddle::platform::CUDAPlace(); place_ = platform::CUDAPlace();
} else { } else {
place_ = paddle::platform::CPUPlace(); place_ = platform::CPUPlace();
} }
this->executor_.reset(new paddle::framework::Executor(this->place_)); this->executor_.reset(new framework::Executor(this->place_));
this->scope_.reset(new paddle::framework::Scope()); this->scope_.reset(new framework::Scope());
// Initialize the inference program // Initialize the inference program
if (!this->config_.model_dir.empty()) { if (!this->config_.model_dir.empty()) {
// Parameters are saved in separate files sited in // Parameters are saved in separate files sited in
// the specified `dirname`. // the specified `dirname`.
this->inference_program_ = paddle::inference::Load( this->inference_program_ = inference::Load(
this->executor_.get(), this->scope_.get(), this->config_.model_dir); this->executor_.get(), this->scope_.get(), this->config_.model_dir);
} else if (!this->config_.prog_file.empty() && } else if (!this->config_.prog_file.empty() &&
!this->config_.param_file.empty()) { !this->config_.param_file.empty()) {
// All parameters are saved in a single file. // All parameters are saved in a single file.
// The file names should be consistent with that used // The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`. // in Python API `fluid.io.save_inference_model`.
this->inference_program_ = this->inference_program_ = inference::Load(this->executor_.get(),
paddle::inference::Load(this->executor_.get(), this->scope_.get(),
this->scope_.get(), this->config_.prog_file,
this->config_.prog_file, this->config_.param_file);
this->config_.param_file);
} }
this->ctx_ = this->executor_->Prepare(*this->inference_program_, 0); this->ctx_ = this->executor_->Prepare(*this->inference_program_, 0);
// 3. create variables // 3. create variables
@ -185,24 +184,21 @@ bool PaddlePredictorImpl::InitShared() {
return true; return true;
} }
bool PaddlePredictorImpl::SetFeed( bool PaddlePredictorImpl::SetFeed(const std::vector<PaddleTensor> &inputs,
const std::vector<PaddleTensor> &inputs, std::vector<framework::LoDTensor> *feeds) {
std::vector<paddle::framework::LoDTensor> *feeds) {
VLOG(3) << "Predictor::set_feed"; VLOG(3) << "Predictor::set_feed";
if (inputs.size() != feed_target_names_.size()) { if (inputs.size() != feed_target_names_.size()) {
LOG(ERROR) << "wrong feed input size."; LOG(ERROR) << "wrong feed input size.";
return false; return false;
} }
for (size_t i = 0; i < feed_target_names_.size(); ++i) { for (size_t i = 0; i < feed_target_names_.size(); ++i) {
paddle::framework::LoDTensor input; framework::LoDTensor input;
paddle::framework::DDim ddim = framework::DDim ddim = framework::make_ddim(inputs[i].shape);
paddle::framework::make_ddim(inputs[i].shape);
void *input_ptr; void *input_ptr;
if (inputs[i].dtype == PaddleDType::INT64) { if (inputs[i].dtype == PaddleDType::INT64) {
input_ptr = input_ptr = input.mutable_data<int64_t>(ddim, platform::CPUPlace());
input.mutable_data<int64_t>(ddim, paddle::platform::CPUPlace());
} else if (inputs[i].dtype == PaddleDType::FLOAT32) { } else if (inputs[i].dtype == PaddleDType::FLOAT32) {
input_ptr = input.mutable_data<float>(ddim, paddle::platform::CPUPlace()); input_ptr = input.mutable_data<float>(ddim, platform::CPUPlace());
} else { } else {
LOG(ERROR) << "unsupported feed type " << inputs[i].dtype; LOG(ERROR) << "unsupported feed type " << inputs[i].dtype;
return false; return false;
@ -213,13 +209,12 @@ bool PaddlePredictorImpl::SetFeed(
inputs[i].data.data, inputs[i].data.data,
inputs[i].data.length); inputs[i].data.length);
feeds->push_back(input); feeds->push_back(input);
LOG(ERROR) << "Actual feed type " << feeds->back().type().name();
} }
return true; return true;
} }
bool PaddlePredictorImpl::GetFetch( bool PaddlePredictorImpl::GetFetch(
const std::vector<paddle::framework::LoDTensor> &fetchs, const std::vector<framework::LoDTensor> &fetchs,
std::vector<PaddleTensor> *outputs) { std::vector<PaddleTensor> *outputs) {
VLOG(3) << "Predictor::get_fetch"; VLOG(3) << "Predictor::get_fetch";
outputs->resize(fetchs.size()); outputs->resize(fetchs.size());
@ -284,8 +279,9 @@ bool PaddlePredictorImpl::GetFetch(
return true; return true;
} }
std::unique_ptr<PaddlePredictorImpl> CreatePaddlePredictorImpl( template <>
const VisConfig &config) { std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(
const ConfigImpl &config) {
VLOG(3) << "create PaddlePredictorImpl"; VLOG(3) << "create PaddlePredictorImpl";
// 1. GPU memeroy // 1. GPU memeroy
std::vector<std::string> flags; std::vector<std::string> flags;
@ -299,12 +295,11 @@ std::unique_ptr<PaddlePredictorImpl> CreatePaddlePredictorImpl(
framework::InitGflags(flags); framework::InitGflags(flags);
} }
std::unique_ptr<PaddlePredictorImpl> predictor( std::unique_ptr<PaddlePredictor> predictor(new PaddlePredictorImpl(config));
new PaddlePredictorImpl(config)); if (!dynamic_cast<PaddlePredictorImpl *>(predictor.get())->Init()) {
if (!predictor->Init()) {
return nullptr; return nullptr;
} }
return predictor; return std::move(predictor);
} }
} // namespace paddle } // namespace paddle

@ -29,7 +29,7 @@
namespace paddle { namespace paddle {
struct VisConfig : public PaddlePredictor::Config { struct ConfigImpl : public PaddlePredictor::Config {
int device; int device;
float fraction_of_gpu_memory; float fraction_of_gpu_memory;
std::string prog_file; std::string prog_file;
@ -37,12 +37,9 @@ struct VisConfig : public PaddlePredictor::Config {
bool share_variables; bool share_variables;
}; };
/*
* Do not use this, just a demo indicating how to customize a Predictor.
*/
class PaddlePredictorImpl : public PaddlePredictor { class PaddlePredictorImpl : public PaddlePredictor {
public: public:
explicit PaddlePredictorImpl(const VisConfig &config) : config_(config) {} explicit PaddlePredictorImpl(const ConfigImpl &config) : config_(config) {}
bool Init(); bool Init();
@ -56,21 +53,18 @@ class PaddlePredictorImpl : public PaddlePredictor {
private: private:
bool InitShared() override; bool InitShared() override;
bool SetFeed(const std::vector<PaddleTensor> &input_datas, bool SetFeed(const std::vector<PaddleTensor> &input_datas,
std::vector<paddle::framework::LoDTensor> *feeds); std::vector<framework::LoDTensor> *feeds);
bool GetFetch(const std::vector<paddle::framework::LoDTensor> &fetchs, bool GetFetch(const std::vector<framework::LoDTensor> &fetchs,
std::vector<PaddleTensor> *output_data); std::vector<PaddleTensor> *output_data);
VisConfig config_; ConfigImpl config_;
paddle::platform::Place place_; platform::Place place_;
std::unique_ptr<paddle::framework::Executor> executor_; std::unique_ptr<framework::Executor> executor_;
std::unique_ptr<paddle::framework::Scope> scope_; std::unique_ptr<framework::Scope> scope_;
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx_; std::unique_ptr<framework::ExecutorPrepareContext> ctx_;
std::unique_ptr<paddle::framework::ProgramDesc> inference_program_; std::unique_ptr<framework::ProgramDesc> inference_program_;
std::vector<std::string> feed_target_names_; std::vector<std::string> feed_target_names_;
std::vector<std::string> fetch_target_names_; std::vector<std::string> fetch_target_names_;
}; };
std::unique_ptr<PaddlePredictorImpl> CreatePaddlePredictorImpl(
const VisConfig &config);
} // namespace paddle } // namespace paddle

@ -40,16 +40,19 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
return pt; return pt;
} }
TEST(paddle_inference_api_impl, word2vec) { ConfigImpl GetConfig() {
VisConfig config; ConfigImpl config;
config.model_dir = FLAGS_dirname + "word2vec.inference.model"; config.model_dir = FLAGS_dirname + "word2vec.inference.model";
LOG(INFO) << "dirname " << config.model_dir; LOG(INFO) << "dirname " << config.model_dir;
config.fraction_of_gpu_memory = 0.15; config.fraction_of_gpu_memory = 0.15;
config.device = 0; config.device = 0;
config.share_variables = true; config.share_variables = true;
return config;
}
std::unique_ptr<PaddlePredictorImpl> predictor = TEST(paddle_inference_api_impl, word2vec) {
CreatePaddlePredictorImpl(config); ConfigImpl config = GetConfig();
std::unique_ptr<PaddlePredictor> predictor = CreatePaddlePredictor(config);
framework::LoDTensor first_word, second_word, third_word, fourth_word; framework::LoDTensor first_word, second_word, third_word, fourth_word;
framework::LoD lod{{0, 1}}; framework::LoD lod{{0, 1}};
@ -60,24 +63,91 @@ TEST(paddle_inference_api_impl, word2vec) {
SetupLoDTensor(&third_word, lod, static_cast<int64_t>(0), dict_size - 1); SetupLoDTensor(&third_word, lod, static_cast<int64_t>(0), dict_size - 1);
SetupLoDTensor(&fourth_word, lod, static_cast<int64_t>(0), dict_size - 1); SetupLoDTensor(&fourth_word, lod, static_cast<int64_t>(0), dict_size - 1);
std::vector<PaddleTensor> cpu_feeds; std::vector<PaddleTensor> paddle_tensor_feeds;
cpu_feeds.push_back(LodTensorToPaddleTensor(&first_word)); paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&first_word));
cpu_feeds.push_back(LodTensorToPaddleTensor(&second_word)); paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&second_word));
cpu_feeds.push_back(LodTensorToPaddleTensor(&third_word)); paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&third_word));
cpu_feeds.push_back(LodTensorToPaddleTensor(&fourth_word)); paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&fourth_word));
std::vector<PaddleTensor> outputs;
ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
ASSERT_EQ(outputs.size(), 1UL);
size_t len = outputs[0].data.length;
float* data = static_cast<float*>(outputs[0].data.data);
for (int j = 0; j < len / sizeof(float); ++j) {
ASSERT_LT(data[j], 1.0);
ASSERT_GT(data[j], -1.0);
}
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&first_word);
cpu_feeds.push_back(&second_word);
cpu_feeds.push_back(&third_word);
cpu_feeds.push_back(&fourth_word);
framework::LoDTensor output1;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);
TestInference<platform::CPUPlace>(config.model_dir, cpu_feeds, cpu_fetchs1);
float* lod_data = output1.data<float>();
for (size_t i = 0; i < output1.numel(); ++i) {
EXPECT_LT(lod_data[i] - data[i], 1e-3);
EXPECT_GT(lod_data[i] - data[i], -1e-3);
}
free(outputs[0].data.data);
}
TEST(paddle_inference_api_impl, image_classification) {
int batch_size = 2;
bool use_mkldnn = false;
bool repeat = false;
ConfigImpl config = GetConfig();
config.model_dir =
FLAGS_dirname + "image_classification_resnet.inference.model";
const bool is_combined = false;
std::vector<std::vector<int64_t>> feed_target_shapes =
GetFeedTargetShapes(config.model_dir, is_combined);
framework::LoDTensor input;
// Use normilized image pixels as input data,
// which should be in the range [0.0, 1.0].
feed_target_shapes[0][0] = batch_size;
framework::DDim input_dims = framework::make_ddim(feed_target_shapes[0]);
SetupTensor<float>(
&input, input_dims, static_cast<float>(0), static_cast<float>(1));
std::vector<framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&input);
framework::LoDTensor output1;
std::vector<framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);
TestInference<platform::CPUPlace, false, true>(config.model_dir,
cpu_feeds,
cpu_fetchs1,
repeat,
is_combined,
use_mkldnn);
std::unique_ptr<PaddlePredictor> predictor = CreatePaddlePredictor(config);
std::vector<PaddleTensor> paddle_tensor_feeds;
paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&input));
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
ASSERT_TRUE(predictor->Run(cpu_feeds, &outputs)); ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
ASSERT_EQ(outputs.size(), 1UL); ASSERT_EQ(outputs.size(), 1UL);
for (size_t i = 0; i < outputs.size(); ++i) { size_t len = outputs[0].data.length;
size_t len = outputs[i].data.length; float* data = static_cast<float*>(outputs[0].data.data);
float* data = static_cast<float*>(outputs[i].data.data); float* lod_data = output1.data<float>();
for (size_t j = 0; j < len / sizeof(float); ++j) { for (size_t j = 0; j < len / sizeof(float); ++j) {
ASSERT_LT(data[j], 1.0); EXPECT_LT(lod_data[j] - data[j], 1e-10);
ASSERT_GT(data[j], -1.0); EXPECT_GT(lod_data[j] - data[j], -1e-10);
}
free(outputs[i].data.data);
} }
free(data);
} }
} // namespace paddle } // namespace paddle

@ -469,6 +469,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
protected: protected:
DDim GetDim(const std::string& name) const override { DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name); Variable* var = scope_.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var);
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims(); return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {

@ -18,8 +18,8 @@ namespace paddle {
namespace framework { namespace framework {
struct ReAllocateVisitor { struct ReAllocateVisitor {
ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims) ReAllocateVisitor(const framework::DDim& dims, framework::Tensor* tensor)
: tensor_(tensor), dims_(dims) {} : dims_(dims), tensor_(tensor) {}
template <typename T> template <typename T>
void operator()() const { void operator()() const {
@ -34,8 +34,8 @@ struct ReAllocateVisitor {
tensor_->ShareDataWith(cpu_tensor); tensor_->ShareDataWith(cpu_tensor);
} }
framework::Tensor* tensor_;
framework::DDim dims_; framework::DDim dims_;
framework::Tensor* tensor_;
}; };
struct TensorCopyVisitor { struct TensorCopyVisitor {
@ -158,6 +158,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
} }
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1), PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
"The first dim of value should be 1."); "The first dim of value should be 1.");
std::lock_guard<std::mutex> lock(*auto_grown_mutex_.get());
auto index = Index(key); auto index = Index(key);
bool is_new_key = false; bool is_new_key = false;
if (index == -1) { if (index == -1) {
@ -169,7 +170,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
auto dims = value_->dims(); auto dims = value_->dims();
dims[0] = (dims[0] + 1) << 1; dims[0] = (dims[0] + 1) << 1;
framework::VisitDataType(framework::ToDataType(value.type()), framework::VisitDataType(framework::ToDataType(value.type()),
ReAllocateVisitor(value_.get(), dims)); ReAllocateVisitor(dims, value_.get()));
} }
} }

@ -15,6 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <memory>
#include <mutex> // NOLINT
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -46,11 +48,13 @@ class SelectedRows {
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height) SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
: rows_(rows), height_(height) { : rows_(rows), height_(height) {
value_.reset(new Tensor()); value_.reset(new Tensor());
auto_grown_mutex_.reset(new std::mutex);
} }
SelectedRows() { SelectedRows() {
height_ = 0; height_ = 0;
value_.reset(new Tensor()); value_.reset(new Tensor());
auto_grown_mutex_.reset(new std::mutex);
} }
platform::Place place() const { return value_->place(); } platform::Place place() const { return value_->place(); }
@ -125,6 +129,7 @@ class SelectedRows {
Vector<int64_t> rows_; Vector<int64_t> rows_;
std::unique_ptr<Tensor> value_{nullptr}; std::unique_ptr<Tensor> value_{nullptr};
int64_t height_; int64_t height_;
std::unique_ptr<std::mutex> auto_grown_mutex_{nullptr};
}; };
/* /*

@ -131,6 +131,20 @@ void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
return buffer(name).buffer; return buffer(name).buffer;
} }
void TensorRTEngine::GetOutputInGPU(const std::string& name, void* dst,
size_t max_size) {
// determine data size
auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end());
PADDLE_ENFORCE_GT(it->second, 0);
PADDLE_ENFORCE_GE(max_size, it->second);
auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second,
cudaMemcpyDeviceToDevice, *stream_),
0);
}
void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst, void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
size_t max_size) { size_t max_size) {
// determine data size // determine data size
@ -152,7 +166,7 @@ Buffer& TensorRTEngine::buffer(const std::string& name) {
return buffers_[slot_offset]; return buffers_[slot_offset];
} }
void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data, void TensorRTEngine::SetInputFromCPU(const std::string& name, const void* data,
size_t size) { size_t size) {
auto& buf = buffer(name); auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer); PADDLE_ENFORCE_NOT_NULL(buf.buffer);
@ -162,6 +176,16 @@ void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
cudaMemcpyHostToDevice, *stream_)); cudaMemcpyHostToDevice, *stream_));
} }
void TensorRTEngine::SetInputFromGPU(const std::string& name, const void* data,
size_t size) {
auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer);
PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
PADDLE_ENFORCE(buf.device == DeviceType::GPU);
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
cudaMemcpyDeviceToDevice, *stream_));
}
void TensorRTEngine::SetITensor(const std::string& name, void TensorRTEngine::SetITensor(const std::string& name,
nvinfer1::ITensor* tensor) { nvinfer1::ITensor* tensor) {
PADDLE_ENFORCE(tensor != nullptr); PADDLE_ENFORCE(tensor != nullptr);

@ -92,13 +92,15 @@ class TensorRTEngine : public EngineBase {
cudaStream_t* stream() { return stream_; } cudaStream_t* stream() { return stream_; }
// Fill an input from CPU memory with name and size. // Fill an input from CPU memory with name and size.
void SetInputFromCPU(const std::string& name, void* data, size_t size); void SetInputFromCPU(const std::string& name, const void* data, size_t size);
// TODO(Superjomn) is this method necessary given that buffer(xxx) can be // TODO(Superjomn) is this method necessary given that buffer(xxx) can be
// accessed directly. Fill an input from GPU memory with name and size. // accessed directly. Fill an input from GPU memory with name and size.
void SetInputFromGPU(const std::string& name, void* data, size_t size); void SetInputFromGPU(const std::string& name, const void* data, size_t size);
// Get an output called name, the output of tensorrt is in GPU, so this method // Get an output called name, the output of tensorrt is in GPU, so this method
// will just return the output's GPU memory address. // Return the output's GPU memory address without copy.
void* GetOutputInGPU(const std::string& name); void* GetOutputInGPU(const std::string& name);
// Copy data into dst inside the GPU device.
void GetOutputInGPU(const std::string& name, void* dst, size_t max_size);
// LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU // LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU
// to CPU. // to CPU.
void GetOutputInCPU(const std::string& name, void* dst, size_t max_size); void GetOutputInCPU(const std::string& name, void* dst, size_t max_size);

@ -168,6 +168,8 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(relu);\n") file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "reduce") elseif(${TARGET} STREQUAL "reduce")
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n") file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
elseif(${TARGET} STREQUAL "fake_dequantize")
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
else() else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif() endif()
@ -223,6 +225,11 @@ op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(softmax_op DEPS softmax) op_library(softmax_op DEPS softmax)
op_library(sequence_softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax)
if (WITH_GPU AND TENSORRT_FOUND)
op_library(tensorrt_engine_op DEPS tensorrt_engine)
else()
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
endif()
op_library(sum_op DEPS selected_rows_functor) op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor)
op_library(print_op DEPS lod_tensor) op_library(print_op DEPS lod_tensor)

@ -89,4 +89,5 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int64_t>, ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, bool>, ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>); ops::CastOpKernel<CPU, paddle::platform::float16>);

@ -21,5 +21,5 @@ using CastOpKernel =
REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>, REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
CastOpKernel<int>, CastOpKernel<int64_t>, CastOpKernel<int>, CastOpKernel<int64_t>,
CastOpKernel<bool>, CastOpKernel<bool>, CastOpKernel<uint8_t>,
CastOpKernel<paddle::platform::float16>); CastOpKernel<paddle::platform::float16>);

@ -0,0 +1,76 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.h"
#include <string>
namespace paddle {
namespace operators {
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
public:
FakeDequantizeMaxAbsOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input with float-32/64 type is the "
"low precision tensor.");
AddOutput("Out",
"(Tensor) The output is the dequantized high "
"precision tensor.");
AddAttr<int>("num_bits",
"(int) `num_bits` is the quantization level bits, "
"such as 2, 5, 8.");
AddAttr<float>("scale",
"(float) The maximum absolute value of low precision tensor."
"It is usually calculated by the fake_quantize_max_abs_op.");
AddComment(R"DOC(
FakeDequantizeMaxAbsOp operator.
This calculation is an opposite operation of FakeQuantizeMaxAbsOp:
$$Out = \frac{scale*X}{2^{num_bits} - 1}$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsOp,
ops::FakeDequantizeMaxAbsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CPU, float>,
ops::FakeDequantizeMaxAbsKernel<CPU, double>);

@ -0,0 +1,21 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.h"
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeDequantizeMaxAbsKernel<CUDA, double>);

@ -0,0 +1,42 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(in->place());
int num_bits = ctx.Attr<int>("num_bits");
T scale = static_cast<T>(ctx.Attr<float>("scale"));
int range = std::pow(2, num_bits) - 1;
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
eigen_out.device(dev) = (scale / range) * eigen_in;
}
};
} // namespace operators
} // namespace paddle

@ -1,197 +0,0 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/mul_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
template <typename Format = mkldnn::memory::format>
mkldnn::memory::desc type(const std::vector<int>& dims, Format&& f) {
return platform::MKLDNNMemDesc(dims, mkldnn::memory::data_type::f32, f);
}
template <typename T>
class MulMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
auto input = ctx.Input<Tensor>("X");
auto weight = ctx.Input<Tensor>("Y");
PADDLE_ENFORCE(input->dims().size() & (2 | 4),
"Input must be with 2 or 4 dimensions, i.e. NC or NCHW");
PADDLE_ENFORCE(weight->dims().size() & (2 | 4),
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
std::vector<int> w_tz = paddle::framework::vectorize2int(weight->dims());
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
auto src_md =
src_tz.size() != 2
? type(src_tz, mkldnn::memory::format::nchw)
: type({src_tz[0], src_tz[1]}, mkldnn::memory::format::nc);
auto dst_md = type({src_tz[0], w_tz[1]}, mkldnn::memory::format::nc);
auto weights_md =
src_tz.size() != 2
? type({w_tz[1], src_tz[1], src_tz[2], src_tz[3]},
mkldnn::memory::format::oihw)
: type({w_tz[1], src_tz[1]}, mkldnn::memory::format::oi);
auto output = ctx.Output<Tensor>("Out");
T* output_data = output->mutable_data<T>(ctx.GetPlace());
const std::string key = ctx.op().Output("Out");
const std::string key_fc_pd = key + "@mul_pd";
const T* input_data = input->data<T>();
const T* w_data = weight->data<T>();
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data);
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
platform::to_void_cast(input_data));
auto weights_memory = mkldnn::memory({weights_md, mkldnn_engine},
platform::to_void_cast(w_data));
auto pd = platform::MKLDNNFwdPrimitiveDesc<mkldnn::inner_product_forward>(
mkldnn_engine, src_md, weights_md, dst_md);
dev_ctx.SetBlob(key_fc_pd, pd);
auto forward = mkldnn::inner_product_forward(*pd, src_memory,
weights_memory, dst_memory);
std::vector<mkldnn::primitive> pipeline = {forward};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
};
template <typename T>
class MulMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* input = ctx.Input<Tensor>("X");
const Tensor* w = ctx.Input<Tensor>("Y");
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* w_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
const std::string key = ctx.op().Input("Out");
const std::string key_fc_pd = key + "@mul_pd";
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
const T* out_grad_data = out_grad->data<T>();
T* input_grad_data = nullptr;
T* w_grad_data = nullptr;
if (input_grad) {
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
}
if (w_grad) {
w_grad_data = w_grad->mutable_data<T>(ctx.GetPlace());
}
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> w_tz = paddle::framework::vectorize2int(w->dims());
auto src_md =
src_tz.size() != 2
? type(src_tz, mkldnn::memory::format::nchw)
: type({src_tz[0], src_tz[1]}, mkldnn::memory::format::nc);
auto dst_md = type({src_tz[0], w_tz[1]}, mkldnn::memory::format::nc);
auto weights_md =
src_tz.size() != 2
? type({w_tz[1], src_tz[1], src_tz[2], src_tz[3]},
mkldnn::memory::format::oihw)
: type({w_tz[1], src_tz[1]}, mkldnn::memory::format::oi);
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
platform::to_void_cast(input_data));
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine},
platform::to_void_cast(out_grad_data));
auto weight_memory = mkldnn::memory({weights_md, mkldnn_engine},
platform::to_void_cast(w_data));
auto pd =
std::static_pointer_cast<mkldnn::inner_product_forward::primitive_desc>(
dev_ctx.GetBlob(key_fc_pd));
PADDLE_ENFORCE(pd != nullptr, "Fail to find pd in device context");
if (w_grad) {
auto weights_grad_memory = mkldnn::memory(
{weights_md, mkldnn_engine}, platform::to_void_cast(w_grad_data));
auto bwd_weight_pd = platform::MKLDNNBwdPrimitiveDesc<
mkldnn::inner_product_backward_weights>(mkldnn_engine, *pd, src_md,
weights_md, dst_md);
auto bwd_weights_prim = mkldnn::inner_product_backward_weights(
bwd_weight_pd, src_memory, dst_memory, weights_grad_memory);
std::vector<mkldnn::primitive> pipeline{bwd_weights_prim};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
if (input_grad) {
auto src_grad_memory = mkldnn::memory(
{src_md, mkldnn_engine}, platform::to_void_cast(input_grad_data));
auto bwd_data_pd =
platform::MKLDNNBwdPrimitiveDesc<mkldnn::inner_product_backward_data>(
mkldnn_engine, *pd, src_md, weights_md, dst_md);
auto bwd_data_prim = mkldnn::inner_product_backward_data(
bwd_data_pd, dst_memory, weight_memory, src_grad_memory);
std::vector<mkldnn::primitive> pipeline{bwd_data_prim};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_KERNEL(mul, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::MulMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(mul_grad, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::MulMKLDNNGradOpKernel<float>);

@ -16,10 +16,6 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -76,22 +72,6 @@ class MulOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
private:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout, library);
}
}; };
class MulOpMaker : public framework::OpProtoAndCheckerMaker { class MulOpMaker : public framework::OpProtoAndCheckerMaker {
@ -120,9 +100,6 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
)DOC") )DOC")
.SetDefault(1) .SetDefault(1)
.EqualGreaterThan(1); .EqualGreaterThan(1);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<int>( AddAttr<int>(
"y_num_col_dims", "y_num_col_dims",
R"DOC((int, default 1), The mul_op can take tensors with more than two, R"DOC((int, default 1), The mul_op can take tensors with more than two,
@ -177,22 +154,6 @@ class MulGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(y_grad_name, y_dims); ctx->SetOutputDim(y_grad_name, y_dims);
} }
} }
private:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout, library);
}
}; };
} // namespace operators } // namespace operators

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save