|
|
|
@ -31,7 +31,7 @@ class ConditionalVAE(Cell):
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor
|
|
|
|
|
should be :math:`(N, hidden_size)`.
|
|
|
|
|
should be :math:`(N, hidden\_size)`.
|
|
|
|
|
The latent_size should be less than or equal to the hidden_size.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -42,8 +42,8 @@ class ConditionalVAE(Cell):
|
|
|
|
|
num_classes(int): The number of classes.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - the same shape as the input of encoder.
|
|
|
|
|
- **input_y** (Tensor) - the tensor of the target data, the shape is :math:`(N, 1)`.
|
|
|
|
|
- **input_x** (Tensor) - the same shape as the input of encoder, the shape is :math:`(N, C, H, W)`.
|
|
|
|
|
- **input_y** (Tensor) - the tensor of the target data, the shape is :math:`(N,)`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
- **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
|
|
|
|
@ -99,7 +99,7 @@ class ConditionalVAE(Cell):
|
|
|
|
|
Randomly sample from latent space to generate sample.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
sample_y (Tensor): Define the label of sample, int tensor.
|
|
|
|
|
sample_y (Tensor): Define the label of sample, int tensor, the shape is (generate_nums, ).
|
|
|
|
|
generate_nums (int): The number of samples to generate.
|
|
|
|
|
shape(tuple): The shape of sample, it should be (generate_nums, C, H, W) or (-1, C, H, W).
|
|
|
|
|
|
|
|
|
@ -121,8 +121,8 @@ class ConditionalVAE(Cell):
|
|
|
|
|
Reconstruct sample from original data.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (Tensor): The input tensor to be reconstructed.
|
|
|
|
|
y (Tensor): The label of the input tensor.
|
|
|
|
|
x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W).
|
|
|
|
|
y (Tensor): The label of the input tensor, the shape is (N,).
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor, the reconstructed sample.
|
|
|
|
|