From 7bce77c853c57f3bce540528bc0a20ea72ae172e Mon Sep 17 00:00:00 2001 From: mhucka <1450019+mhucka@users.noreply.github.com> Date: Fri, 13 Mar 2026 05:27:05 +0000 Subject: [PATCH 1/3] =?UTF-8?q?=F0=9F=A7=AA=20Add=20unit=20tests=20for=20T?= =?UTF-8?q?FQPauliSumCollector.collect?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change implements unit tests for the `collect` method in `TFQPauliSumCollector`, addressing a testing gap in `batch_util_test.py`. The new tests cover: - Standard Pauli Z observable (expected energy -1.0) - Identity observable (expected energy 1.0) - Mixed observable (Z + 2.0*I) (expected energy 1.0) This improves the reliability and coverage of the batch utility module. --- .../core/ops/batch_util_test.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tensorflow_quantum/core/ops/batch_util_test.py b/tensorflow_quantum/core/ops/batch_util_test.py index 074b01fcb..fbcda7cf5 100644 --- a/tensorflow_quantum/core/ops/batch_util_test.py +++ b/tensorflow_quantum/core/ops/batch_util_test.py @@ -308,6 +308,39 @@ def test_no_circuit(self, sim): self.assertDTypeEqual(results, np.int8) self.assertEqual(np.zeros(shape=(0, 0, 0)).shape, results.shape) + def test_pauli_sum_collector_collect(self): + """Test the collect method of TFQPauliSumCollector.""" + qubit = cirq.GridQubit(0, 0) + circuit = cirq.Circuit(cirq.X(qubit)) + samples_per_term = 100 + sampler = cirq.Simulator() + + # Case 1: Standard Pauli observable (Z). Expect -1.0. + observable1 = cirq.PauliSum.wrap(cirq.Z(qubit)) + collector1 = batch_util.TFQPauliSumCollector( + circuit, observable1, samples_per_term=samples_per_term) + collector1.collect(sampler) + + pauli_string1 = list(observable1)[0] + pauli_string1 = pauli_string1 / pauli_string1.coefficient + self.assertEqual(collector1._zeros[pauli_string1], 0) + self.assertEqual(collector1._ones[pauli_string1], samples_per_term) + self.assertAlmostEqual(collector1.estimated_energy(), -1.0) + + # Case 2: Identity observable. Expect 1.0. + observable2 = cirq.PauliSum.wrap(cirq.I(qubit)) + collector2 = batch_util.TFQPauliSumCollector( + circuit, observable2, samples_per_term=samples_per_term) + collector2.collect(sampler) + self.assertAlmostEqual(collector2.estimated_energy(), 1.0) + + # Case 3: Mixed observable (Z + 2.0*I). Expect -1.0 + 2.0 = 1.0. + observable3 = cirq.Z(qubit) + 2.0 * cirq.I(qubit) + collector3 = batch_util.TFQPauliSumCollector( + circuit, observable3, samples_per_term=samples_per_term) + collector3.collect(sampler) + self.assertAlmostEqual(collector3.estimated_energy(), 1.0) + if __name__ == '__main__': tf.test.main() From e98571f9528dab77ff15f261eea512150c8cf981 Mon Sep 17 00:00:00 2001 From: Michael Hucka Date: Fri, 13 Mar 2026 09:55:36 -0700 Subject: [PATCH 2/3] Refactor code to simplify and make extendable Good suggestion from Gemini Code Assist. Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../core/ops/batch_util_test.py | 37 ++++++------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/tensorflow_quantum/core/ops/batch_util_test.py b/tensorflow_quantum/core/ops/batch_util_test.py index fbcda7cf5..c034bf939 100644 --- a/tensorflow_quantum/core/ops/batch_util_test.py +++ b/tensorflow_quantum/core/ops/batch_util_test.py @@ -315,31 +315,18 @@ def test_pauli_sum_collector_collect(self): samples_per_term = 100 sampler = cirq.Simulator() - # Case 1: Standard Pauli observable (Z). Expect -1.0. - observable1 = cirq.PauliSum.wrap(cirq.Z(qubit)) - collector1 = batch_util.TFQPauliSumCollector( - circuit, observable1, samples_per_term=samples_per_term) - collector1.collect(sampler) - - pauli_string1 = list(observable1)[0] - pauli_string1 = pauli_string1 / pauli_string1.coefficient - self.assertEqual(collector1._zeros[pauli_string1], 0) - self.assertEqual(collector1._ones[pauli_string1], samples_per_term) - self.assertAlmostEqual(collector1.estimated_energy(), -1.0) - - # Case 2: Identity observable. Expect 1.0. - observable2 = cirq.PauliSum.wrap(cirq.I(qubit)) - collector2 = batch_util.TFQPauliSumCollector( - circuit, observable2, samples_per_term=samples_per_term) - collector2.collect(sampler) - self.assertAlmostEqual(collector2.estimated_energy(), 1.0) - - # Case 3: Mixed observable (Z + 2.0*I). Expect -1.0 + 2.0 = 1.0. - observable3 = cirq.Z(qubit) + 2.0 * cirq.I(qubit) - collector3 = batch_util.TFQPauliSumCollector( - circuit, observable3, samples_per_term=samples_per_term) - collector3.collect(sampler) - self.assertAlmostEqual(collector3.estimated_energy(), 1.0) + test_cases = [ + ("Standard Pauli observable (Z)", cirq.PauliSum.wrap(cirq.Z(qubit)), -1.0), + ("Identity observable", cirq.PauliSum.wrap(cirq.I(qubit)), 1.0), + ("Mixed observable (Z + 2.0*I)", cirq.Z(qubit) + 2.0 * cirq.I(qubit), 1.0), + ] + + for name, observable, expected_energy in test_cases: + with self.subTest(name): + collector = batch_util.TFQPauliSumCollector( + circuit, observable, samples_per_term=samples_per_term) + collector.collect(sampler) + self.assertAlmostEqual(collector.estimated_energy(), expected_energy) if __name__ == '__main__': From 4653e61be247424391a0f65ce9f2028ff5419a2d Mon Sep 17 00:00:00 2001 From: mhucka Date: Mon, 16 Mar 2026 17:56:13 +0000 Subject: [PATCH 3/3] Incorporate suggestion from Gemini Code Assist --- tensorflow_quantum/core/ops/batch_util_test.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow_quantum/core/ops/batch_util_test.py b/tensorflow_quantum/core/ops/batch_util_test.py index c034bf939..c702338d6 100644 --- a/tensorflow_quantum/core/ops/batch_util_test.py +++ b/tensorflow_quantum/core/ops/batch_util_test.py @@ -316,9 +316,10 @@ def test_pauli_sum_collector_collect(self): sampler = cirq.Simulator() test_cases = [ - ("Standard Pauli observable (Z)", cirq.PauliSum.wrap(cirq.Z(qubit)), -1.0), + ("Pauli observable (Z)", cirq.PauliSum.wrap(cirq.Z(qubit)), -1.0), ("Identity observable", cirq.PauliSum.wrap(cirq.I(qubit)), 1.0), - ("Mixed observable (Z + 2.0*I)", cirq.Z(qubit) + 2.0 * cirq.I(qubit), 1.0), + ("Mixed observable (Z + 2.0*I)", + cirq.Z(qubit) + 2.0 * cirq.I(qubit), 1.0), ] for name, observable, expected_energy in test_cases: @@ -326,7 +327,8 @@ def test_pauli_sum_collector_collect(self): collector = batch_util.TFQPauliSumCollector( circuit, observable, samples_per_term=samples_per_term) collector.collect(sampler) - self.assertAlmostEqual(collector.estimated_energy(), expected_energy) + self.assertAlmostEqual(collector.estimated_energy(), + expected_energy) if __name__ == '__main__':