|
|
|
@ -162,14 +162,19 @@ template <DeviceType Device>
|
|
|
|
|
class CrossMapNormalFunc : public FunctionBase {
|
|
|
|
|
public:
|
|
|
|
|
void init(const FuncConfig& config) override {
|
|
|
|
|
// function arguments
|
|
|
|
|
size_ = config.get<size_t>("size");
|
|
|
|
|
scale_ = config.get<real>("scale");
|
|
|
|
|
pow_ = config.get<real>("pow");
|
|
|
|
|
|
|
|
|
|
// number of inputs and outputs
|
|
|
|
|
numInputs_ = 1;
|
|
|
|
|
numOutputs_ = 2;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ((size_t)1, inputs.size());
|
|
|
|
|
CHECK_EQ((size_t)2, outputs.size());
|
|
|
|
|
CHECK_EQ((size_t)numInputs_, inputs.size());
|
|
|
|
|
CHECK_EQ((size_t)numOutputs_, outputs.size());
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
|
|
|
|
|
CHECK(inputs[0].shape() == outputs[0].shape());
|
|
|
|
@ -236,14 +241,19 @@ template <DeviceType Device>
|
|
|
|
|
class CrossMapNormalGradFunc : public FunctionBase {
|
|
|
|
|
public:
|
|
|
|
|
void init(const FuncConfig& config) override {
|
|
|
|
|
// function arguments
|
|
|
|
|
size_ = config.get<size_t>("size");
|
|
|
|
|
scale_ = config.get<real>("scale");
|
|
|
|
|
pow_ = config.get<real>("pow");
|
|
|
|
|
|
|
|
|
|
// number of inputs and outputs
|
|
|
|
|
numInputs_ = 4;
|
|
|
|
|
numOutputs_ = 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
|
|
|
|
CHECK_EQ((size_t)4, inputs.size());
|
|
|
|
|
CHECK_EQ((size_t)1, outputs.size());
|
|
|
|
|
CHECK_EQ((size_t)numInputs_, inputs.size());
|
|
|
|
|
CHECK_EQ((size_t)numOutputs_, outputs.size());
|
|
|
|
|
|
|
|
|
|
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
|
|
|
|
|
CHECK(inputs[0].shape() == inputs[1].shape());
|
|
|
|
|