implment pause in MapOp, added more to callback add ds_callback - Initial drop of Python DSCallback - Pybind DSCallback - Pybind DSCallback added callback to mapOp - de_pipeline DSCallback - de_pipeline DSCallback add test case, segfault for now fix seg fault - de_pipeline DSCallback remove 1 line update callback test case, now works use builder class for mapOp callback - de_pipeline DSCallback - de_pipeline DSCallback - de_pipeline DSCallback better test case minor fix add comments and minor clean ups get rid of nullptr in MapOp, use other flag instead fix a bug ParseMapOp only takes 1 callback - Added WaitedDSCalabck refactor callback param fix text case incorrect number - added testing fix cpp test case - added testing - revert back lenet changes - cleanup test_callbacks.py - cleanup test_callbacks.py fix CI stage I fix CI stage II fix CI and update epoch counter - add validation - add more testing test_callbacks.py use random data op to do tests adjust when to call EpochBegin/End - add repeat with callback - addressing reviewers' comments - docstring and CI fixes - docstring and CI fixes - docstring and CI fixes - rebase with upstream/master fix cpp test case fix review comments addr review cmts, add test casepull/4632/head
parent
89cd465268
commit
78c1aa1d96
@ -0,0 +1,45 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/stl_bind.h"
|
||||||
|
|
||||||
|
#include "minddata/dataset/api/python/pybind_register.h"
|
||||||
|
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||||
|
#include "minddata/dataset/callback/ds_callback.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
PYBIND_REGISTER(PyDSCallback, 0, ([](const py::module *m) {
|
||||||
|
(void)py::class_<PyDSCallback, std::shared_ptr<PyDSCallback>>(*m, "PyDSCallback")
|
||||||
|
.def(py::init<int32_t>())
|
||||||
|
.def("set_begin", &PyDSCallback::setBegin)
|
||||||
|
.def("set_end", &PyDSCallback::setEnd)
|
||||||
|
.def("set_epoch_begin", &PyDSCallback::setEpochBegin)
|
||||||
|
.def("set_epoch_end", &PyDSCallback::setEpochEnd)
|
||||||
|
.def("set_step_begin", &PyDSCallback::setStepBegin)
|
||||||
|
.def("set_step_end", &PyDSCallback::setStepEnd);
|
||||||
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(CallbackParam, 0, ([](const py::module *m) {
|
||||||
|
(void)py::class_<CallbackParam, std::shared_ptr<CallbackParam>>(*m, "CallbackParam")
|
||||||
|
.def(py::init<int64_t, int64_t, int64_t>())
|
||||||
|
.def_readonly("cur_epoch_num", &CallbackParam::cur_epoch_num_)
|
||||||
|
.def_readonly("cur_step_num_in_epoch", &CallbackParam::cur_epoch_step_num_)
|
||||||
|
.def_readonly("cur_step_num", &CallbackParam::cur_step_num_);
|
||||||
|
}));
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,14 @@
|
|||||||
|
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||||
|
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||||
|
|
||||||
|
|
||||||
|
if (ENABLE_PYTHON)
|
||||||
|
add_library(callback OBJECT
|
||||||
|
callback_manager.cc
|
||||||
|
py_ds_callback.cc
|
||||||
|
)
|
||||||
|
else ()
|
||||||
|
add_library(callback OBJECT
|
||||||
|
callback_manager.cc
|
||||||
|
)
|
||||||
|
endif ()
|
@ -0,0 +1,160 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/callback_manager.h"
|
||||||
|
#include "minddata/dataset/callback/ds_callback.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) {
|
||||||
|
callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(op);
|
||||||
|
op_ = op;
|
||||||
|
// turn the flag on if callback is set
|
||||||
|
enabled_ = !callbacks_.empty();
|
||||||
|
|
||||||
|
// error check for each of the callbacks
|
||||||
|
for (auto &cb : callbacks_) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(cb->step_size() > 0, "callback step_size needs to be greater than 0.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::Begin(const CallbackParam &cb_param) {
|
||||||
|
RETURN_OK_IF_TRUE(!enabled_);
|
||||||
|
std::vector<size_t> callback_inds;
|
||||||
|
// go through all callback functions to see if each function is needed
|
||||||
|
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||||
|
if (callbacks_[ind]->IsBeginNeeded()) callback_inds.push_back(ind);
|
||||||
|
}
|
||||||
|
// return Status::OK() if no begin is needed
|
||||||
|
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||||
|
|
||||||
|
// Now do the actual callback
|
||||||
|
for (size_t ind : callback_inds) {
|
||||||
|
RETURN_IF_NOT_OK(callbacks_[ind]->DSBegin(cb_param));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
|
||||||
|
RETURN_OK_IF_TRUE(!enabled_);
|
||||||
|
std::vector<size_t> callback_inds;
|
||||||
|
// go through all callback functions to see if each function is needed
|
||||||
|
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||||
|
if (callbacks_[ind]->IsEpochBeginNeeded()) callback_inds.push_back(ind);
|
||||||
|
}
|
||||||
|
// return Status::OK() if no epoch_begin is needed
|
||||||
|
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||||
|
|
||||||
|
// Now do the actual callback
|
||||||
|
for (size_t ind : callback_inds) {
|
||||||
|
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochBegin(cb_param));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
|
||||||
|
RETURN_OK_IF_TRUE(!enabled_);
|
||||||
|
std::vector<size_t> callback_inds;
|
||||||
|
// go through all callback functions to see if each function is needed
|
||||||
|
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||||
|
if (callbacks_[ind]->IsNStepBeginNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
|
||||||
|
callback_inds.push_back(ind);
|
||||||
|
}
|
||||||
|
// return Status::OK() if no step_begin is needed
|
||||||
|
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||||
|
|
||||||
|
// Now do the actual callback
|
||||||
|
for (size_t ind : callback_inds) {
|
||||||
|
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::End(const CallbackParam &cb_param) {
|
||||||
|
RETURN_OK_IF_TRUE(!enabled_);
|
||||||
|
std::vector<size_t> callback_inds;
|
||||||
|
// go through all callback functions to see if each function is needed
|
||||||
|
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||||
|
if (callbacks_[ind]->IsEndNeeded()) callback_inds.push_back(ind);
|
||||||
|
}
|
||||||
|
// return Status::OK() if no end is needed
|
||||||
|
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||||
|
|
||||||
|
// Now do the actual callback
|
||||||
|
for (size_t ind : callback_inds) {
|
||||||
|
RETURN_IF_NOT_OK(callbacks_[ind]->DSEnd(cb_param));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
|
||||||
|
RETURN_OK_IF_TRUE(!enabled_);
|
||||||
|
std::vector<size_t> callback_inds;
|
||||||
|
// go through all callback functions to see if each function is needed
|
||||||
|
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||||
|
if (callbacks_[ind]->IsEpochEndNeeded()) callback_inds.push_back(ind);
|
||||||
|
}
|
||||||
|
// return Status::OK() if no epoch_end is needed
|
||||||
|
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||||
|
|
||||||
|
// Now do the actual callback
|
||||||
|
for (size_t ind : callback_inds) {
|
||||||
|
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochEnd(cb_param));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
|
||||||
|
RETURN_OK_IF_TRUE(!enabled_);
|
||||||
|
std::vector<size_t> callback_inds;
|
||||||
|
// go through all callback functions to see if each function is needed
|
||||||
|
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||||
|
if (callbacks_[ind]->IsNStepEndNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
|
||||||
|
callback_inds.push_back(ind);
|
||||||
|
}
|
||||||
|
// return Status::OK() if no step_end is needed
|
||||||
|
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||||
|
|
||||||
|
// Now do the actual callback
|
||||||
|
for (size_t ind : callback_inds) {
|
||||||
|
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,79 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/ds_callback.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
// forward declare to avoid cyclic include of dataset_op.h
|
||||||
|
class DatasetOp;
|
||||||
|
|
||||||
|
/// This class manages all the callbacks that are associated with a single DatasetOp. For now, only MapOp supports this.
|
||||||
|
class CallbackManager {
|
||||||
|
public:
|
||||||
|
/// CallbackManager default constructor. Init needs to be called before using the created instance.
|
||||||
|
CallbackManager() : enabled_(false) {}
|
||||||
|
|
||||||
|
/// \brief
|
||||||
|
/// \param [in] callbacks list of callbacks to perform
|
||||||
|
void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks);
|
||||||
|
|
||||||
|
/// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true
|
||||||
|
/// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads
|
||||||
|
/// \return Status
|
||||||
|
Status Init(std::shared_ptr<DatasetOp> op);
|
||||||
|
|
||||||
|
/// \brief callback function called at the start of the first row
|
||||||
|
/// \return Status
|
||||||
|
Status Begin(const CallbackParam &);
|
||||||
|
|
||||||
|
/// \brief callback function called at the start of each epoch
|
||||||
|
/// \return Status
|
||||||
|
Status EpochBegin(const CallbackParam &);
|
||||||
|
|
||||||
|
/// \brief callback function called at the start of each row
|
||||||
|
/// \return Status
|
||||||
|
Status StepBegin(const CallbackParam &);
|
||||||
|
|
||||||
|
/// \brief callback function called after the last row is processed
|
||||||
|
/// \return Status
|
||||||
|
Status End(const CallbackParam &);
|
||||||
|
|
||||||
|
/// \brief callback function called at the end of each epoch
|
||||||
|
/// \return Status
|
||||||
|
Status EpochEnd(const CallbackParam &);
|
||||||
|
|
||||||
|
/// \brief callback function called at the the end of each row
|
||||||
|
/// \return Status
|
||||||
|
Status StepEnd(const CallbackParam &);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool enabled_; // flag to enable callback, if false, all functions would return immediately
|
||||||
|
std::shared_ptr<DatasetOp> op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager
|
||||||
|
std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
|
@ -0,0 +1,40 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
/// Callback Param is the object a DatasetOp uses to pass run-time information to user defined function.
|
||||||
|
/// This is a prototype for now, more fields will be added
|
||||||
|
class CallbackParam {
|
||||||
|
public:
|
||||||
|
CallbackParam(int64_t epoch_num, int64_t cur_epoch_step, int64_t total_step_num)
|
||||||
|
: cur_epoch_num_(epoch_num), cur_epoch_step_num_(cur_epoch_step), cur_step_num_(total_step_num) {}
|
||||||
|
|
||||||
|
// these are constant public fields for easy access and consistency with python cb_param
|
||||||
|
// the names and orders are consistent with batchInfo
|
||||||
|
const int64_t cur_epoch_num_; // current epoch
|
||||||
|
const int64_t cur_epoch_step_num_; // step number of the current epoch
|
||||||
|
const int64_t cur_step_num_; // step number since the first row
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
|
@ -0,0 +1,100 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/callback_param.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
class DSCallback {
|
||||||
|
public:
|
||||||
|
/// \brief constructor of DSCallback, this is the base class for all front end specific callbacks
|
||||||
|
/// \param step_size number of steps to call DSNStepBegin()
|
||||||
|
explicit DSCallback(int32_t step_size = 1) : step_size_(step_size) {}
|
||||||
|
|
||||||
|
/// \brief actual callback function for begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
virtual Status DSBegin(const CallbackParam &cb_param) = 0;
|
||||||
|
|
||||||
|
/// \brief actual callback function for epoch_begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
virtual Status DSEpochBegin(const CallbackParam &cb_param) = 0;
|
||||||
|
|
||||||
|
/// \brief actual callback function for step_begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
virtual Status DSNStepBegin(const CallbackParam &cb_param) = 0;
|
||||||
|
|
||||||
|
/// \brief actual callback function for end, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
virtual Status DSEnd(const CallbackParam &cb_param) = 0;
|
||||||
|
|
||||||
|
/// \brief actual callback function epoch_end begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
virtual Status DSEpochEnd(const CallbackParam &cb_param) = 0;
|
||||||
|
|
||||||
|
/// \brief actual callback function for step_end, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
virtual Status DSNStepEnd(const CallbackParam &cb_param) = 0;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether begin callback is needed
|
||||||
|
/// \return bool
|
||||||
|
virtual bool IsBeginNeeded() = 0;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether epoch_begin callback is needed
|
||||||
|
/// \return bool
|
||||||
|
virtual bool IsEpochBeginNeeded() = 0;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether step_begin callback is needed
|
||||||
|
/// \return bool
|
||||||
|
virtual bool IsNStepBeginNeeded() = 0;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether end callback is needed
|
||||||
|
/// \return bool
|
||||||
|
virtual bool IsEndNeeded() = 0;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether epoch_end callback is needed
|
||||||
|
/// \return bool
|
||||||
|
virtual bool IsEpochEndNeeded() = 0;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether step_end callback is needed
|
||||||
|
/// \return bool
|
||||||
|
virtual bool IsNStepEndNeeded() = 0;
|
||||||
|
|
||||||
|
/// \brief getter
|
||||||
|
/// \return step_size
|
||||||
|
int32_t step_size() const { return step_size_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int32_t step_size_; // step begin/end will be called every step_size_
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
|
@ -0,0 +1,86 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/callback_manager.h"
|
||||||
|
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
Status PyDSCallback::DSBegin(const CallbackParam &cb_param) {
|
||||||
|
return PyDSCallback::ExecutePyfunc(begin_func_, cb_param);
|
||||||
|
}
|
||||||
|
Status PyDSCallback::DSEpochBegin(const CallbackParam &cb_param) {
|
||||||
|
return PyDSCallback::ExecutePyfunc(epoch_begin_func_, cb_param);
|
||||||
|
}
|
||||||
|
Status PyDSCallback::DSNStepBegin(const CallbackParam &cb_param) {
|
||||||
|
return PyDSCallback::ExecutePyfunc(step_begin_func_, cb_param);
|
||||||
|
}
|
||||||
|
Status PyDSCallback::DSEnd(const CallbackParam &cb_param) { return PyDSCallback::ExecutePyfunc(end_func_, cb_param); }
|
||||||
|
|
||||||
|
Status PyDSCallback::DSEpochEnd(const CallbackParam &cb_param) {
|
||||||
|
return PyDSCallback::ExecutePyfunc(epoch_end_func_, cb_param);
|
||||||
|
}
|
||||||
|
Status PyDSCallback::DSNStepEnd(const CallbackParam &cb_param) {
|
||||||
|
return PyDSCallback::ExecutePyfunc(step_end_func_, cb_param);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PyDSCallback::IsBeginNeeded() { return begin_needed_; }
|
||||||
|
bool PyDSCallback::IsEpochBeginNeeded() { return epoch_begin_needed_; }
|
||||||
|
bool PyDSCallback::IsNStepBeginNeeded() { return step_begin_needed_; }
|
||||||
|
bool PyDSCallback::IsNStepEndNeeded() { return step_end_needed_; }
|
||||||
|
bool PyDSCallback::IsEpochEndNeeded() { return epoch_end_needed_; }
|
||||||
|
bool PyDSCallback::IsEndNeeded() { return end_needed_; }
|
||||||
|
|
||||||
|
Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param) {
|
||||||
|
{
|
||||||
|
// Acquire Python GIL
|
||||||
|
py::gil_scoped_acquire gil_acquire;
|
||||||
|
if (Py_IsInitialized() == 0) {
|
||||||
|
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||||
|
}
|
||||||
|
f(cb_param);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
void PyDSCallback::setBegin(py::function f) {
|
||||||
|
begin_func_ = f;
|
||||||
|
begin_needed_ = true;
|
||||||
|
}
|
||||||
|
void PyDSCallback::setEnd(py::function f) {
|
||||||
|
end_func_ = f;
|
||||||
|
end_needed_ = true;
|
||||||
|
}
|
||||||
|
void PyDSCallback::setEpochBegin(py::function f) {
|
||||||
|
epoch_begin_func_ = f;
|
||||||
|
epoch_begin_needed_ = true;
|
||||||
|
}
|
||||||
|
void PyDSCallback::setEpochEnd(py::function f) {
|
||||||
|
epoch_end_func_ = f;
|
||||||
|
epoch_end_needed_ = true;
|
||||||
|
}
|
||||||
|
void PyDSCallback::setStepBegin(py::function f) {
|
||||||
|
step_begin_func_ = f;
|
||||||
|
step_begin_needed_ = true;
|
||||||
|
}
|
||||||
|
void PyDSCallback::setStepEnd(py::function f) {
|
||||||
|
step_end_func_ = f;
|
||||||
|
step_end_needed_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,130 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/ds_callback.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
class PyDSCallback : public DSCallback {
|
||||||
|
public:
|
||||||
|
/// \brief constructor for PyDSCallback. This callback is for python front end
|
||||||
|
explicit PyDSCallback(int32_t step_size = 1)
|
||||||
|
: DSCallback(step_size),
|
||||||
|
begin_needed_(false),
|
||||||
|
epoch_begin_needed_(false),
|
||||||
|
step_begin_needed_(false),
|
||||||
|
end_needed_(false),
|
||||||
|
epoch_end_needed_(false),
|
||||||
|
step_end_needed_(false) {}
|
||||||
|
|
||||||
|
void setBegin(py::function f);
|
||||||
|
void setEnd(py::function f);
|
||||||
|
void setEpochBegin(py::function f);
|
||||||
|
void setEpochEnd(py::function f);
|
||||||
|
void setStepBegin(py::function f);
|
||||||
|
void setStepEnd(py::function f);
|
||||||
|
|
||||||
|
/// \brief actual callback function for begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
Status DSBegin(const CallbackParam &cb_param) override;
|
||||||
|
|
||||||
|
/// \brief actual callback function for epoch_begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
Status DSEpochBegin(const CallbackParam &cb_param) override;
|
||||||
|
|
||||||
|
/// \brief actual callback function for step_begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
Status DSNStepBegin(const CallbackParam &cb_param) override;
|
||||||
|
|
||||||
|
/// \brief actual callback function for end, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
Status DSEnd(const CallbackParam &cb_param) override;
|
||||||
|
|
||||||
|
/// \brief actual callback function epoch_end begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
Status DSEpochEnd(const CallbackParam &cb_param) override;
|
||||||
|
|
||||||
|
/// \brief actual callback function for step_end, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
Status DSNStepEnd(const CallbackParam &cb_param) override;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether begin callback is needed
|
||||||
|
/// \return bool
|
||||||
|
bool IsBeginNeeded() override;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether epoch_begin callback is needed
|
||||||
|
/// \return bool
|
||||||
|
bool IsEpochBeginNeeded() override;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether step_begin callback is needed
|
||||||
|
/// \return bool
|
||||||
|
bool IsNStepBeginNeeded() override;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether end callback is needed
|
||||||
|
/// \return bool
|
||||||
|
bool IsEndNeeded() override;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether epoch_end callback is needed
|
||||||
|
/// \return bool
|
||||||
|
bool IsEpochEndNeeded() override;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether step_end callback is needed
|
||||||
|
/// \return bool
|
||||||
|
bool IsNStepEndNeeded() override;
|
||||||
|
|
||||||
|
/// \brief helper function to acquire GIL then execute a pyfunc
|
||||||
|
/// \param f the python function
|
||||||
|
/// \param cb_param
|
||||||
|
/// \return Status
|
||||||
|
static Status ExecutePyfunc(py::function f, const CallbackParam &cb_param);
|
||||||
|
|
||||||
|
private:
|
||||||
|
py::function begin_func_;
|
||||||
|
py::function epoch_begin_func_;
|
||||||
|
py::function step_begin_func_;
|
||||||
|
py::function end_func_;
|
||||||
|
py::function epoch_end_func_;
|
||||||
|
py::function step_end_func_;
|
||||||
|
|
||||||
|
bool begin_needed_;
|
||||||
|
bool epoch_begin_needed_;
|
||||||
|
bool step_begin_needed_;
|
||||||
|
bool end_needed_;
|
||||||
|
bool epoch_end_needed_;
|
||||||
|
bool step_end_needed_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
@ -0,0 +1,18 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""init file for python callback"""
|
||||||
|
from .ds_callback import DSCallback, WaitedDSCallback
|
||||||
|
|
||||||
|
__all__ = ["DSCallback", "WaitedDSCallback"]
|
@ -0,0 +1,232 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Python callback class
|
||||||
|
"""
|
||||||
|
import threading
|
||||||
|
from mindspore._c_dataengine import PyDSCallback
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
from .validators import check_callback
|
||||||
|
|
||||||
|
|
||||||
|
class DSCallback:
|
||||||
|
"""
|
||||||
|
Abstract base class used to build a dataset callback class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_size (int, optional): The number of steps before the step_begin and step_end are called (Default=1).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class PrintInfo(DSCallback):
|
||||||
|
>>> def ds_epoch_end(self, ds_run_context):
|
||||||
|
>>> print(cb_params.cur_epoch_num)
|
||||||
|
>>> print(cb_params.cur_step_num)
|
||||||
|
>>>
|
||||||
|
>>> data = data.map(operations=op, callbacks=PrintInfo())
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_callback
|
||||||
|
def __init__(self, step_size=1):
|
||||||
|
self.step_size = step_size
|
||||||
|
|
||||||
|
def ds_begin(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called before the data pipeline is started.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context (RunContext): Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def ds_epoch_begin(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called before a new epoch is started.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context (RunContext): Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def ds_epoch_end(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called after an epoch is finished.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context (RunContext): Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def ds_step_begin(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called before n steps are started.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context (RunContext): Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def ds_step_end(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called after n steps are finished.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context (RunContext): Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create_runtime_obj(self):
|
||||||
|
"""
|
||||||
|
Creates a runtime (C++) object from the callback methods defined by the user.
|
||||||
|
|
||||||
|
Returns: _c_dataengine.PyDSCallback
|
||||||
|
"""
|
||||||
|
c_cb = PyDSCallback(self.step_size)
|
||||||
|
at_least_one = False
|
||||||
|
|
||||||
|
if self.__class__.ds_begin != DSCallback.ds_begin:
|
||||||
|
c_cb.set_begin(self.ds_begin)
|
||||||
|
at_least_one = True
|
||||||
|
|
||||||
|
if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin:
|
||||||
|
c_cb.set_epoch_begin(self.ds_epoch_begin)
|
||||||
|
at_least_one = True
|
||||||
|
if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end:
|
||||||
|
c_cb.set_epoch_end(self.ds_epoch_end)
|
||||||
|
at_least_one = True
|
||||||
|
|
||||||
|
if self.__class__.ds_step_begin != DSCallback.ds_step_begin:
|
||||||
|
c_cb.set_step_begin(self.ds_step_begin)
|
||||||
|
at_least_one = True
|
||||||
|
if self.__class__.ds_step_end != DSCallback.ds_step_end:
|
||||||
|
c_cb.set_step_end(self.ds_step_end)
|
||||||
|
at_least_one = True
|
||||||
|
|
||||||
|
if not at_least_one:
|
||||||
|
raise AttributeError("Provided Callback class did not override any of the 6 callback methods.")
|
||||||
|
|
||||||
|
return c_cb
|
||||||
|
|
||||||
|
|
||||||
|
class WaitedDSCallback(Callback, DSCallback):
|
||||||
|
"""
|
||||||
|
Abstract base class used to build a dataset callback class that are synchronized with the training callback.
|
||||||
|
|
||||||
|
This class can be used to execute a user defined logic right after the previous step or epoch.
|
||||||
|
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> my_cb = MyWaitedCallback(32)
|
||||||
|
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
|
||||||
|
>>> data = data.batch(32)
|
||||||
|
>>> # define the model
|
||||||
|
>>> model.train(epochs, data, callbacks=[my_cb])
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_size: the number of rows in each step.
|
||||||
|
Usually the step size will be equal to the batch size (Default=1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, step_size=1):
|
||||||
|
super().__init__()
|
||||||
|
self.step_size = step_size
|
||||||
|
self.step_event = threading.Event()
|
||||||
|
self.step_run_context = None
|
||||||
|
|
||||||
|
self.epoch_event = threading.Event()
|
||||||
|
self.epoch_run_context = None
|
||||||
|
|
||||||
|
def sync_epoch_begin(self, train_run_context, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called before a new dataset epoch is started and after the previous training epoch is ended.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_run_context: Include some information of the model with feedback from the previous epoch.
|
||||||
|
ds_run_context: Include some information of the dataset pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def sync_step_begin(self, train_run_context, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called before a new dataset step is started and after the previous training step is ended.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_run_context: Include some information of the model with feedback from the previous step.
|
||||||
|
ds_run_context: Include some information of the dataset pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context: Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.epoch_run_context = run_context
|
||||||
|
self.epoch_event.set()
|
||||||
|
self.epoch_event.clear()
|
||||||
|
|
||||||
|
def ds_epoch_begin(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context: Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
if ds_run_context.cur_epoch_num > 1:
|
||||||
|
if self.epoch_run_context is None:
|
||||||
|
self.epoch_event.wait()
|
||||||
|
self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
|
||||||
|
self.epoch_run_context = None
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context: Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.step_run_context = run_context
|
||||||
|
self.step_event.set()
|
||||||
|
self.step_event.clear()
|
||||||
|
|
||||||
|
def ds_step_begin(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context: Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
if ds_run_context.cur_step_num > self.step_size:
|
||||||
|
if self.step_run_context is None:
|
||||||
|
self.step_event.wait()
|
||||||
|
self.sync_step_begin(self.step_run_context, ds_run_context)
|
||||||
|
self.step_run_context = None
|
||||||
|
|
||||||
|
def create_runtime_obj(self):
|
||||||
|
"""
|
||||||
|
Creates a runtime (C++) object from the callback methods defined by the user. This method is internal.
|
||||||
|
|
||||||
|
Returns: _c_dataengine.PyDSCallback
|
||||||
|
"""
|
||||||
|
c_cb = PyDSCallback(self.step_size)
|
||||||
|
at_least_one = False
|
||||||
|
|
||||||
|
if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin:
|
||||||
|
c_cb.set_step_begin(self.ds_step_begin)
|
||||||
|
at_least_one = True
|
||||||
|
|
||||||
|
if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin:
|
||||||
|
c_cb.set_epoch_begin(self.ds_epoch_begin)
|
||||||
|
at_least_one = True
|
||||||
|
|
||||||
|
if not at_least_one:
|
||||||
|
raise AttributeError("Provided Callback class did not override any of the 2 callback methods.")
|
||||||
|
|
||||||
|
return c_cb
|
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License foNtest_resr the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
Built-in validators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from ..core.validator_helpers import parse_user_args, check_pos_int32
|
||||||
|
|
||||||
|
|
||||||
|
def check_callback(method):
|
||||||
|
"""check the input arguments of DSCallback."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
[step_size], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
check_pos_int32(step_size, "step_size")
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue