|
|
|
@ -20,6 +20,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
#include "paddle/framework/scope.h"
|
|
|
|
|
#include "paddle/platform/place.h"
|
|
|
|
|
#include "paddle/pybind/tensor_bind.h"
|
|
|
|
|
#include "pybind11/numpy.h"
|
|
|
|
|
#include "pybind11/pybind11.h"
|
|
|
|
@ -62,12 +63,12 @@ PYBIND11_PLUGIN(core) {
|
|
|
|
|
self.Resize(pd::make_ddim(dim));
|
|
|
|
|
})
|
|
|
|
|
.def("alloc_float",
|
|
|
|
|
[](pd::Tensor& self) {
|
|
|
|
|
self.mutable_data<float>(paddle::platform::CPUPlace());
|
|
|
|
|
[](pd::Tensor& self, paddle::platform::Place& place) {
|
|
|
|
|
self.mutable_data<float>(place);
|
|
|
|
|
})
|
|
|
|
|
.def("alloc_int",
|
|
|
|
|
[](pd::Tensor& self) {
|
|
|
|
|
self.mutable_data<int>(paddle::platform::CPUPlace());
|
|
|
|
|
[](pd::Tensor& self, paddle::platform::Place& place) {
|
|
|
|
|
self.mutable_data<int>(place);
|
|
|
|
|
})
|
|
|
|
|
.def("set", paddle::pybind::PyTensorSetFromArray<float>)
|
|
|
|
|
.def("set", paddle::pybind::PyTensorSetFromArray<int>)
|
|
|
|
@ -122,9 +123,20 @@ All parameter, weight, gradient are variables in Paddle.
|
|
|
|
|
.def("temp", pd::OperatorBase::TMP_VAR_NAME);
|
|
|
|
|
|
|
|
|
|
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
|
|
|
|
|
.def_static("cpu_context", []() -> paddle::platform::DeviceContext* {
|
|
|
|
|
return new paddle::platform::CPUDeviceContext();
|
|
|
|
|
});
|
|
|
|
|
.def_static(
|
|
|
|
|
"create",
|
|
|
|
|
[](paddle::platform::Place) -> paddle::platform::DeviceContext* {
|
|
|
|
|
if (paddle::platform::is_gpu_place(place)) {
|
|
|
|
|
return new paddle::platform::GPUDeviceContext(place);
|
|
|
|
|
} else if (paddle::platform::is_cpu_place(place)) {
|
|
|
|
|
return new paddle::platform::CPUDeviceContext();
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
py::class_<paddle::platform::Place>(m, "GPUPlace").def(py::init<int>());
|
|
|
|
|
.def(py::init<>());
|
|
|
|
|
|
|
|
|
|
py::class_<paddle::platform::Place>(m, "CPUPlace").def(py::init<>());
|
|
|
|
|
|
|
|
|
|
py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base(
|
|
|
|
|
m, "Operator");
|
|
|
|
|