You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							100 lines
						
					
					
						
							4.0 KiB
						
					
					
				
			
		
		
	
	
							100 lines
						
					
					
						
							4.0 KiB
						
					
					
				/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License. */
 | 
						|
#include <fcntl.h>
 | 
						|
 | 
						|
#ifdef _POSIX_C_SOURCE
 | 
						|
#undef _POSIX_C_SOURCE
 | 
						|
#endif
 | 
						|
 | 
						|
#ifdef _XOPEN_SOURCE
 | 
						|
#undef _XOPEN_SOURCE
 | 
						|
#endif
 | 
						|
 | 
						|
#include <memory>
 | 
						|
#include <string>
 | 
						|
#include <vector>
 | 
						|
 | 
						|
#include "paddle/fluid/framework/data_feed.h"
 | 
						|
#include "paddle/fluid/framework/data_feed.pb.h"
 | 
						|
#include "paddle/fluid/framework/fleet/box_wrapper.h"
 | 
						|
#include "paddle/fluid/pybind/box_helper_py.h"
 | 
						|
#ifdef PADDLE_WITH_BOX_PS
 | 
						|
#include <boxps_public.h>
 | 
						|
#endif
 | 
						|
 | 
						|
namespace py = pybind11;
 | 
						|
 | 
						|
namespace paddle {
 | 
						|
namespace pybind {
 | 
						|
void BindBoxHelper(py::module* m) {
 | 
						|
  py::class_<framework::BoxHelper, std::shared_ptr<framework::BoxHelper>>(
 | 
						|
      *m, "BoxPS")
 | 
						|
      .def(py::init([](paddle::framework::Dataset* dataset) {
 | 
						|
        return std::make_shared<paddle::framework::BoxHelper>(dataset);
 | 
						|
      }))
 | 
						|
      .def("set_date", &framework::BoxHelper::SetDate,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("begin_pass", &framework::BoxHelper::BeginPass,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("end_pass", &framework::BoxHelper::EndPass,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("wait_feed_pass_done", &framework::BoxHelper::WaitFeedPassDone,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("preload_into_memory", &framework::BoxHelper::PreLoadIntoMemory,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("load_into_memory", &framework::BoxHelper::LoadIntoMemory,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("slots_shuffle", &framework::BoxHelper::SlotsShuffle,
 | 
						|
           py::call_guard<py::gil_scoped_release>());
 | 
						|
}  // end BoxHelper
 | 
						|
 | 
						|
#ifdef PADDLE_WITH_BOX_PS
 | 
						|
void BindBoxWrapper(py::module* m) {
 | 
						|
  py::class_<framework::BoxWrapper, std::shared_ptr<framework::BoxWrapper>>(
 | 
						|
      *m, "BoxWrapper")
 | 
						|
      .def(py::init([](int embedx_dim, int expand_embed_dim) {
 | 
						|
        // return std::make_shared<paddle::framework::BoxHelper>(dataset);
 | 
						|
        return framework::BoxWrapper::SetInstance(embedx_dim, expand_embed_dim);
 | 
						|
      }))
 | 
						|
      .def("save_base", &framework::BoxWrapper::SaveBase,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("feed_pass", &framework::BoxWrapper::FeedPass,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("set_test_mode", &framework::BoxWrapper::SetTestMode,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("save_delta", &framework::BoxWrapper::SaveDelta,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("initialize_gpu_and_load_model",
 | 
						|
           &framework::BoxWrapper::InitializeGPUAndLoadModel,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("initialize_auc_runner", &framework::BoxWrapper::InitializeAucRunner,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("init_metric", &framework::BoxWrapper::InitMetric,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("get_metric_msg", &framework::BoxWrapper::GetMetricMsg,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("get_metric_name_list", &framework::BoxWrapper::GetMetricNameList,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("flip_phase", &framework::BoxWrapper::FlipPhase,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("init_afs_api", &framework::BoxWrapper::InitAfsAPI,
 | 
						|
           py::call_guard<py::gil_scoped_release>())
 | 
						|
      .def("finalize", &framework::BoxWrapper::Finalize,
 | 
						|
           py::call_guard<py::gil_scoped_release>());
 | 
						|
}  // end BoxWrapper
 | 
						|
#endif
 | 
						|
 | 
						|
}  // end namespace pybind
 | 
						|
}  // end namespace paddle
 |