You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
98 lines
2.5 KiB
98 lines
2.5 KiB
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
#pragma once
|
|
#include <memory>
|
|
#include "PaddleAPI.h"
|
|
#include "paddle/gserver/evaluators/Evaluator.h"
|
|
#include "paddle/gserver/gradientmachines/GradientMachine.h"
|
|
#include "paddle/parameter/ParameterUpdaterBase.h"
|
|
#include "paddle/trainer/TrainerConfigHelper.h"
|
|
|
|
struct GradientMachinePrivate {
|
|
std::shared_ptr<paddle::GradientMachine> machine;
|
|
|
|
template <typename T>
|
|
inline T& cast(void* ptr) {
|
|
return *(T*)(ptr);
|
|
}
|
|
};
|
|
|
|
struct OptimizationConfigPrivate {
|
|
std::shared_ptr<paddle::TrainerConfigHelper> trainer_config;
|
|
paddle::OptimizationConfig config;
|
|
|
|
const paddle::OptimizationConfig& getConfig() {
|
|
if (trainer_config != nullptr) {
|
|
return trainer_config->getOptConfig();
|
|
} else {
|
|
return config;
|
|
}
|
|
}
|
|
};
|
|
|
|
struct TrainerConfigPrivate {
|
|
std::shared_ptr<paddle::TrainerConfigHelper> conf;
|
|
TrainerConfigPrivate() {}
|
|
};
|
|
|
|
struct ModelConfigPrivate {
|
|
std::shared_ptr<paddle::TrainerConfigHelper> conf;
|
|
};
|
|
|
|
struct ArgumentsPrivate {
|
|
std::vector<paddle::Argument> outputs;
|
|
|
|
inline paddle::Argument& getArg(size_t idx) throw(RangeError) {
|
|
if (idx < outputs.size()) {
|
|
return outputs[idx];
|
|
} else {
|
|
RangeError e;
|
|
throw e;
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
std::shared_ptr<T>& cast(void* rawPtr) const {
|
|
return *(std::shared_ptr<T>*)(rawPtr);
|
|
}
|
|
};
|
|
|
|
struct ParameterUpdaterPrivate {
|
|
std::unique_ptr<paddle::ParameterUpdater> updater;
|
|
};
|
|
|
|
struct ParameterPrivate {
|
|
std::shared_ptr<paddle::Parameter> sharedPtr;
|
|
paddle::Parameter* rawPtr; // rawPtr only used in ParameterUpdater,
|
|
// in other situation sharedPtr should
|
|
// contains value.
|
|
|
|
ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {}
|
|
|
|
paddle::Parameter* getPtr() {
|
|
if (sharedPtr) {
|
|
return sharedPtr.get();
|
|
} else {
|
|
return rawPtr;
|
|
}
|
|
}
|
|
};
|
|
|
|
struct EvaluatorPrivate {
|
|
paddle::Evaluator* rawPtr;
|
|
|
|
EvaluatorPrivate() : rawPtr(nullptr) {}
|
|
~EvaluatorPrivate() { delete rawPtr; }
|
|
};
|