diff --git a/toqito/matrix_ops/__init__.py b/toqito/matrix_ops/__init__.py index 746c49472..ccde0a897 100644 --- a/toqito/matrix_ops/__init__.py +++ b/toqito/matrix_ops/__init__.py @@ -7,3 +7,5 @@ from toqito.matrix_ops.vectors_to_gram_matrix import vectors_to_gram_matrix from toqito.matrix_ops.calculate_vector_matrix_dimension import calculate_vector_matrix_dimension from toqito.matrix_ops.to_density_matrix import to_density_matrix +from toqito.matrix_ops.perturb_vectors import perturb_vectors + diff --git a/toqito/matrix_ops/perturb_vectors.py b/toqito/matrix_ops/perturb_vectors.py new file mode 100644 index 000000000..eba25d31e --- /dev/null +++ b/toqito/matrix_ops/perturb_vectors.py @@ -0,0 +1,36 @@ +"""Perturb vectors is used to add a small random number to each element of a vector. + +A random value is added sampled from a normal distribution scaled by `eps`. +""" + +import numpy as np + + +def perturb_vectors(vectors: list[np.ndarray], eps: float = 0.1) -> list[np.ndarray]: + """Perturb the vectors by adding a small random number to each element. + + :param vectors: List of vectors to perturb. + :param eps: Amount by which to perturb vectors. + :return: Resulting list of perturbed vectors by a factor of epsilon. + + Example: + ========== + + >>> from toqito.matrix_ops import perturb_vectors + >>> import numpy as np + >>> vectors = [np.array([1.0, 2.0]), np.array([3.0, 4.0])] + >>> perturb_vectors(vectors, eps=0.1) # doctest: +SKIP + array([[0.47687587, 0.87897065], + [0.58715549, 0.80947417]]) + + """ + perturbed_vectors: list[np.ndarray] = [] + for i, v in enumerate(vectors): + if eps == 0: + perturbed_vectors.append(v) + else: + perturbed_vectors.append(v + np.random.randn(v.shape[0]) * eps) + + # Normalize the vectors after perturbing them. + perturbed_vectors[i] = perturbed_vectors[i] / np.linalg.norm(perturbed_vectors[i]) + return np.array(perturbed_vectors) diff --git a/toqito/matrix_ops/tests/test_perturb_vector.py b/toqito/matrix_ops/tests/test_perturb_vector.py new file mode 100644 index 000000000..ee0182efd --- /dev/null +++ b/toqito/matrix_ops/tests/test_perturb_vector.py @@ -0,0 +1,67 @@ +"""Test perturb vectors.""" + +import numpy as np +import pytest + +from toqito.matrix_ops import perturb_vectors + + +@pytest.mark.parametrize( + "vectors, eps, expected_length", + [ + # Test with three vectors along the axes + ([np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])], 0.1, 3), + # Test with one vector and eps=0 + ([np.array([1, 1, 0])], 0.0, 1), + # Test with two vectors and a different perturbation value + ([np.array([1, 0, 0]), np.array([0, 1, 0])], 0.2, 2), + ], +) +def test_output_size(vectors, eps, expected_length): + """Test that the function returns the same number of vectors as input.""" + perturbed_vectors = perturb_vectors(vectors, eps) + assert len(perturbed_vectors) == expected_length + + +@pytest.mark.parametrize( + "vectors, eps", + [ + ([np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])], 0.1), + ([np.array([1, 1, 1])], 0.1), + ], +) +def test_normalization(vectors, eps): + """Test that each perturbed vector is normalized.""" + perturbed_vectors = perturb_vectors(vectors, eps) + for pv in perturbed_vectors: + norm = np.linalg.norm(pv) + assert np.isclose(norm, 1.0, atol=1e-5) + + +@pytest.mark.parametrize( + "vectors, eps", + [ + ([np.array([1, 0, 0]), np.array([0, 1, 0])], 0.1), + ([np.array([1, 1, 1])], 0.2), + ], +) +def test_perturbation_effect(vectors, eps): + """Test that the perturbed vectors are different from the original vectors.""" + perturbed_vectors = perturb_vectors(vectors, eps) + for i in range(len(vectors)): + assert not np.array_equal(vectors[i], perturbed_vectors[i]) + + +@pytest.mark.parametrize( + "vectors", + [ + ([np.array([1, 0, 0]), np.array([0, 1, 0])]), + ([np.array([1, 1, 1])]), + ], +) +def test_zero_perturbation(vectors): + """Test that if eps = 0, the vectors remain the same.""" + perturbed_vectors = perturb_vectors(vectors, eps=0.0) + for i in range(len(vectors)): + assert np.allclose(vectors[i], perturbed_vectors[i]), f"Vector {i} does not match." +