diff --git a/tensorflow_quantum/core/ops/batch_util_test.py b/tensorflow_quantum/core/ops/batch_util_test.py index 074b01fcb..c702338d6 100644 --- a/tensorflow_quantum/core/ops/batch_util_test.py +++ b/tensorflow_quantum/core/ops/batch_util_test.py @@ -308,6 +308,28 @@ def test_no_circuit(self, sim): self.assertDTypeEqual(results, np.int8) self.assertEqual(np.zeros(shape=(0, 0, 0)).shape, results.shape) + def test_pauli_sum_collector_collect(self): + """Test the collect method of TFQPauliSumCollector.""" + qubit = cirq.GridQubit(0, 0) + circuit = cirq.Circuit(cirq.X(qubit)) + samples_per_term = 100 + sampler = cirq.Simulator() + + test_cases = [ + ("Pauli observable (Z)", cirq.PauliSum.wrap(cirq.Z(qubit)), -1.0), + ("Identity observable", cirq.PauliSum.wrap(cirq.I(qubit)), 1.0), + ("Mixed observable (Z + 2.0*I)", + cirq.Z(qubit) + 2.0 * cirq.I(qubit), 1.0), + ] + + for name, observable, expected_energy in test_cases: + with self.subTest(name): + collector = batch_util.TFQPauliSumCollector( + circuit, observable, samples_per_term=samples_per_term) + collector.collect(sampler) + self.assertAlmostEqual(collector.estimated_energy(), + expected_energy) + if __name__ == '__main__': tf.test.main()