Skip to main contentIBM Quantum Documentation
This page is from an old version of Qiskit SDK and does not exist in the latest version. We recommend you migrate to the latest version. See the release notes for more information.
Important

IBM Quantum Platform is moving and this version will be sunset on July 1. To get started on the new platform, read the migration guide.

PyTorchDiscriminator

class PyTorchDiscriminator(n_features=1, n_out=1)

GitHub

Discriminator based on PyTorch

Parameters

  • n_features (int) – Dimension of input data vector.
  • n_out (int) – Dimension of the discriminator’s output vector.

Raises

NameError – Pytorch not installed


Attributes

discriminator_net

Get discriminator

Returns

discriminator object

Return type

object


Methods

get_label

PyTorchDiscriminator.get_label(x, detach=False)

Get data sample labels, i.e. true or fake.

Parameters

  • x (Union(numpy.ndarray, torch.Tensor)) – Discriminator input, i.e. data sample.
  • detach (bool) – if None detach from torch tensor variable (optional)

Returns

Discriminator output, i.e. data label

Return type

torch.Tensor

gradient_penalty

PyTorchDiscriminator.gradient_penalty(x, lambda_=5.0, k=0.01, c=1.0)

Compute gradient penalty for discriminator optimization

Parameters

  • x (numpy.ndarray) – Generated data sample.
  • lambda (float) – Gradient penalty coefficient 1.
  • k (float) – Gradient penalty coefficient 2.
  • c (float) – Gradient penalty coefficient 3.

Returns

Gradient penalty.

Return type

torch.Tensor

load_model

PyTorchDiscriminator.load_model(load_dir)

Load discriminator model

Parameters

load_dir (str) – file with stored pytorch discriminator model to be loaded

loss

PyTorchDiscriminator.loss(x, y, weights=None)

Loss function

Parameters

  • x (torch.Tensor) – Discriminator output.
  • y (torch.Tensor) – Label of the data point
  • weights (torch.Tensor) – Data weights.

Returns

Loss w.r.t to the generated data points.

Return type

torch.Tensor

save_model

PyTorchDiscriminator.save_model(snapshot_dir)

Save discriminator model

Parameters

snapshot_dir (str) – directory path for saving the model

set_seed

PyTorchDiscriminator.set_seed(seed)

Set seed.

Parameters

seed (int) – seed

train

PyTorchDiscriminator.train(data, weights, penalty=True, quantum_instance=None, shots=None)

Perform one training step w.r.t. to the discriminator’s parameters

Parameters

  • data (tuple) – real_batch: torch.Tensor, Training data batch. generated_batch: numpy array, Generated data batch.
  • weights (tuple) – real problem, generated problem
  • penalty (bool) – Indicate whether or not penalty function is applied to the loss function.
  • quantum_instance (QuantumInstance) – Quantum Instance (depreciated)
  • shots (int) – Number of shots for hardware or qasm execution. Not used for classical network (only quantum ones)

Returns

with Discriminator loss (torch.Tensor) and updated parameters (array).

Return type

dict

Was this page helpful?
Report a bug or request content on GitHub.