|
|
|
@ -32,12 +32,12 @@ class TestTensorLayout : public UT::Common {
|
|
|
|
|
virtual void TearDown() {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void ReshapeExpandDeviceArrangementTestFunction(const DeviceArrangement& in_device_arrangement_shape,
|
|
|
|
|
const TensorMap& in_tensor_map_shape,
|
|
|
|
|
const TensorShape& in_tensor_shape_shape,
|
|
|
|
|
const DeviceArrangement& out_device_arrangement_shape,
|
|
|
|
|
const TensorMap& out_tensor_map_shape,
|
|
|
|
|
const TensorShape& out_tensor_shape_shape) {
|
|
|
|
|
void ReshapeExpandDeviceArrangementTestFunction(const DeviceArrangement &in_device_arrangement_shape,
|
|
|
|
|
const TensorMap &in_tensor_map_shape,
|
|
|
|
|
const TensorShape &in_tensor_shape_shape,
|
|
|
|
|
const DeviceArrangement &out_device_arrangement_shape,
|
|
|
|
|
const TensorMap &out_tensor_map_shape,
|
|
|
|
|
const TensorShape &out_tensor_shape_shape) {
|
|
|
|
|
Arrangement device_arrangement;
|
|
|
|
|
Status status = device_arrangement.Init(in_device_arrangement_shape);
|
|
|
|
|
ASSERT_EQ(Status::SUCCESS, status);
|
|
|
|
@ -154,12 +154,10 @@ TEST_F(TestTensorLayout, ReshapeExpandDeviceArrangement5) {
|
|
|
|
|
tensor_map_expect, tensor_shape_expect);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ExpandTensorShapeTestFunction(const DeviceArrangement& in_device_arrangement_shape,
|
|
|
|
|
const TensorMap& in_tensor_map_shape,
|
|
|
|
|
const TensorShape& in_tensor_shape_shape,
|
|
|
|
|
const DeviceArrangement& out_device_arrangement_shape,
|
|
|
|
|
const TensorMap& out_tensor_map_shape,
|
|
|
|
|
const TensorShape& out_tensor_shape_shape) {
|
|
|
|
|
void ExpandTensorShapeTestFunction(const DeviceArrangement &in_device_arrangement_shape,
|
|
|
|
|
const TensorMap &in_tensor_map_shape, const TensorShape &in_tensor_shape_shape,
|
|
|
|
|
const DeviceArrangement &out_device_arrangement_shape,
|
|
|
|
|
const TensorMap &out_tensor_map_shape, const TensorShape &out_tensor_shape_shape) {
|
|
|
|
|
Arrangement device_arrangement;
|
|
|
|
|
Status status = device_arrangement.Init(in_device_arrangement_shape);
|
|
|
|
|
ASSERT_EQ(Status::SUCCESS, status);
|
|
|
|
@ -251,12 +249,12 @@ TEST_F(TestTensorLayout, UpdateTensorMap) {
|
|
|
|
|
ASSERT_EQ(in_tensor_map, new_tensor_map);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RemoveElementEqualToOneInDeviceArrangementTestFunction(const DeviceArrangement& in_device_arrangement_shape,
|
|
|
|
|
const TensorMap& in_tensor_map_shape,
|
|
|
|
|
const TensorShape& in_tensor_shape_shape,
|
|
|
|
|
const DeviceArrangement& out_device_arrangement_shape,
|
|
|
|
|
const TensorMap& out_tensor_map_shape,
|
|
|
|
|
const TensorShape& out_tensor_shape_shape) {
|
|
|
|
|
void RemoveElementEqualToOneInDeviceArrangementTestFunction(const DeviceArrangement &in_device_arrangement_shape,
|
|
|
|
|
const TensorMap &in_tensor_map_shape,
|
|
|
|
|
const TensorShape &in_tensor_shape_shape,
|
|
|
|
|
const DeviceArrangement &out_device_arrangement_shape,
|
|
|
|
|
const TensorMap &out_tensor_map_shape,
|
|
|
|
|
const TensorShape &out_tensor_shape_shape) {
|
|
|
|
|
Arrangement device_arrangement;
|
|
|
|
|
Status status = device_arrangement.Init(in_device_arrangement_shape);
|
|
|
|
|
ASSERT_EQ(Status::SUCCESS, status);
|
|
|
|
@ -310,15 +308,82 @@ TEST_F(TestTensorLayout, RemoveElementEqualToOneInDeviceArrangement3) {
|
|
|
|
|
device_arrangement, tensor_map, tensor_shape, device_arrangement_expect, tensor_map_expect, tensor_shape_new);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestTensorLayout, RemoveElementEqualToOneInDeviceArrangement4) {
|
|
|
|
|
DeviceArrangement device_arrangement = {1, 1, 1};
|
|
|
|
|
TensorMap tensor_map = {2, 1};
|
|
|
|
|
TensorShape tensor_shape = {128, 4096};
|
|
|
|
|
DeviceArrangement device_arrangement_expect = {};
|
|
|
|
|
TensorMap tensor_map_expect = {-1, -1};
|
|
|
|
|
TensorShape tensor_shape_new = {128, 4096};
|
|
|
|
|
RemoveElementEqualToOneInDeviceArrangementTestFunction(
|
|
|
|
|
device_arrangement, tensor_map, tensor_shape, device_arrangement_expect, tensor_map_expect, tensor_shape_new);
|
|
|
|
|
/*
|
|
|
|
|
* example:
|
|
|
|
|
* device_arrangement = [8, 4],
|
|
|
|
|
* tensor_map = [1, 0],
|
|
|
|
|
* tensor_shape = [512, 1024],
|
|
|
|
|
*/
|
|
|
|
|
TEST_F(TestTensorLayout, GenerateOptShardSliceShape1) {
|
|
|
|
|
Arrangement device_arrangement;
|
|
|
|
|
device_arrangement.Init({8, 4});
|
|
|
|
|
Map tensor_map;
|
|
|
|
|
tensor_map.Init({1, 0});
|
|
|
|
|
Arrangement tensor_shape;
|
|
|
|
|
tensor_shape.Init({512, 1024});
|
|
|
|
|
TensorLayout tensor_layout;
|
|
|
|
|
tensor_layout.Init(device_arrangement, tensor_map, tensor_shape);
|
|
|
|
|
ASSERT_EQ(Status::FAILED, tensor_layout.GenerateOptShardSliceShape());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
* example:
|
|
|
|
|
* device_arrangement = [8, 4],
|
|
|
|
|
* tensor_map = [-1, 0],
|
|
|
|
|
* tensor_shape = [512, 1024],
|
|
|
|
|
*/
|
|
|
|
|
TEST_F(TestTensorLayout, GenerateOptShardSliceShape2) {
|
|
|
|
|
Arrangement device_arrangement;
|
|
|
|
|
device_arrangement.Init({8, 4});
|
|
|
|
|
Map tensor_map;
|
|
|
|
|
tensor_map.Init({-1, 0});
|
|
|
|
|
Arrangement tensor_shape;
|
|
|
|
|
tensor_shape.Init({512, 1024});
|
|
|
|
|
TensorLayout tensor_layout;
|
|
|
|
|
tensor_layout.Init(device_arrangement, tensor_map, tensor_shape);
|
|
|
|
|
ASSERT_EQ(Status::SUCCESS, tensor_layout.GenerateOptShardSliceShape());
|
|
|
|
|
|
|
|
|
|
Shape slice_shape_expect = {64, 256};
|
|
|
|
|
ASSERT_EQ(tensor_layout.opt_shard_slice_shape(), slice_shape_expect);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
* example:
|
|
|
|
|
* device_arrangement = [4, 4, 2],
|
|
|
|
|
* tensor_map = [1, 0],
|
|
|
|
|
* tensor_shape = [512, 1024],
|
|
|
|
|
*/
|
|
|
|
|
TEST_F(TestTensorLayout, GenerateOptShardSliceShape3) {
|
|
|
|
|
Arrangement device_arrangement;
|
|
|
|
|
device_arrangement.Init({4, 4, 2});
|
|
|
|
|
Map tensor_map;
|
|
|
|
|
tensor_map.Init({1, 0});
|
|
|
|
|
Arrangement tensor_shape;
|
|
|
|
|
tensor_shape.Init({512, 1024});
|
|
|
|
|
TensorLayout tensor_layout;
|
|
|
|
|
tensor_layout.Init(device_arrangement, tensor_map, tensor_shape);
|
|
|
|
|
ASSERT_EQ(Status::SUCCESS, tensor_layout.GenerateOptShardSliceShape());
|
|
|
|
|
|
|
|
|
|
Shape slice_shape_expect = {32, 512};
|
|
|
|
|
ASSERT_EQ(tensor_layout.opt_shard_slice_shape(), slice_shape_expect);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
* example:
|
|
|
|
|
* device_arrangement = [4, 4, 2],
|
|
|
|
|
* tensor_map = [1, 0],
|
|
|
|
|
* tensor_shape = [20, 1024],
|
|
|
|
|
*/
|
|
|
|
|
TEST_F(TestTensorLayout, GenerateOptShardSliceShape4) {
|
|
|
|
|
Arrangement device_arrangement;
|
|
|
|
|
device_arrangement.Init({4, 4, 2});
|
|
|
|
|
Map tensor_map;
|
|
|
|
|
tensor_map.Init({1, 0});
|
|
|
|
|
Arrangement tensor_shape;
|
|
|
|
|
tensor_shape.Init({20, 1024});
|
|
|
|
|
TensorLayout tensor_layout;
|
|
|
|
|
tensor_layout.Init(device_arrangement, tensor_map, tensor_shape);
|
|
|
|
|
ASSERT_EQ(Status::FAILED, tensor_layout.GenerateOptShardSliceShape());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace parallel
|
|
|
|
|