Skip to content
This repository has been archived by the owner on Jan 30, 2023. It is now read-only.

Commit

Permalink
Add a bit of typing to manifold code
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasdiez authored and Matthias Koeppe committed Mar 28, 2022
1 parent 43474c9 commit 51ee92b
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 63 deletions.
2 changes: 1 addition & 1 deletion src/sage/manifolds/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -3037,7 +3037,7 @@ def plot(self, chart=None, ambient_coords=None, mapping=None,
from sage.plot.graphics import Graphics
from sage.plot.line import line
from sage.manifolds.continuous_map import ContinuousMap
from .utilities import set_axes_labels
from sage.manifolds.utilities import set_axes_labels

# Extract the kwds options
max_range = kwds['max_range']
Expand Down
3 changes: 2 additions & 1 deletion src/sage/manifolds/differentiable/degenerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sage.rings.infinity import infinity
from sage.manifolds.structure import DegenerateStructure
from sage.manifolds.differentiable.manifold import DifferentiableManifold
from sage.manifolds.differentiable.metric import DegenerateMetric

###############################################################################

Expand Down Expand Up @@ -143,7 +144,7 @@ def __init__(self, n, name, metric_name=None, signature=None,
self._metric_latex_name = metric_latex_name

def metric(self, name=None, signature=None, latex_name=None,
dest_map=None):
dest_map=None) -> DegenerateMetric:
r"""
Return the metric giving the null manifold structure to the
manifold, or define a new metric tensor on the manifold.
Expand Down
3 changes: 2 additions & 1 deletion src/sage/manifolds/differentiable/degenerate_submanifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@
TangentTensor)
from sage.manifolds.differentiable.differentiable_submanifold import \
DifferentiableSubmanifold
from sage.manifolds.differentiable.metric import DegenerateMetric
from sage.manifolds.differentiable.vectorfield_module import VectorFieldModule
from sage.rings.infinity import infinity
from sage.matrix.constructor import matrix
Expand Down Expand Up @@ -590,7 +591,7 @@ def screen(self, name, screen, rad, latex_name=None):
self._default_screen = self._screens[name]
return self._screens[name]

def induced_metric(self):
def induced_metric(self) -> DegenerateMetric:
r"""
Return the pullback of the ambient metric.
Expand Down
17 changes: 10 additions & 7 deletions src/sage/manifolds/differentiable/diff_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,14 @@
# *****************************************************************************

from __future__ import annotations
from typing import Union, TYPE_CHECKING
from typing import Optional, Union, TYPE_CHECKING
from sage.misc.cachefunc import cached_method
from sage.tensor.modules.free_module_alt_form import FreeModuleAltForm
from sage.manifolds.differentiable.tensorfield import TensorField
from sage.manifolds.differentiable.tensorfield_paral import TensorFieldParal

if TYPE_CHECKING:
from sage.manifolds.differentiable.vectorfield_module import VectorFieldModule
from sage.manifolds.differentiable.metric import PseudoRiemannianMetric
from sage.manifolds.differentiable.symplectic_form import SymplecticForm

Expand Down Expand Up @@ -384,7 +385,7 @@ def _del_derived(self):
self.exterior_derivative.clear_cache()

@cached_method
def exterior_derivative(self):
def exterior_derivative(self) -> DiffForm:
r"""
Compute the exterior derivative of ``self``.
Expand Down Expand Up @@ -590,7 +591,7 @@ def wedge(self, other: DiffForm) -> DiffForm:
other_r._restrictions[dom])
return resu

def degree(self):
def degree(self) -> int:
r"""
Return the degree of ``self``.
Expand All @@ -615,7 +616,9 @@ def degree(self):

def hodge_dual(
self,
nondegenerate_tensor: Union[PseudoRiemannianMetric, SymplecticForm, None] = None,
nondegenerate_tensor: Union[
PseudoRiemannianMetric, SymplecticForm, None
] = None,
) -> DiffForm:
r"""
Compute the Hodge dual of the differential form with respect to some non-degenerate
Expand Down Expand Up @@ -1228,8 +1231,8 @@ class DiffFormParal(FreeModuleAltForm, TensorFieldParal, DiffForm):
no symmetry; no antisymmetry
"""
def __init__(self, vector_field_module, degree, name=None,
latex_name=None):
def __init__(self, vector_field_module: VectorFieldModule, degree: int, name: Optional[str] = None,
latex_name: Optional[str] = None):
r"""
Construct a differential form.
Expand Down Expand Up @@ -1382,7 +1385,7 @@ def __call__(self, *args):
return TensorFieldParal.__call__(self, *args)

@cached_method
def exterior_derivative(self):
def exterior_derivative(self) -> DiffFormParal:
r"""
Compute the exterior derivative of ``self``.
Expand Down
7 changes: 6 additions & 1 deletion src/sage/manifolds/differentiable/manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,12 +443,15 @@
from typing import Optional, TYPE_CHECKING
from sage.categories.manifolds import Manifolds
from sage.categories.homset import Hom
from sage.manifolds.differentiable.diff_map import DiffMap
from sage.rings.cc import CC
from sage.rings.real_mpfr import RR
from sage.rings.infinity import infinity, minus_infinity
from sage.rings.integer import Integer
from sage.manifolds.manifold import TopologicalManifold
from sage.manifolds.differentiable.mixed_form_algebra import MixedFormAlgebra
from sage.manifolds.differentiable.vectorfield_module import VectorFieldModule
from sage.manifolds.differentiable.metric import PseudoRiemannianMetric

if TYPE_CHECKING:
from sage.manifolds.differentiable.vectorfield_module import VectorFieldModule
Expand Down Expand Up @@ -3945,7 +3948,9 @@ def affine_connection(self, name, latex_name=None):
AffineConnection
return AffineConnection(self, name, latex_name)

def metric(self, name, signature=None, latex_name=None, dest_map=None):
def metric(self, name: str, signature: Optional[int] = None,
latex_name: Optional[str] = None,
dest_map: Optional[DiffMap] = None) -> PseudoRiemannianMetric:
r"""
Define a pseudo-Riemannian metric on the manifold.
Expand Down
14 changes: 10 additions & 4 deletions src/sage/manifolds/differentiable/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,18 @@
# the License, or (at your option) any later version.
# https://www.gnu.org/licenses/
# *****************************************************************************

from __future__ import annotations
from typing import TYPE_CHECKING
from sage.rings.integer import Integer

from typing import TYPE_CHECKING, overload

from sage.manifolds.differentiable.tensorfield import TensorField
from sage.manifolds.differentiable.tensorfield_paral import TensorFieldParal
from sage.rings.integer import Integer

if TYPE_CHECKING:
from sage.manifolds.differentiable.diff_form import DiffForm


class PseudoRiemannianMetric(TensorField):
r"""
Pseudo-Riemannian metric with values on an open subset of a
Expand Down Expand Up @@ -1624,7 +1626,11 @@ def sqrt_abs_det(self, frame=None):
self._sqrt_abs_dets[frame] = resu
return self._sqrt_abs_dets[frame]

def volume_form(self, contra: int = 0) -> TensorField:
@overload
def volume_form(self) -> DiffForm: ...
@overload
def volume_form(self, contra: int) -> TensorField: ...
def volume_form(self, contra=0):
r"""
Volume form (Levi-Civita tensor) `\epsilon` associated with the metric.
Expand Down
57 changes: 41 additions & 16 deletions src/sage/manifolds/differentiable/tensorfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,23 @@
# *****************************************************************************

from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Tuple, Union

from sage.rings.integer import Integer
from sage.rings.integer_ring import ZZ
from sage.structure.element import ModuleElementWithMutability
from sage.tensor.modules.free_module_tensor import FreeModuleTensor
from sage.tensor.modules.tensor_with_indices import TensorWithIndices

from typing import Optional, TYPE_CHECKING, Tuple, Union

if TYPE_CHECKING:
from sage.manifolds.differentiable.manifold import DifferentiableManifold
from sage.manifolds.differentiable.diff_map import DiffMap
from sage.manifolds.differentiable.manifold import DifferentiableManifold
from sage.manifolds.differentiable.metric import PseudoRiemannianMetric
from sage.manifolds.differentiable.symplectic_form import SymplecticForm
from sage.manifolds.differentiable.poisson_tensor import PoissonTensorField
from sage.manifolds.differentiable.symplectic_form import SymplecticForm
from sage.manifolds.differentiable.vectorfield_module import VectorFieldModule
from sage.tensor.modules.comp import Components


TensorType = Tuple[int, int]
Expand Down Expand Up @@ -402,8 +405,17 @@ class TensorField(ModuleElementWithMutability):
ValueError: the name of an immutable element cannot be changed
"""
def __init__(self, vector_field_module, tensor_type, name=None,
latex_name=None, sym=None, antisym=None, parent=None):

def __init__(
self,
vector_field_module: VectorFieldModule,
tensor_type: TensorType,
name: Optional[str] = None,
latex_name: Optional[str] = None,
sym=None,
antisym=None,
parent=None,
):
r"""
Construct a tensor field.
Expand Down Expand Up @@ -840,7 +852,7 @@ def domain(self) -> DifferentiableManifold:
"""
return self._domain

def base_module(self):
def base_module(self) -> VectorFieldModule:
r"""
Return the vector field module on which ``self`` acts as a tensor.
Expand All @@ -867,7 +879,7 @@ def base_module(self):
"""
return self._vmodule

def tensor_type(self) -> Tuple[int, int]:
def tensor_type(self) -> TensorType:
r"""
Return the tensor type of ``self``.
Expand Down Expand Up @@ -968,7 +980,7 @@ def set_immutable(self):
rst.set_immutable()
super().set_immutable()

def set_restriction(self, rst):
def set_restriction(self, rst: TensorField):
r"""
Define a restriction of ``self`` to some subdomain.
Expand Down Expand Up @@ -1035,7 +1047,9 @@ def set_restriction(self, rst):
latex_name=self._latex_name)
self._is_zero = False # a priori

