Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-sgd

revert-3824-remove_grad_op_type
qiaolongfei 8 years ago
commit 230e613c6f

@ -137,9 +137,9 @@ set(EXTERNAL_LIBS
) )
if(WITH_GPU) if(WITH_GPU)
list(APPEND EXTERNAL_LIB ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY}) list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
if(NOT WITH_DSO) if(NOT WITH_DSO)
list(APPEND EXTERNAL_LIB ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY}) list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY})
endif(NOT WITH_DSO) endif(NOT WITH_DSO)
endif(WITH_GPU) endif(WITH_GPU)

@ -32,9 +32,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public: public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").AsNoGradient(); AddInput("X", "Input X of Add").NotInGradient();
AddInput("b", "Bias of Add").AsNoGradient(); AddInput("b", "Bias of Add").NotInGradient();
AddOutput("Out", "Out of Add").AsNoGradient(); AddOutput("Out", "Out of Add").NotInGradient();
AddComment("Add Op"); AddComment("Add Op");
} }
}; };

@ -60,7 +60,7 @@ message OpProto {
optional bool duplicable = 3 [ default = false ]; optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ]; optional bool intermediate = 4 [ default = false ];
optional bool no_gradient = 5 [ default = false ]; optional bool not_in_gradient = 5 [ default = false ];
} }
// AttrProto describes the C++ type Attribute. // AttrProto describes the C++ type Attribute.

@ -28,7 +28,7 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
const auto& src_arg_list = const auto& src_arg_list =
src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
if (arg.no_gradient() && !is_grad) continue; if (arg.not_in_gradient() && !is_grad) continue;
const std::string src_name = arg.name(); const std::string src_name = arg.name();
std::string dst_name = is_grad ? GradVarName(src_name) : src_name; std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
dst_inout[dst_name].reserve(src_inout.at(src_name).size()); dst_inout[dst_name].reserve(src_inout.at(src_name).size());

@ -26,10 +26,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input"); AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient(); AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient();
AddInput("In3_mult", "another multiple input").AsDuplicable(); AddInput("In3_mult", "another multiple input").AsDuplicable();
AddOutput("Out1_mult", "a multiple output").AsDuplicable(); AddOutput("Out1_mult", "a multiple output").AsDuplicable();
AddOutput("Out2", "a single output").AsNoGradient(); AddOutput("Out2", "a single output").NotInGradient();
AddComment("op with inputs and outputs ignored in gradient calculating"); AddComment("op with inputs and outputs ignored in gradient calculating");
} }
}; };

@ -184,11 +184,8 @@ class OpProtoAndCheckerMaker {
return *this; return *this;
} }
// TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it VariableBuilder& NotInGradient() {
// means that input/output is not needed when calculate gradient. It does var_->set_not_in_gradient(true);
// not mean no gradient when backward. It should be changed soon.
VariableBuilder& AsNoGradient() {
var_->set_no_gradient(true);
return *this; return *this;
} }
}; };

