You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							100 lines
						
					
					
						
							4.0 KiB
						
					
					
				
			
		
		
	
	
							100 lines
						
					
					
						
							4.0 KiB
						
					
					
				| /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
| http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License. */
 | |
| #include <fcntl.h>
 | |
| 
 | |
| #ifdef _POSIX_C_SOURCE
 | |
| #undef _POSIX_C_SOURCE
 | |
| #endif
 | |
| 
 | |
| #ifdef _XOPEN_SOURCE
 | |
| #undef _XOPEN_SOURCE
 | |
| #endif
 | |
| 
 | |
| #include <memory>
 | |
| #include <string>
 | |
| #include <vector>
 | |
| 
 | |
| #include "paddle/fluid/framework/data_feed.h"
 | |
| #include "paddle/fluid/framework/data_feed.pb.h"
 | |
| #include "paddle/fluid/framework/fleet/box_wrapper.h"
 | |
| #include "paddle/fluid/pybind/box_helper_py.h"
 | |
| #ifdef PADDLE_WITH_BOX_PS
 | |
| #include <boxps_public.h>
 | |
| #endif
 | |
| 
 | |
| namespace py = pybind11;
 | |
| 
 | |
| namespace paddle {
 | |
| namespace pybind {
 | |
| void BindBoxHelper(py::module* m) {
 | |
|   py::class_<framework::BoxHelper, std::shared_ptr<framework::BoxHelper>>(
 | |
|       *m, "BoxPS")
 | |
|       .def(py::init([](paddle::framework::Dataset* dataset) {
 | |
|         return std::make_shared<paddle::framework::BoxHelper>(dataset);
 | |
|       }))
 | |
|       .def("set_date", &framework::BoxHelper::SetDate,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("begin_pass", &framework::BoxHelper::BeginPass,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("end_pass", &framework::BoxHelper::EndPass,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("wait_feed_pass_done", &framework::BoxHelper::WaitFeedPassDone,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("preload_into_memory", &framework::BoxHelper::PreLoadIntoMemory,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("load_into_memory", &framework::BoxHelper::LoadIntoMemory,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("slots_shuffle", &framework::BoxHelper::SlotsShuffle,
 | |
|            py::call_guard<py::gil_scoped_release>());
 | |
| }  // end BoxHelper
 | |
| 
 | |
| #ifdef PADDLE_WITH_BOX_PS
 | |
| void BindBoxWrapper(py::module* m) {
 | |
|   py::class_<framework::BoxWrapper, std::shared_ptr<framework::BoxWrapper>>(
 | |
|       *m, "BoxWrapper")
 | |
|       .def(py::init([](int embedx_dim, int expand_embed_dim) {
 | |
|         // return std::make_shared<paddle::framework::BoxHelper>(dataset);
 | |
|         return framework::BoxWrapper::SetInstance(embedx_dim, expand_embed_dim);
 | |
|       }))
 | |
|       .def("save_base", &framework::BoxWrapper::SaveBase,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("feed_pass", &framework::BoxWrapper::FeedPass,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("set_test_mode", &framework::BoxWrapper::SetTestMode,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("save_delta", &framework::BoxWrapper::SaveDelta,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("initialize_gpu_and_load_model",
 | |
|            &framework::BoxWrapper::InitializeGPUAndLoadModel,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("initialize_auc_runner", &framework::BoxWrapper::InitializeAucRunner,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("init_metric", &framework::BoxWrapper::InitMetric,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("get_metric_msg", &framework::BoxWrapper::GetMetricMsg,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("get_metric_name_list", &framework::BoxWrapper::GetMetricNameList,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("flip_phase", &framework::BoxWrapper::FlipPhase,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("init_afs_api", &framework::BoxWrapper::InitAfsAPI,
 | |
|            py::call_guard<py::gil_scoped_release>())
 | |
|       .def("finalize", &framework::BoxWrapper::Finalize,
 | |
|            py::call_guard<py::gil_scoped_release>());
 | |
| }  // end BoxWrapper
 | |
| #endif
 | |
| 
 | |
| }  // end namespace pybind
 | |
| }  // end namespace paddle
 |