def restrict(self, subdomain: DifferentiableManifold, dest_map: Optional[DiffMap] = None) -> TensorField:
def restrict(
self, subdomain: DifferentiableManifold, dest_map: Optional[DiffMap] = None
) -> TensorField:
r"""
Return the restriction of ``self`` to some subdomain.
Expand Down Expand Up @@ -1385,7 +1399,7 @@ def _add_comp_unsafe(self, basis=None):
rst = self.restrict(basis._domain, dest_map=basis._dest_map)
return rst._add_comp_unsafe(basis)

def add_comp(self, basis=None):
def add_comp(self, basis=None) -> Components:
r"""
Return the components of ``self`` in a given vector frame
for assignment.
Expand Down Expand Up @@ -3718,7 +3732,13 @@ def at(self, point):
if point in dom:
return rst.at(point)

def up(self, non_degenerate_form: Union['PseudoRiemannianMetric', 'SymplecticForm', 'PoissonTensorField'], pos: Optional[int] = None) -> 'TensorField':
def up(
self,
non_degenerate_form: Union[
"PseudoRiemannianMetric", "SymplecticForm", "PoissonTensorField"
],
pos: Optional[int] = None,
) -> "TensorField":
r"""
Compute a dual of the tensor field by raising some index with the
given tensor field (usually, a pseudo-Riemannian metric, a symplectic form or a Poisson tensor).
Expand Down Expand Up @@ -3857,14 +3877,15 @@ def up(self, non_degenerate_form: Union['PseudoRiemannianMetric', 'SymplecticFor
k = result._tensor_type[0]
result = result.up(non_degenerate_form, k)
return result
if pos<n_con or pos>self._tensor_rank-1:

if pos < n_con or pos > self._tensor_rank - 1:
print("pos = {}".format(pos))
raise ValueError("position out of range")

from sage.manifolds.differentiable.metric import PseudoRiemannianMetric
from sage.manifolds.differentiable.symplectic_form import SymplecticForm
from sage.manifolds.differentiable.poisson_tensor import PoissonTensorField

if isinstance(non_degenerate_form, PseudoRiemannianMetric):
return self.contract(pos, non_degenerate_form.inverse(), 1)
elif isinstance(non_degenerate_form, SymplecticForm):
Expand All @@ -3874,7 +3895,11 @@ def up(self, non_degenerate_form: Union['PseudoRiemannianMetric', 'SymplecticFor
else:
raise ValueError("The non-degenerate form has to be a metric, a symplectic form or a Poisson tensor field")

def down(self, non_degenerate_form: Union[PseudoRiemannianMetric, SymplecticForm], pos: Optional[int] = None) -> TensorField:
def down(
self,
non_degenerate_form: Union[PseudoRiemannianMetric, SymplecticForm],
pos: Optional[int] = None,
) -> TensorField:
r"""
Compute a dual of the tensor field by lowering some index with a
given non-degenerate form (pseudo-Riemannian metric or symplectic form).
Expand Down Expand Up @@ -4009,7 +4034,7 @@ def down(self, non_degenerate_form: Union[PseudoRiemannianMetric, SymplecticForm
result = self
for p in range(n_con):
k = result._tensor_type[0]
result = result.down(non_degenerate_form, k-1)
result = result.down(non_degenerate_form, k - 1)
return result
if not isinstance(pos, (int, Integer)):
raise TypeError("the argument 'pos' must be an integer")
Expand Down
11 changes: 9 additions & 2 deletions src/sage/manifolds/differentiable/tensorfield_paral.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,19 @@
# https://www.gnu.org/licenses/
# *****************************************************************************

from sage.tensor.modules.free_module_tensor import FreeModuleTensor
from __future__ import annotations

from typing import TYPE_CHECKING

from sage.manifolds.chart import Chart
from sage.manifolds.differentiable.tensorfield import TensorField
from sage.parallel.decorate import parallel
from sage.parallel.parallelism import Parallelism
from sage.symbolic.ring import SR
from sage.tensor.modules.free_module_tensor import FreeModuleTensor

if TYPE_CHECKING:
from sage.tensor.modules.comp import Components

class TensorFieldParal(FreeModuleTensor, TensorField):
r"""
Expand Down Expand Up @@ -1122,7 +1129,7 @@ class :class:`~sage.tensor.modules.comp.Components`; if such
# The add_comp operation is performed on the subdomain:
return rst.add_comp(basis=basis)

def comp(self, basis=None, from_basis=None):
def comp(self, basis=None, from_basis=None) -> Components:
r"""
Return the components in a given vector frame.
Expand Down
Loading

0 comments on commit 51ee92b

Please sign in to comment.