@ -57,11 +57,14 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap,
} }
void MKLDNNFcLayer::convertWeightsFromPaddle() { void MKLDNNFcLayer::convertWeightsFromPaddle() {
if (FLAGS_use_mkldnn_wgt) { if (hasInitedWgt_) {
return; return;
} }
if (hasInitedWgt_) { // TODO(TJ): dst format should get from wgtVal_
int dstFmt = PARAM_FORMAT_MKLDNN_OI;
int srcFmt = weight_->getParameterPtr()->getHeaderFormat();
if (srcFmt == dstFmt) {
return; return;
} }
@ -78,6 +81,7 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() {
MatrixPtr paddleWgtT; MatrixPtr paddleWgtT;
paddleWgt->transpose(paddleWgtT, true); paddleWgt->transpose(paddleWgtT, true);
weight_->getW()->copyFrom(*paddleWgtT); weight_->getW()->copyFrom(*paddleWgtT);
weight_->getParameterPtr()->setHeaderFormat(dstFmt);
hasInitedWgt_ = true; hasInitedWgt_ = true;
} }

@ -330,9 +330,7 @@ void MKLDNNTester::run(const TestConfig& dnn,
log_ = log; log_ = log;
lvl_ = level; lvl_ = level;
// Firstly test FLAGS_use_mkldnn_wgt = false // Firstly test mkldnn init from PARAM_FORMAT_ORIGINAL weight
FLAGS_use_mkldnn_wgt = false;
// reset and run once
reset(dnn, ref, batchSize); reset(dnn, ref, batchSize);
randomWgtDatas(); randomWgtDatas();
clearWgtDiffs(); clearWgtDiffs();
@ -342,17 +340,32 @@ void MKLDNNTester::run(const TestConfig& dnn,
runOnce(); runOnce();
} }
// Then test FLAGS_use_mkldnn_wgt = true if (parameters_[DNN].empty()) {
FLAGS_use_mkldnn_wgt = true; // has no paramters
// after run once the mkldnn weight has been stored in dnnlayer return;
}
// After run some iterations, the mkldnn weight has been stored in dnnLayer
// and we can also get the mkldnn weight parameter header format.
// Weight parameter should always be index 0 (and bias index 1).
// TODO(TJ): should also consider mean and var format when batchnorm ready
int dnnWgtFmt = parameters_[DNN][0]->getHeaderFormat();
int refWgtFmt = parameters_[REF][0]->getHeaderFormat();
if (dnnWgtFmt == refWgtFmt) {
// weight format are equal, so no need check more
return;
}
// then save the weights and restart again // then save the weights and restart again
vector<VectorPtr> dnnWgts, refWgts; vector<VectorPtr> dnnWgts, refWgts;
CHECK_EQ(parameters_[DNN].size(), parameters_[REF].size()); CHECK_EQ(parameters_[DNN].size(), parameters_[REF].size());
saveWgt(parameters_[DNN], dnnWgts); saveWgt(parameters_[DNN], dnnWgts);
saveWgt(parameters_[REF], refWgts); saveWgt(parameters_[REF], refWgts);
// restart again with flag true // restart again with dnn weight format
reset(dnn, ref, batchSize); reset(dnn, ref, batchSize);
// TODO(TJ): should also considerate mean and var format when batchnorm ready
parameters_[DNN][0]->setHeaderFormat(dnnWgtFmt);
// restore wgt // restore wgt
restoreWgt(dnnWgts, parameters_[DNN]); restoreWgt(dnnWgts, parameters_[DNN]);

@ -108,7 +108,7 @@ private:
* if many(>failRate) wrong(abs(dnn-ref)/abs(ref)>thres) points return the * if many(>failRate) wrong(abs(dnn-ref)/abs(ref)>thres) points return the
* max(diff/ref) * max(diff/ref)
* else return sum(abs(a-b)) / sum(abs(b)) * else return sum(abs(a-b)) / sum(abs(b))
* The return value should smaller than eps when passing. * The return value should be smaller than eps when passing.
*/ */
double getDelta(const real* d1, double getDelta(const real* d1,
const real* d2, const real* d2,

@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").AsNoGradient(); AddOutput("Out", "The output of mean op").NotInGradient();
AddComment("Mean Operator"); AddComment("Mean Operator");
} }
}; };

@ -55,9 +55,10 @@ class MeanGradKernel : public framework::OpKernel {
IG->mutable_data<T>(context.GetPlace()); IG->mutable_data<T>(context.GetPlace());
T ig_size = (T)framework::product(IG->dims()); T ig_size = (T)framework::product(IG->dims());
Eigen::DSizes<int, 1> bcast(ig_size);
EigenVector<T>::Flatten(*IG).device(context.GetEigenDevice<Place>()) = EigenVector<T>::Flatten(*IG).device(context.GetEigenDevice<Place>()) =
EigenScalar<T>::From(*OG) / ig_size; (EigenVector<T>::From(*OG) / ig_size).broadcast(bcast);
} }
}; };

