Skip to content

Commit

Permalink
(docstrings) apply @deltheil suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Feb 2, 2024
1 parent 84d5796 commit 7307a36
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 17 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ conversion = [
"tqdm>=4.62.3",
]
doc = [
"black>=24.1.1", # required by mkdocs to format the signatures
# required by mkdocs to format the signatures
"black>=24.1.1",
"mkdocs-material>=9.5.6",
"mkdocstrings[python]>=0.24.0",
"mkdocs-literate-nav>=0.6.1",
Expand Down
18 changes: 12 additions & 6 deletions src/refiners/fluxion/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class Lora(fl.Chain, ABC):
"""Low-rank approximation (LoRA) layer.
"""Low-Rank Adaptation (LoRA) layer.
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
Expand Down Expand Up @@ -156,7 +156,7 @@ def auto_attach(
return LoraAdapter(layer, self), parent

def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
"""Load the weights of the LoRA.
"""Load the (pre-trained) weights of the LoRA.
Args:
down_weight: The down weight.
Expand All @@ -169,7 +169,7 @@ def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:


class LinearLora(Lora):
"""Low-rank approximation (LoRA) layer for linear layers.
"""Low-Rank Adaptation (LoRA) layer for linear layers.
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
"""
Expand Down Expand Up @@ -255,7 +255,7 @@ def is_compatible(self, layer: fl.WeightedModule, /) -> bool:


class Conv2dLora(Lora):
"""Low-rank approximation (LoRA) layer for 2D convolutional layers.
"""Low-Rank Adaptation (LoRA) layer for 2D convolutional layers.
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
"""
Expand Down Expand Up @@ -391,12 +391,12 @@ def names(self) -> list[str]:

@property
def loras(self) -> dict[str, Lora]:
"""The LoRA layers."""
"""The LoRA layers indexed by name."""
return {lora.name: lora for lora in self.layers(Lora)}

@property
def scales(self) -> dict[str, float]:
"""The scales of the LoRA layers."""
"""The scales of the LoRA layers indexed by names."""
return {lora.name: lora.scale for lora in self.layers(Lora)}

@scales.setter
Expand All @@ -407,6 +407,9 @@ def scale(self, values: dict[str, float]) -> None:
def add_lora(self, lora: Lora, /) -> None:
"""Add a LoRA layer to the adapter.
Raises:
AssertionError: If the adapter already contains a LoRA layer with the same name.
Args:
lora: The LoRA layer to add.
"""
Expand All @@ -416,6 +419,9 @@ def add_lora(self, lora: Lora, /) -> None:
def remove_lora(self, name: str, /) -> Lora | None:
"""Remove a LoRA layer from the adapter.
Note:
If the adapter doesn't contain a LoRA layer with the given name, nothing happens and `None` is returned.
Args:
name: The name of the LoRA layer to remove.
"""
Expand Down
4 changes: 2 additions & 2 deletions src/refiners/fluxion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def set_context(self, key: str, value: Context) -> None:
self.contexts[key] = value

def get_context(self, key: str) -> Any:
"""Retreive a value from the context.
"""Retrieve a value from the context.
Args:
key: The key of the context.
Expand All @@ -34,7 +34,7 @@ def get_context(self, key: str) -> Any:
return self.contexts.get(key)

def update_contexts(self, new_contexts: Contexts) -> None:
"""Update the contexts with new contexts.
"""Update or set the contexts with new contexts.
Args:
new_contexts: The new contexts.
Expand Down
2 changes: 1 addition & 1 deletion src/refiners/fluxion/layers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ReLU(Activation):
output = relu(tensor)
expected_output = torch.tensor([[0.0, 0.0, 1.0]])
assert torch.allclose(output, expected_output)
assert torch.equal(output, expected_output)
```
"""

Expand Down
10 changes: 6 additions & 4 deletions src/refiners/fluxion/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def scaled_dot_product_attention(
) -> Float[Tensor, "batch source_sequence_length dim"]:
"""Scaled Dot Product Attention.
Optimization depends on which pytorch backend is used.
Note:
Optimization depends on which PyTorch backend is used.
See [[arXiv:1706.03762] Attention Is All You Need (Equation 1)](https://arxiv.org/abs/1706.03762) for more details.
See also [torch.nn.functional.scaled_dot_product_attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""
Expand Down Expand Up @@ -213,7 +215,7 @@ class Attention(Chain):
which transforms the 3 inputs into Query, Key and Value
- a [`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
- a [`Linear`][refiners.fluxion.layers.linear.Linear] layer,
which further transforms the output of the
which projects the output of the
[`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
Receives:
Expand Down Expand Up @@ -461,7 +463,7 @@ def _tensor_2d_to_sequence(
) -> Float[Tensor, "batch height*width channels"]:
"""Transform a 2D Tensor into a sequence.
The height and width of the input Tensor are stored in the context,
The height and width of the input Tensor are stored in a `"reshape"` context,
so that the output Tensor can be transformed back into a 2D Tensor in the `sequence_to_tensor_2d` method.
"""
height, width = x.shape[-2:]
Expand All @@ -480,7 +482,7 @@ def _sequence_to_tensor_2d(
) -> Float[Tensor, "batch channels height width"]:
"""Transform a sequence into a 2D Tensor.
The height and width of the output Tensor are retrieved from the context,
The height and width of the output Tensor are retrieved from the `"reshape"` context,
which was set in the `tensor_2d_to_sequence` method.
"""
height, width = self.use_context("reshape").values()
Expand Down
6 changes: 3 additions & 3 deletions src/refiners/fluxion/layers/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Identity(Module):
tensor = torch.randn(10, 10)
output = identity(tensor)
assert torch.allclose(tensor, output)
assert torch.equal(tensor, output)
```
"""

Expand Down Expand Up @@ -51,9 +51,9 @@ class GetArg(Module):
torch.randn(20, 20),
torch.randn(30, 30),
)
output = get_arg(inputs)
output = get_arg(*inputs)
assert torch.allclose(tensor[1], output)
assert id(inputs[1]) == id(output)
```
"""

Expand Down

0 comments on commit 7307a36

Please sign in to comment.