Merge pull request #16741 from colourful-tree/dev
	
		
	
				
					
				
			add continuous value model oprevert-16839-cmakelist_change
						commit
						434caab21b
					
				@ -0,0 +1,154 @@
 | 
				
			||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/operators/cvm_op.h"
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include "paddle/fluid/operators/math/math_function.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
using Tensor = framework::Tensor;
 | 
				
			||||
 | 
				
			||||
class CVMOp : public framework::OperatorWithKernel {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::OperatorWithKernel::OperatorWithKernel;
 | 
				
			||||
 | 
				
			||||
  void InferShape(framework::InferShapeContext* ctx) const override {
 | 
				
			||||
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
 | 
				
			||||
    PADDLE_ENFORCE(ctx->HasInput("CVM"), "Input(CVM) should be not null.");
 | 
				
			||||
    PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
 | 
				
			||||
 | 
				
			||||
    auto x_dims = ctx->GetInputDim("X");
 | 
				
			||||
    auto cvm_dims = ctx->GetInputDim("CVM");
 | 
				
			||||
    PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2.");
 | 
				
			||||
    PADDLE_ENFORCE_EQ(cvm_dims.size(), 2UL, "Input(CVM)'s rank should be 2.");
 | 
				
			||||
    PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL,
 | 
				
			||||
                      "The 2nd dimension of "
 | 
				
			||||
                      "Input(CVM) should be 2.");
 | 
				
			||||
 | 
				
