Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ModularDNN model #235

Merged
merged 113 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
113 commits
Select commit Hold shift + click to select a range
966b9c9
Change order of concatenation and subtraction
stefan-apollo Nov 28, 2023
ac7367e
Changed in both functions
stefan-apollo Nov 28, 2023
790752c
Implemented new basis November A (previously called new norm), to be
stefan-apollo Nov 28, 2023
63bdc1c
Hacky implementation of new edges formula
stefan-apollo Nov 28, 2023
491fb17
Refactored integradted_gradient_jacobian and edge_norm
stefan-apollo Nov 29, 2023
7b75387
Adjusted tests because new integrated gradient function includes the …
stefan-apollo Nov 29, 2023
dff29ef
Merge remote-tracking branch 'origin/main' into feature/support_diffe…
stefan-apollo Nov 29, 2023
5367ddc
Change default attribution; better docs; move minus sign
stefan-apollo Nov 29, 2023
e3d550f
MNIST bugfix
stefan-apollo Nov 29, 2023
c389d5f
Implemented new edges formula for MLPs/MNIST
stefan-apollo Nov 29, 2023
dbe9aed
Fixed MNIST for real this time
stefan-apollo Nov 29, 2023
eb2cf99
Update .vscode
stefan-apollo Nov 29, 2023
0a4d35f
Comment for test
stefan-apollo Nov 29, 2023
f375170
Pass through arguments, add docs
stefan-apollo Nov 29, 2023
88d3dbb
Merge remote-tracking branch 'origin/main' into feature/support_diffe…
stefan-apollo Nov 29, 2023
a953ccf
Allow for config options to select new features
stefan-apollo Nov 29, 2023
d2461e5
Add run build tests for new methods (CI slow-down?)
stefan-apollo Nov 29, 2023
e505964
Implemented tests and fixed input_pos != out_pos
stefan-apollo Nov 29, 2023
4a06b76
Added all test combinations
stefan-apollo Nov 30, 2023
0b02a66
Skip fp32 Cs test for new basis
stefan-apollo Nov 30, 2023
c247fa3
Moved module_hat to linalg
stefan-apollo Nov 30, 2023
2aca8f9
Removed integral_boundary_relative_epsilon from everywhere
stefan-apollo Nov 30, 2023
ff50a6d
Remove unneccessary device moving
stefan-apollo Nov 30, 2023
889eb30
Separated functions; also fixed normalization
stefan-apollo Nov 30, 2023
846ac7d
Removed stray copilot kwarg
stefan-apollo Nov 30, 2023
3478830
Fix tests
stefan-apollo Nov 30, 2023
a56dcb8
Rename ig_formula -> basis_formula
stefan-apollo Nov 30, 2023
124588a
Finish rename
stefan-apollo Nov 30, 2023
d4d8de1
Fixed tests
stefan-apollo Nov 30, 2023
6ccfacf
Added config entries for mnist
stefan-apollo Nov 30, 2023
6282cf0
Move module_hat for easier diff / no functional change
stefan-apollo Nov 30, 2023
2edaf90
Fix docstrings
stefan-apollo Nov 30, 2023
4312ea2
Move einsum line
stefan-apollo Nov 30, 2023
9b39da3
Write einsum pattern directly
stefan-apollo Nov 30, 2023
c32f3b3
Parameterize tests
stefan-apollo Nov 30, 2023
3f85669
renamed test
stefan-apollo Nov 30, 2023
448929c
Improved docs
stefan-apollo Nov 30, 2023
38f7b0d
Merge remote-tracking branch 'origin/main' into feature/support_diffe…
stefan-apollo Dec 1, 2023
de20c74
Remove duplicate normalization
danbraunai-apollo Dec 1, 2023
eabc54f
Remove variable_position_dimension variable everywhere
danbraunai-apollo Dec 1, 2023
664ab73
Fix accidental deletion of i_grad from einsum
danbraunai-apollo Dec 1, 2023
65636fb
Simplify edge hook function
danbraunai-apollo Dec 1, 2023
d6270ce
Reduce mnist dataset size in tests
danbraunai-apollo Dec 1, 2023
c2cd73e
Tests for rotate final layer
stefan-apollo Dec 1, 2023
0523dab
Mod add tests too slow
stefan-apollo Dec 1, 2023
4625209
Stray line in tests
stefan-apollo Dec 1, 2023
0ef4bc4
Test naming
stefan-apollo Dec 1, 2023
e64dd63
Remove all mod add tests
stefan-apollo Dec 1, 2023
eeef53b
Implement ModularDNN following Jakes script
stefan-apollo Dec 1, 2023
65a79cc
Merge remote-tracking branch 'origin/main' into feature/implement_mod…
stefan-apollo Dec 1, 2023
48aa4bd
Get it running
stefan-apollo Dec 1, 2023
8478ae8
Clarified docs
stefan-apollo Dec 1, 2023
92981a5
Working plotting script
stefan-apollo Dec 1, 2023
3181f0b
Merge remote-tracking branch 'origin/main' into feature/implement_mod…
stefan-apollo Dec 4, 2023
02bba38
SVD for mod dnn
stefan-apollo Dec 4, 2023
3446703
mlp svd
stefan-apollo Dec 4, 2023
c0ecec6
Merge remote-tracking branch 'origin/main' into feature/implement_mod…
stefan-apollo Dec 4, 2023
20464bc
Adjusted for merge
stefan-apollo Dec 4, 2023
c31b390
Custom plot
stefan-apollo Dec 4, 2023
f41d285
Configs
stefan-apollo Dec 4, 2023
e7a7b6d
Refactor
stefan-apollo Dec 4, 2023
ef8e463
Small fixes
stefan-apollo Dec 4, 2023
3fcbe88
Minor change
stefan-apollo Dec 4, 2023
b9da8ab
Jake mapping
stefan-apollo Dec 4, 2023
7a79796
Neuron basis
stefan-apollo Dec 4, 2023
0d889ae
Test linear
stefan-apollo Dec 4, 2023
4c01b54
Implement Identity act fn@
stefan-apollo Dec 4, 2023
fc81052
Implemented diagonal edges test, and DNN rotation invariance text
stefan-apollo Dec 4, 2023
b8229cf
Notebook changes
stefan-apollo Dec 4, 2023
dc0da16
Fix dtype error
stefan-apollo Dec 8, 2023
c7c2240
Merge remote-tracking branch 'origin/main' into feature/implement_mod…
stefan-apollo Dec 8, 2023
d321709
Tests for dnn build (find e.g. dtype issue)
stefan-apollo Dec 8, 2023
e1f41a5
Fix settings
stefan-apollo Dec 8, 2023
8828586
Fix plotting file
stefan-apollo Dec 8, 2023
3590a80
Rename to BlockDiagonalDNN
stefan-apollo Dec 8, 2023
01e9ebf
Rename to BlockDiagonalDNN
stefan-apollo Dec 8, 2023
e4a7096
Revert "Rename to BlockDiagonalDNN"
stefan-apollo Dec 8, 2023
a5bab35
Revert "Rename to BlockDiagonalDNN"
stefan-apollo Dec 8, 2023
c56ec52
isort
stefan-apollo Dec 8, 2023
83e752e
Init
stefan-apollo Dec 8, 2023
4228fc6
typing
stefan-apollo Dec 8, 2023
19206e2
Random commit to restart CI
stefan-apollo Dec 8, 2023
e501ac8
Re-organise modular mlp
danbraunai-apollo Dec 11, 2023
0284b57
Merge modular_mlp_build with mlp_rib_build
danbraunai-apollo Dec 11, 2023
b70459a
Replicate old results in mlp_rib_build
danbraunai-apollo Dec 11, 2023
5e53560
Fix configs
danbraunai-apollo Dec 11, 2023
bde5077
Remove modular_mlp_build experiment
danbraunai-apollo Dec 11, 2023
4242d5f
Fix vscode settings and launch profile
danbraunai-apollo Dec 11, 2023
5d5fa86
Remove svd block diagonal config
danbraunai-apollo Dec 11, 2023
a196914
Remove duplicate type
danbraunai-apollo Dec 11, 2023
4553338
Don't fold bias inside ModularMLP.__init__
danbraunai-apollo Dec 12, 2023
cdbea96
Change MLPConfig dtype from str to torch.dtype
danbraunai-apollo Dec 12, 2023
d789202
Use torch.full instead of scalar * torch.ones
danbraunai-apollo Dec 12, 2023
cebf9af
Use validato in ModularMLPConfig
danbraunai-apollo Dec 12, 2023
fdb21c0
Add serializer and validator to MLP dtype in config
danbraunai-apollo Dec 12, 2023
2849565
Use torch.block_diag()
danbraunai-apollo Dec 12, 2023
8e2099f
Remove unnecessary path updating
danbraunai-apollo Dec 12, 2023
0c21bd6
Remove mark slow from test_modular_mlp_build_graph
danbraunai-apollo Dec 12, 2023
3a18054
Put test_lambdas logic in graph_build_test
danbraunai-apollo Dec 12, 2023
32e7725
Use logger instead of print
danbraunai-apollo Dec 12, 2023
5589d67
Always output model config in mlp_rib_build
danbraunai-apollo Dec 12, 2023
7e8cd8f
Fix validators in first_block fields in config
danbraunai-apollo Dec 12, 2023
3e69ed8
Update BlockVectorDataset.generate_data docstring
danbraunai-apollo Dec 12, 2023
73bc305
Ensure that U is stored for every layer
danbraunai-apollo Dec 12, 2023
bae23b3
Clean up ModularMLPConfig
danbraunai-apollo Dec 12, 2023
1966be0
Give warnings about reproducibility
danbraunai-apollo Dec 12, 2023
c6de03a
Clean up test_modular_mlp_diagonal_edges_when_linear
danbraunai-apollo Dec 12, 2023
c52b9ff
Clean comments in test_modular_mlp_diagonal_edges_when_linear
danbraunai-apollo Dec 12, 2023
9b79a59
Merge main into feature/implement_modular_dnn
danbraunai-apollo Dec 12, 2023
0c4b814
Add missing import
danbraunai-apollo Dec 12, 2023
1ae3103
Fix long docstring
danbraunai-apollo Dec 12, 2023
c5e0918
Prevent identical layers in modular mlp
danbraunai-apollo Dec 12, 2023
de514f2
Fix MLPConfig loading
danbraunai-apollo Dec 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "modular mlp rib build",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/experiments/mlp_rib_build/run_mlp_rib_build.py",
"args": "${workspaceFolder}/experiments/mlp_rib_build/block_diagonal.yaml",
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "train modadd",
"type": "python",
Expand Down
5 changes: 3 additions & 2 deletions experiments/lm_rib_build/run_lm_rib_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,11 @@ class Config(BaseModel):
description="The type of evaluation to perform on the model before building the graph."
"If None, skip evaluation.",
)
basis_formula: Literal["(1-alpha)^2", "(1-0)*alpha", "svd"] = Field(
basis_formula: Literal["(1-alpha)^2", "(1-0)*alpha", "svd", "neuron"] = Field(
"(1-0)*alpha",
description="The integrated gradient formula to use to calculate the basis. If 'svd', will"
"use Us as Cs, giving the eigendecomposition of the gram matrix.",
"use Us as Cs, giving the eigendecomposition of the gram matrix. If 'neuron', will use "
"the neuron-basis.",
)
edge_formula: Literal["functional", "squared"] = Field(
"functional",
Expand Down
28 changes: 28 additions & 0 deletions experiments/mlp_rib_build/block_diagonal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
exp_name: block_diagonal
node_layers:
- layers.0
- layers.1
- layers.2
- layers.3
- layers.4
- output
dataset:
name: block_vector
size: 10000
length: 4
data_variances: [1, 1]
data_perfect_correlation: false
seed: 0
modular_mlp_config:
n_hidden_layers: 4
width: 4
weight_variances: [2, 2]
weight_equal_columns: false
bias: 0
seed: 0
dtype: float64
batch_size: 256
n_intervals: 0
truncation_threshold: 1e-15
rotate_final_node_layer: false
basis_formula: (1-0)*alpha
37 changes: 31 additions & 6 deletions experiments/mlp_rib_build/plot_mlp_graph.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,66 @@
"""Plot an interaction graph given a results file contain the graph edges.

# TODO: Merge with experiments/lm_rib_build/plot_lm_graph.py
Usage:
python plot_mlp_graph.py <path/to/results_pt_file>

The results_pt_file should be the output of the run_mlp_rib_build.py script.
"""
import csv
from pathlib import Path
from typing import Optional, Union

import fire
import torch

from rib.log import logger
from rib.plotting import plot_interaction_graph
from rib.utils import check_outfile_overwrite


def main(results_file: str, force: bool = True) -> None:
def main(
results_file: str,
nodes_per_layer: Optional[Union[int, list[int]]] = None,
labels_file: Optional[str] = None,
out_file: Optional[Union[str, Path]] = None,
force: bool = False,
) -> None:
"""Plot an interaction graph given a results file contain the graph edges."""
results = torch.load(results_file)
out_dir = Path(__file__).parent / "out"
out_file = out_dir / f"{results['exp_name']}_rib_graph.png"
if out_file is None:
out_file = out_dir / f"{results['exp_name']}_rib_graph.png"
else:
out_file = Path(out_file)

# Input layer is much larger for mnist so include more nodes in it
nodes_per_layer = [40, 10, 10, 10] if nodes_per_layer is None else nodes_per_layer

if not check_outfile_overwrite(out_file, force):
return

# Input layer is much larger so include more nodes in it
nodes_per_layer = [40, 10, 10, 10]
# Ensure that we have edges
assert results["edges"], "The results file does not contain any edges."

layer_names = results["config"]["node_layers"] + ["output"]
# Add labels if provided
if labels_file is not None:
with open(labels_file, "r", newline="") as file:
reader = csv.reader(file)
node_labels = list(reader)
else:
node_labels = None

plot_interaction_graph(
raw_edges=results["edges"],
layer_names=layer_names,
layer_names=results["config"]["node_layers"],
nix-apollo marked this conversation as resolved.
Show resolved Hide resolved
exp_name=results["exp_name"],
nodes_per_layer=nodes_per_layer,
out_file=out_file,
node_labels=node_labels,
)

logger.info("Saved plot to %s", out_file)


if __name__ == "__main__":
fire.Fire(main)
55 changes: 41 additions & 14 deletions experiments/mlp_rib_build/run_mlp_rib_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,48 @@

"""

import json
from dataclasses import asdict
from pathlib import Path
from typing import Literal, Optional, Union

import fire
import torch
import yaml
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from torch.utils.data import DataLoader

from rib.data import VisionDatasetConfig
from rib.data import BlockVectorDatasetConfig, VisionDatasetConfig
from rib.data_accumulator import collect_gram_matrices, collect_interaction_edges
from rib.hook_manager import HookedModel
from rib.interaction_algos import calculate_interaction_rotations
from rib.loader import load_dataset, load_mlp
from rib.log import logger
from rib.models.mlp import MLPConfig
from rib.models.modular_mlp import ModularMLPConfig
from rib.types import TORCH_DTYPES, RibBuildResults, RootPath, StrDtype
from rib.utils import check_outfile_overwrite, load_config, set_seed


class Config(BaseModel):
model_config = ConfigDict(extra="forbid", frozen=True)
exp_name: str
mlp_path: RootPath
mlp_path: Optional[RootPath] = Field(
None,
description="Path to the saved MLP model. If None, we expect the MLP class to not be "
"randomly initialized (e.g. like in the ModularMLP class).",
)
batch_size: int
seed: Optional[int] = 0
truncation_threshold: float # Remove eigenvectors with eigenvalues below this threshold.
rotate_final_node_layer: bool # Whether to rotate the output layer to its eigenbasis.
n_intervals: int # The number of intervals to use for integrated gradients.
dtype: StrDtype # Data type of all tensors (except those overriden in certain functions).
node_layers: list[str]
basis_formula: Literal["(1-alpha)^2", "(1-0)*alpha"] = Field(
basis_formula: Literal["(1-alpha)^2", "(1-0)*alpha", "svd", "neuron"] = Field(
"(1-0)*alpha",
description="The integrated gradient formula to use to calculate the basis.",
description="The integrated gradient formula to use to calculate the basis. If 'svd', will"
"use Us as Cs, giving the eigendecomposition of the gram matrix. If 'neuron', will use "
"the neuron-basis.",
)
edge_formula: Literal["functional", "squared"] = Field(
"functional",
Expand All @@ -60,27 +66,48 @@ class Config(BaseModel):
description="Directory for the output files. Defaults to `./out/`. If None, no output "
"is written. If a relative path, it is relative to the root of the rib repo.",
)
dataset: VisionDatasetConfig = VisionDatasetConfig()
dataset: Union[VisionDatasetConfig, BlockVectorDatasetConfig] = Field(
VisionDatasetConfig(),
description="The dataset to use to build the graph.",
)
modular_mlp_config: Optional[ModularMLPConfig] = Field(
None,
description="The model to use. If None, we expect mlp_path to be set.",
)

@model_validator(mode="after")
def verify_model_config(self) -> "Config":
"""Verify that model_config is set if modular_mlp_config is not."""
if self.mlp_path is None and self.modular_mlp_config is None:
raise ValueError("model must be set if modular_mlp_config is not.")
return self


def main(config_path_or_obj: Union[str, Config], force: bool = False) -> RibBuildResults:
"""Implement the main algorithm and store the graph to disk."""
config = load_config(config_path_or_obj, config_model=Config)
set_seed(config.seed)

with open(config.mlp_path.parent / "config.yaml", "r") as f:
model_config_dict = yaml.safe_load(f)

if config.out_dir is not None:
config.out_dir.mkdir(parents=True, exist_ok=True)
out_file = config.out_dir / f"{config.exp_name}_rib_graph.pt"
if not check_outfile_overwrite(out_file, force):
raise FileExistsError("Not overwriting output file")

mlp_config: Union[MLPConfig, ModularMLPConfig]
if config.mlp_path is not None:
with open(config.mlp_path.parent / "config.yaml", "r") as f:
model_config_dict = yaml.safe_load(f)
mlp_config = MLPConfig(**model_config_dict["model"])
else:
assert config.modular_mlp_config is not None
mlp_config = config.modular_mlp_config

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = TORCH_DTYPES[config.dtype]
mlp_config = MLPConfig(**model_config_dict["model"])
mlp = load_mlp(mlp_config, config.mlp_path, fold_bias=True, device=device)
mlp = load_mlp(
mlp_config, mlp_path=config.mlp_path, fold_bias=True, device=device, seed=config.seed
)
assert mlp.has_folded_bias, "MLP must have folded bias to run RIB"

all_possible_node_layers = [f"layers.{i}" for i in range(len(mlp.layers))] + ["output"]
Expand Down Expand Up @@ -151,8 +178,8 @@ def main(config_path_or_obj: Union[str, Config], force: bool = False) -> RibBuil
"interaction_rotations": interaction_rotations,
"eigenvectors": eigenvectors,
"edges": [(module, E_hats[module].cpu()) for module in E_hats],
"config": json.loads(config.model_dump_json()),
"model_config_dict": model_config_dict,
"config": config.model_dump(),
"model_config_dict": mlp_config.model_dump(),
}

# Save the results (which include torch tensors) to file
Expand Down
112 changes: 109 additions & 3 deletions rib/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,20 @@
from typing import Literal, Optional

import torch
from jaxtyping import Int
from pydantic import BaseModel, ConfigDict, Field, model_validator
from jaxtyping import Float, Int
from pydantic import (
BaseModel,
ConfigDict,
Field,
ValidationInfo,
field_validator,
model_validator,
)
from torch import Tensor
from torch.utils.data import Dataset
from typing_extensions import Annotated

from rib.types import TORCH_DTYPES, StrDtype


class DatasetConfig(BaseModel):
Expand Down Expand Up @@ -121,8 +131,104 @@ def __len__(self) -> int:


class VisionDatasetConfig(DatasetConfig):
source: Literal["custom"] = "custom"
name: Literal["CIFAR10", "MNIST"] = "MNIST"
seed: Optional[int] = 0
return_set_frac: Optional[float] = None # Needed for some reason to avoid mypy errors
return_set_n_samples: Optional[int] = None # Needed for some reason to avoid mypy errors


class BlockVectorDatasetConfig(DatasetConfig):
name: Literal["block_vector"] = "block_vector"
size: int = Field(
1000,
description="Number of samples in the dataset.",
)
length: int = Field(
4,
description="Length of each vector.",
)
first_block_length: Optional[int] = Field(
None,
description="Length of the first block. If None, defaults to length // 2.",
validate_default=True,
)
data_variances: list[float] = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: better for this to be a tuple since it must be len 2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. However, tuple's don't seem to be supported by default in yaml.

You could add a validator like this:

    @field_validator("data_variances", mode="after")
    @classmethod
    def check_data_variances_len_two(cls, v: list[float]) -> list[float]:
        if len(v) != 2:
            raise ValueError("data_variances must be a list of length 2.")
        return v

Or one that converts the list to a tuple and ensures that it has length two. Haven't made either of these changes because there is a tonne of validation that we haven't really bothered to do in these configs, I don't think it's very high priority as most of them should just raise errors later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, that's fine then.

[1.0, 1.0],
description="Variance of the two blocks of the vectors.",
)
data_perfect_correlation: bool = Field(
False,
description="Whether to make the data within each block perfectly correlated.",
)
dtype: StrDtype = "float64"
seed: Optional[int] = 0

@field_validator("first_block_length", mode="after")
@classmethod
def set_first_block_length(cls, v: Optional[int], info: ValidationInfo) -> int:
if v is None:
return info.data["length"] // 2
return v


class BlockVectorDataset(Dataset):
def __init__(
self,
dataset_config: BlockVectorDatasetConfig,
):
"""Generate a dataset of random normal vectors.

The components in `[:first_block_length]` have variance `data_variances[0]`, while the
components in `[first_block_length:length]` have variance `data_variances[1]`.
If `data_perfect_correlation` is true, the entries in each block are identical. Otherwise
they have no correlation.
"""
self.cfg = dataset_config
self.data = self.generate_data()
# Not needed, just here for Dataset class
self.labels = torch.nan * torch.ones(self.cfg.size)

def __len__(self):
return self.cfg.size

def __getitem__(self, idx):
return self.data[idx], self.labels[idx]

def generate_data(self) -> Float[Tensor, "size length"]:
"""Generate a dataset of vectors with two blocks of variance.

Warning, changing the structure of this function may break reproducibility.

Returns:
A dataset of vectors with two blocks of variance.
"""
dtype = TORCH_DTYPES[self.cfg.dtype]
size = self.cfg.size
length = self.cfg.length
first_block_length = self.cfg.first_block_length
data_variances = self.cfg.data_variances
data_perfect_correlation = self.cfg.data_perfect_correlation

first_block_length = first_block_length or length // 2
second_block_length = length - first_block_length
data = torch.empty((size, length), dtype=dtype)

if self.cfg.seed is not None:
torch.manual_seed(self.cfg.seed)

if not data_perfect_correlation:
data[:, 0:first_block_length] = data_variances[0] * torch.randn(
size, first_block_length, dtype=dtype
)
data[:, first_block_length:] = data_variances[1] * torch.randn(
size, second_block_length, dtype=dtype
)
else:
data[:, 0:first_block_length] = data_variances[0] * torch.randn(
size, 1, dtype=dtype
).repeat(1, first_block_length)
data[:, first_block_length:] = data_variances[1] * torch.randn(
size, 1, dtype=dtype
).repeat(1, second_block_length)

return data
Loading