Skip to main contentIBM Quantum Documentation
This page is from an old version of Qiskit SDK and does not exist in the latest version. We recommend you migrate to the latest version. See the release notes for more information.
Important

IBM Quantum Platform is moving and this version will be sunset on July 1. To get started on the new platform, read the migration guide.

NumPyDiscriminator

class NumPyDiscriminator(n_features=1, n_out=1)

GitHub

Discriminator based on NumPy

Parameters

  • n_features (int) – Dimension of input data vector.
  • n_out (int) – Dimension of the discriminator’s output vector.

Attributes

discriminator_net

Get discriminator

Returns

discriminator object

Return type

DiscriminatorNet


Methods

get_label

NumPyDiscriminator.get_label(x, detach=False)

Get data sample labels, i.e. true or fake.

Parameters

  • x (numpy.ndarray) – Discriminator input, i.e. data sample.
  • detach (bool) – depreciated for numpy network

Returns

Discriminator output, i.e. data label

Return type

numpy.ndarray

load_model

NumPyDiscriminator.load_model(load_dir)

Load discriminator model

Parameters

load_dir (str) – file with stored pytorch discriminator model to be loaded

loss

NumPyDiscriminator.loss(x, y, weights=None)

Loss function :param x: sample label (equivalent to discriminator output) :type x: numpy.ndarray :param y: target label :type y: numpy.ndarray :param weights: customized scaling for each sample (optional) :type weights: numpy.ndarray

Returns

loss function

Return type

float

save_model

NumPyDiscriminator.save_model(snapshot_dir)

Save discriminator model

Parameters

snapshot_dir (str) – directory path for saving the model

set_seed

NumPyDiscriminator.set_seed(seed)

Set seed. :param seed: seed :type seed: int

train

NumPyDiscriminator.train(data, weights, penalty=False, quantum_instance=None, shots=None)

Perform one training step w.r.t to the discriminator’s parameters

Parameters

  • data (tuple(numpy.ndarray, numpy.ndarray)) – real_batch: array, Training data batch. generated_batch: array, Generated data batch.
  • weights (tuple) – real problem, generated problem
  • penalty (bool) – Depreciated for classical networks.
  • quantum_instance (QuantumInstance) – Depreciated for classical networks.
  • shots (int) – Number of shots for hardware or qasm execution. Ignored for classical networks.

Returns

with Discriminator loss and updated parameters.

Return type

dict

Was this page helpful?
Report a bug or request content on GitHub.