@ -44,7 +44,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims()); ctx.Output<Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("Y")->dims());
} }
}; };

@ -37,7 +37,7 @@ class SigmoidKernel : public framework::OpKernel {
auto Y = EigenVector<T>::Flatten(*output); auto Y = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp()); Y.device(place) = 1. / (1. + (-X).exp());
} }
}; };

@ -48,7 +48,8 @@ Parameter::Parameter(const ParameterConfig& config, bool useGpu, bool doInit)
deviceId_(-1), deviceId_(-1),
sharedCount_(0), sharedCount_(0),
updateCounter_(0), updateCounter_(0),
updated_(false) { updated_(false),
headerFormat_(PARAM_FORMAT_ORIGINAL) {
setID(-1); /* capture uninitialized id */ setID(-1); /* capture uninitialized id */
if (useGpu_ && FLAGS_parallel_nn) { if (useGpu_ && FLAGS_parallel_nn) {
/* gpu environment is specified by device property */ /* gpu environment is specified by device property */
@ -285,7 +286,7 @@ bool Parameter::save(const std::string& filename) const {
bool Parameter::save(std::ostream& s) const { bool Parameter::save(std::ostream& s) const {
CpuVector vec(*bufs_[PARAMETER_VALUE].get()); CpuVector vec(*bufs_[PARAMETER_VALUE].get());
Header header; Header header;
header.version = kFormatVersion; header.format = headerFormat_;
header.valueSize = sizeof(real); header.valueSize = sizeof(real);
header.size = getSize(); header.size = getSize();
@ -344,8 +345,9 @@ bool Parameter::load(std::istream& s) {
Header header; Header header;
CHECK(s.read(reinterpret_cast<char*>(&header), sizeof(header))) CHECK(s.read(reinterpret_cast<char*>(&header), sizeof(header)))
<< "Fail to read parameter " << getName(); << "Fail to read parameter " << getName();
CHECK_EQ(header.version, kFormatVersion) << "Incorrect format version: " CHECK(isHeaderFormatSupported(header.format)) << "Incorrect format version: "
<< header.version; << header.format;
headerFormat_ = header.format;
CHECK_EQ(header.size, getSize()) CHECK_EQ(header.size, getSize())
<< "The size (" << header.size << ") in the file does not match the size " << "The size (" << header.size << ") in the file does not match the size "
<< "(" << getSize() << ") of the parameter: " << getName(); << "(" << getSize() << ") of the parameter: " << getName();

@ -34,6 +34,20 @@ limitations under the License. */
namespace paddle { namespace paddle {
typedef enum {
/// The paddle original basic format
PARAM_FORMAT_ORIGINAL = 0,
/// See mkldnn_memory_format_t in
/// https://github.com/01org/mkl-dnn/blob/master/include/mkldnn_types.h
/// for a detailed description.
/// 2D weights tensor in the format (output channels, input channels).
PARAM_FORMAT_MKLDNN_OI,
/// The total format items numbers
PARAM_FORMAT_ITEMS,
} PARAM_FORMAT;
class SparsePrefetchRowCpuMatrix; class SparsePrefetchRowCpuMatrix;
class Parameter; class Parameter;
@ -242,14 +256,30 @@ public:
/// Initialize the value to 0 /// Initialize the value to 0
void zeroMem(); void zeroMem();
static const int kFormatVersion = 0;
/// file header structure /// file header structure
struct Header { struct Header {
int32_t version; // = 0, file format version int32_t format; // = PARAM_FORMAT
uint32_t valueSize; // = sizeof(real) uint32_t valueSize; // = sizeof(real)
uint64_t size; // = getSize() uint64_t size; // = getSize()
}; };
/**
* @brief Is the header format supported.
*/
static bool isHeaderFormatSupported(int32_t fmt) {
return fmt < PARAM_FORMAT_ITEMS;
}
/**
* @brief Get the format in header.
*/
int getHeaderFormat() { return headerFormat_; }
/**
* @brief Set the format in header.
*/
void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; }
/** /**
* @brief Parameter Update Hook. * @brief Parameter Update Hook.
* *
@ -321,6 +351,9 @@ protected:
bool updated_; bool updated_;
SparseFormat format_; SparseFormat format_;
/// The header format for saving or loading param
int32_t headerFormat_;
std::vector<std::shared_ptr<IParameterUpdaterHook>> updaterHooks_; std::vector<std::shared_ptr<IParameterUpdaterHook>> updaterHooks_;
public: public:

@ -1032,8 +1032,8 @@ void ParameterServer2::loadValueVector(const LoadValueRequest& request,
Parameter::Header header; Parameter::Header header;
CHECK(fs.read(reinterpret_cast<char*>(&header), sizeof(header))) CHECK(fs.read(reinterpret_cast<char*>(&header), sizeof(header)))
<< "Fail to read parameters in pserver"; << "Fail to read parameters in pserver";
CHECK_EQ(header.version, Parameter::kFormatVersion) CHECK(Parameter::isHeaderFormatSupported(header.format))
<< "Incorrect format version: " << header.version; << "Incorrect format version: " << header.format;
CHECK_EQ(header.size, (size_t)size_) CHECK_EQ(header.size, (size_t)size_)
<< "The size (" << header.size << ") in the file does not match the size " << "The size (" << header.size << ") in the file does not match the size "
<< "(" << size_ << ") of the pserver: " << serverId_; << "(" << size_ << ") of the pserver: " << serverId_;
@ -1063,7 +1063,8 @@ void ParameterServer2::saveValueVector(const SaveValueRequest& request,
CpuVector& vec = vectors_[PARAMETER_APPLY] ? *vectors_[PARAMETER_APPLY] CpuVector& vec = vectors_[PARAMETER_APPLY] ? *vectors_[PARAMETER_APPLY]
: *vectors_[PARAMETER_VALUE]; : *vectors_[PARAMETER_VALUE];
Parameter::Header header; Parameter::Header header;
header.version = Parameter::kFormatVersion; // TODO(TJ): save param headerFormat_
header.format = PARAM_FORMAT_ORIGINAL;
header.valueSize = sizeof(real); header.valueSize = sizeof(real);
header.size = size_; header.size = size_;

@ -29,7 +29,6 @@ DECLARE_bool(with_gpu);
DECLARE_bool(parallel_nn); DECLARE_bool(parallel_nn);
DECLARE_string(config_args); DECLARE_string(config_args);
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
DECLARE_bool(use_mkldnn_wgt);
const char *kConfigParserModuleName = "paddle.trainer.config_parser"; const char *kConfigParserModuleName = "paddle.trainer.config_parser";
const char *kConfigParserFuncName = "parse_config_and_serialize"; const char *kConfigParserFuncName = "parse_config_and_serialize";
@ -47,7 +46,6 @@ TrainerConfigHelper::TrainerConfigHelper(const std::string &configFilePath)
<< ",with_cost=" << FLAGS_with_cost << ",use_gpu=" << FLAGS_use_gpu << ",with_cost=" << FLAGS_with_cost << ",use_gpu=" << FLAGS_use_gpu
<< ",parallel_nn=" << FLAGS_parallel_nn << ",parallel_nn=" << FLAGS_parallel_nn
<< ",use_mkldnn=" << FLAGS_use_mkldnn << ",use_mkldnn=" << FLAGS_use_mkldnn
<< ",use_mkldnn_wgt=" << FLAGS_use_mkldnn_wgt
<< ",cudnn_version=" << hl_get_cudnn_lib_version(); << ",cudnn_version=" << hl_get_cudnn_lib_version();
if (!FLAGS_config_args.empty()) { if (!FLAGS_config_args.empty()) {
configArgs << "," << FLAGS_config_args; configArgs << "," << FLAGS_config_args;

@ -27,7 +27,6 @@ DEFINE_bool(use_mkldnn, false, "Default still keep use CPU training");
DEFINE_bool(use_mkldnn, false, "Only support CPU training"); DEFINE_bool(use_mkldnn, false, "Only support CPU training");
#endif #endif
DEFINE_bool(use_mkldnn_wgt, false, "Init weight from CPU weight");
DEFINE_bool(parallel_nn, DEFINE_bool(parallel_nn,
false, false,
"Whether to use multi-threads to calculate one neural network." "Whether to use multi-threads to calculate one neural network."

@ -41,4 +41,3 @@ DECLARE_string(predict_file);
DECLARE_bool(prev_batch_state); DECLARE_bool(prev_batch_state);
DECLARE_string(init_model_path); DECLARE_string(init_model_path);
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
DECLARE_bool(use_mkldnn_wgt);

@ -26,3 +26,4 @@ py_test(test_operator SRCS test_operator.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_sgd_op SRCS test_sgd_op.py)
py_test(test_gradient_checker SRCS test_gradient_checker.py)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,43 @@
import unittest
import numpy
from paddle.v2.framework.op import Operator
from gradient_checker import GradientChecker
from gradient_checker import get_numeric_gradient
class GetNumericGradientTest(unittest.TestCase):
def test_add_op(self):
add_op = Operator('add_two', X="X", Y="Y", Out="Z")
x = numpy.random.random((10, 1)).astype("float32")
y = numpy.random.random((10, 1)).astype("float32")
arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X')
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4)
def test_softmax_op(self):
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - numpy.max(x)
exps = numpy.exp(shiftx)
return exps / numpy.sum(exps)
def label_softmax_grad(Y, dY):
dX = Y * 0.0
for i in range(Y.shape[0]):
d = numpy.dot(Y[i, :], dY[i, :])
dX[i, :] = Y[i, :] * (dY[i, :] - d)
return dX
softmax_op = Operator("softmax", X="X", Y="Y")
X = numpy.random.random((2, 2)).astype("float32")
Y = numpy.apply_along_axis(stable_softmax, 1, X)
dY = numpy.ones(Y.shape)
dX = label_softmax_grad(Y, dY)
arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X')
numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2)
if __name__ == '__main__':
unittest.main()

@ -1,5 +1,6 @@
import unittest import unittest
from op_test_util import OpTestMeta from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
import numpy as np import numpy as np
@ -12,5 +13,12 @@ class TestMeanOp(unittest.TestCase):
self.outputs = {'Out': np.mean(self.inputs['X'])} self.outputs = {'Out': np.mean(self.inputs['X'])}
class MeanGradOpTest(GradientChecker):
def test_normal(self):
op = create_op("mean")
inputs = {"X": np.random.random((10, 10)).astype("float32")}
self.check_grad(op, inputs, set("X"), "Out")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

@ -1,6 +1,7 @@
import unittest import unittest
from op_test_util import OpTestMeta
import numpy as np import numpy as np
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
class TestSigmoidOp(unittest.TestCase): class TestSigmoidOp(unittest.TestCase):
@ -8,12 +9,20 @@ class TestSigmoidOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "sigmoid" self.type = "sigmoid"
self.inputs = {'X': np.random.random((32, 100)).astype("float32")} self.inputs = {'X': np.random.random((15, 31)).astype("float32")}
self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}
#class TestSigmoidGradOp(unittest.TestCase): class TestSigmoidGradOp(GradientChecker):
#TODO(qingqing) add unit test def test_grad(self):
op = create_op("sigmoid")
inputs = {"X": np.random.uniform(0.1, 1, [11, 17]).astype("float32")}
# compare gpu and cpu results for backward op.
# this test will be skiped if only compiling CPU version.
self.compare_grad(op, inputs)
# check gradients
self.check_grad(op, inputs, set("X"), "Y", max_relative_error=0.007)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save