|
|
|
@ -101,6 +101,7 @@ static void InitTensorForVarBase(imperative::VarBase *self,
|
|
|
|
|
|
|
|
|
|
static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
|
|
|
|
|
const py::kwargs &kwargs) {
|
|
|
|
|
VLOG(4) << "Init VarBase";
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
kwargs.contains("value"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
@ -126,6 +127,7 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
|
|
|
|
|
bool persistable = false,
|
|
|
|
|
bool zero_copy = false,
|
|
|
|
|
std::string name = "") {
|
|
|
|
|
VLOG(4) << "Init VarBase";
|
|
|
|
|
// 0: self, 1: value, 2: place, 3: persistable, 4: zero_copy, 5: name
|
|
|
|
|
if (name == "") {
|
|
|
|
|
name = imperative::GetCurrentTracer()->GenerateUniqueName("generated_var");
|
|
|
|
@ -140,10 +142,31 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
|
|
|
|
|
|
|
|
|
|
static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
|
|
|
|
|
const py::array &array) {
|
|
|
|
|
VLOG(4) << "Init VarBase";
|
|
|
|
|
auto place = imperative::GetCurrentTracer()->ExpectedPlace();
|
|
|
|
|
InitTensorForVarBase(self, array, place);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void InitVarBaseFromTensorWithArgDefault(
|
|
|
|
|
imperative::VarBase *self, const framework::LoDTensor &tensor) {
|
|
|
|
|
VLOG(4) << "Init VarBase";
|
|
|
|
|
auto place = imperative::GetCurrentTracer()->ExpectedPlace();
|
|
|
|
|
new (self) imperative::VarBase(
|
|
|
|
|
imperative::GetCurrentTracer()->GenerateUniqueName("generated_var"));
|
|
|
|
|
self->SetPersistable(false);
|
|
|
|
|
self->SetType(framework::proto::VarType::LOD_TENSOR);
|
|
|
|
|
self->SetDataType(tensor.type());
|
|
|
|
|
auto *new_tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
|
|
|
|
|
// Same place,share data directly
|
|
|
|
|
if (place == tensor.place()) {
|
|
|
|
|
new_tensor->ShareDataWith(tensor);
|
|
|
|
|
VLOG(4) << "Same place, do ShareDataWith";
|
|
|
|
|
} else {
|
|
|
|
|
framework::TensorCopy(tensor, place, new_tensor);
|
|
|
|
|
VLOG(4) << "Different place, do TensorCopy";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string GetTypeName(const imperative::VarBase &var) {
|
|
|
|
|
if (var.Type() == framework::proto::VarType::RAW) {
|
|
|
|
|
return "RAW";
|
|
|
|
@ -520,6 +543,7 @@ void BindImperative(py::module *m_ptr) {
|
|
|
|
|
[](imperative::VarBase &self, framework::proto::VarType::Type dtype,
|
|
|
|
|
const std::vector<int> &dims, const py::handle &name,
|
|
|
|
|
framework::proto::VarType::Type type, bool persistable) {
|
|
|
|
|
VLOG(4) << "Init VarBase";
|
|
|
|
|
std::string act_name = "";
|
|
|
|
|
if (!name.ptr() || name.ptr() == Py_None) {
|
|
|
|
|
act_name = imperative::GetCurrentTracer()->GenerateUniqueName(
|
|
|
|
@ -547,6 +571,7 @@ void BindImperative(py::module *m_ptr) {
|
|
|
|
|
py::arg("value"), py::arg("place"), py::arg("persistable") = false,
|
|
|
|
|
py::arg("zero_copy") = false, py::arg("name") = "")
|
|
|
|
|
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
|
|
|
|
|
.def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
|
|
|
|
|
.def("__init__", &InitVarBaseFromNumpyWithKwargs)
|
|
|
|
|
.def("__getitem__",
|
|
|
|
|
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
|
|
|
|
|