			||||
    if (ctx->Attrs().Get<bool>("use_cvm")) {
 | 
				
			||||
      ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]});
 | 
				
			||||
    } else {
 | 
				
			||||
      ctx->SetOutputDim("Y", {x_dims[0], x_dims[1] - 2});
 | 
				
			||||
    }
 | 
				
			||||
    ctx->ShareLoD("X", /*->*/ "Y");
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  // Explicitly set that the data type of computation kernel of
 | 
				
			||||
  // cvm
 | 
				
			||||
  // is determined by its input "X".
 | 
				
			||||
  framework::OpKernelType GetExpectedKernelType(
 | 
				
			||||
      const framework::ExecutionContext& ctx) const override {
 | 
				
			||||
    return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
 | 
				
			||||
                                   platform::CPUPlace());
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class CVMGradientOp : public framework::OperatorWithKernel {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::OperatorWithKernel::OperatorWithKernel;
 | 
				
			||||
 | 
				
			||||
  void InferShape(framework::InferShapeContext* ctx) const override {
 | 
				
			||||
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
 | 
				
			||||
    PADDLE_ENFORCE(ctx->HasInput("CVM"), "Input(CVM) should be not null.");
 | 
				
			||||
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
 | 
				
			||||
                   "Input(Y@GRAD) should be not null.");
 | 
				
			||||
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
 | 
				
			||||
                   "Output(X@GRAD) should be not null.");
 | 
				
			||||
 | 
				
			||||
    auto x_dims = ctx->GetInputDim("X");
 | 
				
			||||
    auto cvm_dims = ctx->GetInputDim("CVM");
 | 
				
			||||
    auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
 | 
				
			||||
    PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2.");
 | 
				
			||||
    PADDLE_ENFORCE_EQ(cvm_dims.size(), 2, "Input(CVM)'s rank should be 2.");
 | 
				
			||||
 | 
				
			||||
    PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0],
 | 
				
			||||
                      "The 1st dimension of Input(X) and Input(Y@Grad) should "
 | 
				
			||||
                      "be equal.");
 | 
				
			||||
 | 
				
			||||
    PADDLE_ENFORCE_EQ(cvm_dims[1], 2,
 | 
				
			||||
                      "When Attr(soft_label) == false, the 2nd dimension of "
 | 
				
			||||
                      "Input(CVM) should be 2.");
 | 
				
			||||
    ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
 | 
				
			||||
    ctx->ShareLoD("X", framework::GradVarName("X"));
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  // Explicitly set that the data type of computation kernel of
 | 
				
			||||
  // cvm
 | 
				
			||||
  // is determined by its input "X".
 | 
				
			||||
  framework::OpKernelType GetExpectedKernelType(
 | 
				
			||||
      const framework::ExecutionContext& ctx) const override {
 | 
				
			||||
    return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
 | 
				
			||||
                                   platform::CPUPlace());
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class CVMOpMaker : public framework::OpProtoAndCheckerMaker {
 | 
				
			||||
 public:
 | 
				
			||||
  void Make() override {
 | 
				
			||||
    AddInput("X",
 | 
				
			||||
             "(LodTensor, default LodTensor<float>), a 2-D tensor with shape "
 | 
				
			||||
             "[N x D],"
 | 
				
			||||
             " where N is the batch size and D is the emebdding dim. ");
 | 
				
			||||
    AddInput("CVM",
 | 
				
			||||
             "(Tensor),  a 2-D Tensor with shape [N x 2], where N is the batch "
 | 
				
			||||
             "size, 2 is show and click.");
 | 
				
			||||
    AddOutput("Y",
 | 
				
			||||
              "(LodTensor, default LodTensor<float>), a 2-D tensor with shape "
 | 
				
			||||
              "[N x K].");
 | 
				
			||||
    AddAttr<bool>("use_cvm", "bool, use cvm or not").SetDefault(true);
 | 
				
			||||
    AddComment(R"DOC(
 | 
				
			||||
CVM Operator.
 | 
				
			||||
 | 
				
			||||
      We assume that input X is a embedding vector with cvm_feature(show and click), which shape is [N * D] (D is 2(cvm_feature) + embedding dim, N is batch_size)
 | 
				
			||||
      if use_cvm is True, we will log(cvm_feature), and output shape is [N * D].
 | 
				
			||||
      if use_cvm is False, we will remove cvm_feature from input, and output shape is [N * (D - 2)].
 | 
				
			||||
 | 
				
			||||
)DOC");
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class CVMGradOpDescMaker : public framework::SingleGradOpDescMaker {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  std::unique_ptr<framework::OpDesc> Apply() const override {
 | 
				
			||||
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
 | 
				
			||||
    op->SetType("cvm_grad");
 | 
				
			||||
    op->SetInput("X", Input("X"));
 | 
				
			||||
    op->SetInput("CVM", Input("CVM"));
 | 
				
			||||
    op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
 | 
				
			||||
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
 | 
				
			||||
    op->SetAttrMap(Attrs());
 | 
				
			||||
    return op;
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
REGISTER_OPERATOR(cvm, ops::CVMOp, ops::CVMOpMaker, ops::CVMGradOpDescMaker);
 | 
				
			||||
 | 
				
			||||
REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp);
 | 
				
			||||
 | 
				
			||||
REGISTER_OP_CPU_KERNEL(cvm, ops::CVMOpKernel<float>, ops::CVMOpKernel<double>);
 | 
				
			||||
 | 
				
			||||
REGISTER_OP_CPU_KERNEL(cvm_grad, ops::CVMGradOpKernel<float>,
 | 
				
			||||
                       ops::CVMGradOpKernel<double>);
 | 
				
			||||
@ -0,0 +1,105 @@
 | 
				
			||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
#include "paddle/fluid/framework/eigen.h"
 | 
				
			||||
#include "paddle/fluid/framework/op_registry.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
using Tensor = framework::Tensor;
 | 
				
			||||
using LoDTensor = framework::LoDTensor;
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
class CVMOpKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& context) const override {
 | 
				
			||||
    const LoDTensor* x = context.Input<LoDTensor>("X");
 | 
				
			||||
    const T* x_data = x->data<T>();
 | 
				
			||||
    auto lod = x->lod()[0];
 | 
				
			||||
    int64_t item_size = x->numel() / x->dims()[0];
 | 
				
			||||
    int offset = 2;
 | 
				
			||||
    if (!context.Attr<bool>("use_cvm")) {
 | 
				
			||||
      item_size -= offset;
 | 
				
			||||
    }
 | 
				
			||||
    LoDTensor* y = context.Output<LoDTensor>("Y");
 | 
				
			||||
    T* y_data = y->mutable_data<T>(context.GetPlace());
 | 
				
			||||
 | 
				
			||||
    int seq_num = static_cast<int>(lod.size()) - 1;
 | 
				
			||||
    for (int i = 0; i < seq_num; ++i) {
 | 
				
			||||
      int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
 | 
				
			||||
 | 
				
			||||
      for (int j = 0; j < seq_len; ++j) {
 | 
				
			||||
        if (context.Attr<bool>("use_cvm")) {
 | 
				
			||||
          std::memcpy(y_data, x_data, item_size * sizeof(T));
 | 
				
			||||
          y_data[0] = log(y_data[0] + 1);
 | 
				
			||||
          y_data[1] = log(y_data[1] + 1) - y_data[0];
 | 
				
			||||
          x_data += item_size;
 | 
				
			||||
          y_data += item_size;
 | 
				
			||||
        } else {
 | 
				
			||||
          std::memcpy(y_data, x_data + offset, item_size * sizeof(T));
 | 
				
			||||
          x_data += item_size + offset;
 | 
				
			||||
          y_data += item_size;
 | 
				
			||||
        }
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
class CVMGradOpKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& context) const override {
 | 
				
			||||
    LoDTensor* dx = context.Output<LoDTensor>(framework::GradVarName("X"));
 | 
				
			||||
    T* dx_data = dx->mutable_data<T>(context.GetPlace());
 | 
				
			||||
 | 
				
			||||
    const Tensor* cvm = context.Input<Tensor>("CVM");
 | 
				
			||||
    const T* cvm_data = cvm->data<T>();
 | 
				
			||||
    int offset = 2;
 | 
				
			||||
    const framework::LoDTensor* dOut =
 | 
				
			||||
        context.Input<framework::LoDTensor>(framework::GradVarName("Y"));
 | 
				
			||||
    const T* dout_data = dOut->data<T>();
 | 
				
			||||
 | 
				
			||||
    auto lod = dx->lod()[0];
 | 
				
			||||
    int64_t item_size = dx->numel() / dx->dims()[0];
 | 
				
			||||
    if (!context.Attr<bool>("use_cvm")) {
 | 
				
			||||
      item_size -= offset;
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    int seq_num = static_cast<int>(lod.size()) - 1;
 | 
				
			||||
    for (int i = 0; i < seq_num; ++i) {
 | 
				
			||||
      int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
 | 
				
			||||
 | 
				
			||||
      for (int j = 0; j < seq_len; ++j) {
 | 
				
			||||
        if (context.Attr<bool>("use_cvm")) {
 | 
				
			||||
          std::memcpy(dx_data, dout_data, item_size * sizeof(T));
 | 
				
			||||
          dx_data[0] = cvm_data[0];
 | 
				
			||||
          dx_data[1] = cvm_data[1];
 | 
				
			||||
          dx_data += item_size;
 | 
				
			||||
          dout_data += item_size;
 | 
				
			||||
        } else {
 | 
				
			||||
          std::memcpy(dx_data + offset, dout_data, item_size * sizeof(T));
 | 
				
			||||
          dx_data[0] = cvm_data[0];
 | 
				
			||||
          dx_data[1] = cvm_data[1];
 | 
				
			||||
          dx_data += item_size + offset;
 | 
				
			||||
          dout_data += item_size;
 | 
				
			||||
        }
 | 
				
			||||
      }
 | 
				
			||||
      cvm_data += offset;
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,47 @@
 | 
				
			||||
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
 | 
				
			||||
import numpy as np
 | 
				
			||||
from math import log
 | 
				
			||||
from math import exp
 | 
				
			||||
from op_test import OpTest
 | 
				
			||||
import unittest
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestCVMOp(OpTest):
 | 
				
			||||
    """
 | 
				
			||||
        Test cvm op with discrete one-hot labels.
 | 
				
			||||
    """
 | 
				
			||||
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        self.op_type = "cvm"
 | 
				
			||||
        batch_size = 4
 | 
				
			||||
        dims = 11
 | 
				
			||||
        lod = [[1]]
 | 
				
			||||
        self.inputs = {
 | 
				
			||||
            'X': (np.random.uniform(0, 1, [1, dims]).astype("float32"), lod),
 | 
				
			||||
            'CVM': np.array([[0.6, 0.4]]).astype("float32"),
 | 
				
			||||
        }
 | 
				
			||||
        self.attrs = {'use_cvm': False}
 | 
				
			||||
        out = []
 | 
				
			||||
        for index, emb in enumerate(self.inputs["X"][0]):
 | 
				
			||||
            out.append(emb[2:])
 | 
				
			||||
        self.outputs = {'Y': (np.array(out), lod)}
 | 
				
			||||
 | 
				
			||||
    def test_check_output(self):
 | 
				
			||||
        self.check_output()
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
if __name__ == '__main__':
 | 
				
			||||
    unittest.main()
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue