qiskit.aqua.components.neural_networks.PyTorchDiscriminator
class PyTorchDiscriminator(n_features=1, n_out=1)
Discriminator based on PyTorch
Parameters
- n_features (
int
) – Dimension of input data vector. - n_out (
int
) – Dimension of the discriminator’s output vector.
Raises
MissingOptionalLibraryError – Pytorch not installed
__init__
__init__(n_features=1, n_out=1)
Parameters
- n_features (
int
) – Dimension of input data vector. - n_out (
int
) – Dimension of the discriminator’s output vector.
Raises
MissingOptionalLibraryError – Pytorch not installed
Methods
__init__ ([n_features, n_out]) | type n_featuresint |
get_label (x[, detach]) | Get data sample labels, i.e. true or fake. |
gradient_penalty (x[, lambda_, k, c]) | Compute gradient penalty for discriminator optimization |
load_model (load_dir) | Load discriminator model |
loss (x, y[, weights]) | Loss function |
save_model (snapshot_dir) | Save discriminator model |
set_seed (seed) | Set seed. |
train (data, weights[, penalty, …]) | Perform one training step w.r.t to the discriminator’s parameters |
Attributes
discriminator_net | Get discriminator |
discriminator_net
Get discriminator
Returns
discriminator object
Return type
object
get_label
get_label(x, detach=False)
Get data sample labels, i.e. true or fake.
Parameters
- x (Union(numpy.ndarray, torch.Tensor)) – Discriminator input, i.e. data sample.
- detach (bool) – if None detach from torch tensor variable (optional)
Returns
Discriminator output, i.e. data label
Return type
torch.Tensor
gradient_penalty
gradient_penalty(x, lambda_=5.0, k=0.01, c=1.0)
Compute gradient penalty for discriminator optimization
Parameters
- x (numpy.ndarray) – Generated data sample.
- lambda (float) – Gradient penalty coefficient 1.
- k (float) – Gradient penalty coefficient 2.
- c (float) – Gradient penalty coefficient 3.
Returns
Gradient penalty.
Return type
torch.Tensor
load_model
load_model(load_dir)
Load discriminator model
Parameters
load_dir (str
) – file with stored pytorch discriminator model to be loaded
loss
loss(x, y, weights=None)
Loss function
Parameters
- x (torch.Tensor) – Discriminator output.
- y (torch.Tensor) – Label of the data point
- weights (torch.Tensor) – Data weights.
Returns
Loss w.r.t to the generated data points.
Return type
torch.Tensor
save_model
save_model(snapshot_dir)
Save discriminator model
Parameters
snapshot_dir (str
) – directory path for saving the model
set_seed
set_seed(seed)
Set seed.
Parameters
seed (int
) – seed
train
train(data, weights, penalty=False, quantum_instance=None, shots=None)
Perform one training step w.r.t to the discriminator’s parameters
Parameters
- data (
Iterable
) – Data batch. - weights (
Iterable
) – Data sample weights. - penalty (
bool
) – Indicate whether or not penalty function is applied to the loss function. Ignored if no penalty function defined. - quantum_instance (QuantumInstance) – used to run Quantum network. Ignored for a classical network.
- shots (
Optional
[int
]) – Number of shots for hardware or qasm execution. Ignored for classical network
Returns
with discriminator loss and updated parameters.data, weights, penalty=True,
quantum_instance=None, shots=None) -> Dict[str, Any]:
Return type
dict