diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1016d12cfa..d16248bdbd 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -29,4 +29,5 @@ jobs: - name: ruff run: | ruff --version - ruff check src + ruff check src # tests has a lot of issues , TODO + ruff format --check src # tests diff --git a/ruff.toml b/ruff.toml index 1a3a4eeab0..30bdb80a4a 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,4 +1,4 @@ -include = ["src/fairchem/core/**/*.py", "src/fairchem/data/oc/**/*.py"] +include = ["src/fairchem/core/**/*.py", "src/fairchem/data/oc/**/*.py", "tests/**/*.py"] line-length = 88 [lint] diff --git a/src/fairchem/core/_cli.py b/src/fairchem/core/_cli.py index 26da8c7cd2..cd3960998e 100644 --- a/src/fairchem/core/_cli.py +++ b/src/fairchem/core/_cli.py @@ -49,7 +49,9 @@ def checkpoint(self, *args, **kwargs): self.config["timestamp_id"] = self.trainer.timestamp_id if self.trainer.logger is not None: self.trainer.logger.mark_preempting() - logging.info(f'Checkpointing callback is triggered, checkpoint saved to: {self.config["checkpoint"]}, timestamp_id: {self.config["timestamp_id"]}') + logging.info( + f'Checkpointing callback is triggered, checkpoint saved to: {self.config["checkpoint"]}, timestamp_id: {self.config["timestamp_id"]}' + ) return DelayedSubmission(new_runner, self.config) @@ -57,7 +59,9 @@ def runner_wrapper(config: dict): Runner()(config) -def main(args: argparse.Namespace | None = None, override_args: list[str] | None = None): +def main( + args: argparse.Namespace | None = None, override_args: list[str] | None = None +): """Run the main fairchem program.""" setup_logging() @@ -66,7 +70,9 @@ def main(args: argparse.Namespace | None = None, override_args: list[str] | None args, override_args = parser.parse_known_args() # TODO: rename num_gpus -> num_ranks everywhere - assert args.num_gpus > 0, "num_gpus is used to determine number ranks, so it must be at least 1" + assert ( + args.num_gpus > 0 + ), "num_gpus is used to determine number ranks, so it must be at least 1" config = build_config(args, override_args) if args.submit: # Run on cluster @@ -98,9 +104,7 @@ def main(args: argparse.Namespace | None = None, override_args: list[str] | None else: # Run locally on a single node, n-processes if args.num_gpus > 1: - logging.info( - f"Running in local mode with {args.num_gpus} ranks" - ) + logging.info(f"Running in local mode with {args.num_gpus} ranks") # HACK to disable multiprocess dataloading in local mode # there is an open issue where LMDB's environment cannot be pickled and used # during torch multiprocessing https://github.com/pytorch/examples/issues/526 @@ -119,7 +123,9 @@ def main(args: argparse.Namespace | None = None, override_args: list[str] | None ) elastic_launch(launch_config, runner_wrapper)(config) else: - logging.info("Running in local mode without elastic launch (single gpu only)") + logging.info( + "Running in local mode without elastic launch (single gpu only)" + ) os.environ["MASTER_ADDR"] = "localhost" os.environ["LOCAL_RANK"] = "0" os.environ["RANK"] = "0" diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index e54ae7d969..2e232ae7ba 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -22,6 +22,7 @@ T = TypeVar("T") DISTRIBUTED_PORT = 13356 + def os_environ_get_or_throw(x: str) -> str: if x not in os.environ: raise RuntimeError(f"Could not find {x} in ENV variables") @@ -68,7 +69,9 @@ def setup(config) -> None: ) # ensures GPU0 does not have extra context/higher peak memory - logging.info(f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}") + logging.info( + f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}" + ) torch.cuda.set_device(config["local_rank"]) dist.init_process_group( @@ -104,13 +107,20 @@ def setup(config) -> None: ) else: if not os.environ.get("MASTER_ADDR"): - assert config["world_size"] == 1, "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup" + assert ( + config["world_size"] == 1 + ), "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup" os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(get_free_port()) os.environ["LOCAL_RANK"] = "0" os.environ["RANK"] = "0" config["local_rank"] = int(os.environ.get("LOCAL_RANK")) - dist.init_process_group(backend=config["distributed_backend"], rank=int(os.environ.get("RANK")), world_size=config["world_size"], timeout=timeout) + dist.init_process_group( + backend=config["distributed_backend"], + rank=int(os.environ.get("RANK")), + world_size=config["world_size"], + timeout=timeout, + ) def cleanup() -> None: diff --git a/src/fairchem/core/common/logger.py b/src/fairchem/core/common/logger.py index 13d370dcb6..97199c15ce 100644 --- a/src/fairchem/core/common/logger.py +++ b/src/fairchem/core/common/logger.py @@ -60,6 +60,7 @@ def log_summary(self, summary_dict: dict[str, Any]) -> None: def log_artifact(self, name: str, type: str, file_location: str) -> None: pass + @registry.register_logger("wandb") class WandBLogger(Logger): def __init__(self, config) -> None: @@ -115,6 +116,7 @@ def log_artifact(self, name: str, type: str, file_location: str) -> None: art.add_file(file_location) art.save() + @registry.register_logger("tensorboard") class TensorboardLogger(Logger): def __init__(self, config) -> None: diff --git a/src/fairchem/core/common/profiler_utils.py b/src/fairchem/core/common/profiler_utils.py index 0828cb6737..5d32d55d5e 100644 --- a/src/fairchem/core/common/profiler_utils.py +++ b/src/fairchem/core/common/profiler_utils.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from fairchem.core.common.logger import Logger + def get_default_profiler_handler(run_id: str, output_dir: str, logger: Logger): """Get a standard callback handle for the pytorch profiler""" @@ -20,9 +21,13 @@ def trace_handler(p): print(f"Saving trace in {output_path}") p.export_chrome_trace(output_path) if logger: - logger.log_artifact(name=trace_name, type="profile", file_location=output_path) + logger.log_artifact( + name=trace_name, type="profile", file_location=output_path + ) + return trace_handler + def get_profile_schedule(wait: int = 5, warmup: int = 5, active: int = 2): """Get a profile schedule and total number of steps to run check pytorch docs on the meaning of these paramters: diff --git a/src/fairchem/core/common/slurm.py b/src/fairchem/core/common/slurm.py index 37023de1bc..f06af3adf4 100644 --- a/src/fairchem/core/common/slurm.py +++ b/src/fairchem/core/common/slurm.py @@ -6,13 +6,17 @@ from submitit.core.utils import JobPaths -def add_timestamp_id_to_submission_pickle(slurm_folder: str, slurm_job_id: str, timestamp_id: str): +def add_timestamp_id_to_submission_pickle( + slurm_folder: str, slurm_job_id: str, timestamp_id: str +): # Try to put the timestamp-id into the original submission pickle's config # so that if the node crashes, it can be pick up the correct run to resume # # we need to do this after the job has started because the timestamp-id is generated at runtime # instead a-priori before the submission starts (ie: if we had a db to store a global job unique job) - submission_pickle_path = JobPaths(folder=slurm_folder, job_id=slurm_job_id).submitted_pickle + submission_pickle_path = JobPaths( + folder=slurm_folder, job_id=slurm_job_id + ).submitted_pickle try: with open(str(submission_pickle_path), "rb") as f: pkl = pickle.load(f) diff --git a/src/fairchem/core/common/test_utils.py b/src/fairchem/core/common/test_utils.py index ce86aa782f..c57b8ac4b8 100644 --- a/src/fairchem/core/common/test_utils.py +++ b/src/fairchem/core/common/test_utils.py @@ -131,6 +131,7 @@ def spawn_multi_process( return [mp_output_dict[i] for i in range(config.world_size)] + def init_local_distributed_process_group(backend="nccl"): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(get_free_port()) diff --git a/src/fairchem/core/datasets/lmdb_dataset.py b/src/fairchem/core/datasets/lmdb_dataset.py index 346987d8e5..b3310d9e8c 100644 --- a/src/fairchem/core/datasets/lmdb_dataset.py +++ b/src/fairchem/core/datasets/lmdb_dataset.py @@ -211,7 +211,9 @@ def sample_property_metadata(self, num_samples: int = 100): } -def data_list_collater(data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False) -> BaseData | dict[str, torch.Tensor]: +def data_list_collater( + data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False +) -> BaseData | dict[str, torch.Tensor]: batch = Batch.from_data_list(data_list) if not otf_graph: diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 480ee7d028..ab3c95afac 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -254,9 +254,15 @@ def __init__( # if finetune_config is provided, then attempt to load the model from the given finetune checkpoint starting_model = None if finetune_config is not None: - starting_model: HydraModel = load_model_and_weights_from_checkpoint(finetune_config["starting_checkpoint"]) - logging.info(f"Found and loaded fine-tuning checkpoint: {finetune_config['starting_checkpoint']} (Note we are NOT loading the training state from this checkpoint, only parts of the model and weights)") - assert isinstance(starting_model, HydraModel), "Can only finetune starting from other hydra models!" + starting_model: HydraModel = load_model_and_weights_from_checkpoint( + finetune_config["starting_checkpoint"] + ) + logging.info( + f"Found and loaded fine-tuning checkpoint: {finetune_config['starting_checkpoint']} (Note we are NOT loading the training state from this checkpoint, only parts of the model and weights)" + ) + assert isinstance( + starting_model, HydraModel + ), "Can only finetune starting from other hydra models!" if backbone is not None: backbone = copy.deepcopy(backbone) @@ -268,9 +274,13 @@ def __init__( ) elif starting_model is not None: self.backbone = starting_model.backbone - logging.info(f"User did not specify a backbone, using the backbone from the starting checkpoint {self.backbone}") + logging.info( + f"User did not specify a backbone, using the backbone from the starting checkpoint {self.backbone}" + ) else: - raise RuntimeError("Backbone not specified and not found in the starting checkpoint") + raise RuntimeError( + "Backbone not specified and not found in the starting checkpoint" + ) if heads is not None: heads = copy.deepcopy(heads) @@ -278,7 +288,9 @@ def __init__( self.output_heads: dict[str, HeadInterface] = {} head_names_sorted = sorted(heads.keys()) - assert len(set(head_names_sorted)) == len(head_names_sorted), "Head names must be unique!" + assert len(set(head_names_sorted)) == len( + head_names_sorted + ), "Head names must be unique!" for head_name in head_names_sorted: head_config = heads[head_name] if "module" not in head_config: @@ -295,15 +307,23 @@ def __init__( self.output_heads = torch.nn.ModuleDict(self.output_heads) elif starting_model is not None: self.output_heads = starting_model.output_heads - logging.info(f"User did not specify heads, using the output heads from the starting checkpoint {self.output_heads}") + logging.info( + f"User did not specify heads, using the output heads from the starting checkpoint {self.output_heads}" + ) else: - raise RuntimeError("Heads not specified and not found in the starting checkpoint") + raise RuntimeError( + "Heads not specified and not found in the starting checkpoint" + ) def forward(self, data: Batch): # lazily get device from input to use with amp, at least one input must be a tensor to figure out it's device if not self.device: - device_from_tensors = {x.device.type for x in data.values() if isinstance(x, torch.Tensor)} - assert len(device_from_tensors) == 1, f"all inputs must be on the same device, found the following devices {device_from_tensors}" + device_from_tensors = { + x.device.type for x in data.values() if isinstance(x, torch.Tensor) + } + assert ( + len(device_from_tensors) == 1 + ), f"all inputs must be on the same device, found the following devices {device_from_tensors}" self.device = device_from_tensors.pop() emb = self.backbone(data) @@ -319,5 +339,3 @@ def forward(self, data: Batch): out[k] = self.output_heads[k](data, emb) return out - - diff --git a/src/fairchem/core/models/dimenet_plus_plus.py b/src/fairchem/core/models/dimenet_plus_plus.py index f555448261..55a2975e72 100644 --- a/src/fairchem/core/models/dimenet_plus_plus.py +++ b/src/fairchem/core/models/dimenet_plus_plus.py @@ -352,13 +352,16 @@ def forward( ) } if self.regress_forces: - outputs["forces"] = -1 * ( - torch.autograd.grad( - outputs["energy"], - data.pos, - grad_outputs=torch.ones_like(outputs["energy"]), - create_graph=True, - )[0] + outputs["forces"] = ( + -1 + * ( + torch.autograd.grad( + outputs["energy"], + data.pos, + grad_outputs=torch.ones_like(outputs["energy"]), + create_graph=True, + )[0] + ) ) return outputs @@ -465,13 +468,16 @@ def forward(self, data): outputs = {"energy": energy} if self.regress_forces: - forces = -1 * ( - torch.autograd.grad( - energy, - data.pos, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] + forces = ( + -1 + * ( + torch.autograd.grad( + energy, + data.pos, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] + ) ) outputs["forces"] = forces diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 2851acbc22..98e21a77f1 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -610,7 +610,7 @@ def no_weight_decay(self) -> set: @registry.register_model("equiformer_v2_energy_head") class EquiformerV2EnergyHead(nn.Module, HeadInterface): - def __init__(self, backbone, reduce: str="sum"): + def __init__(self, backbone, reduce: str = "sum"): super().__init__() self.reduce = reduce self.avg_num_nodes = backbone.avg_num_nodes @@ -645,8 +645,9 @@ def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): elif self.reduce == "mean": return {"energy": energy / data.natoms} else: - raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}") - + raise ValueError( + f"reduce can only be sum or mean, user provided: {self.reduce}" + ) @registry.register_model("equiformer_v2_force_head") diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 0b17e34eab..020416344f 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -537,7 +537,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: @registry.register_model("escn_energy_head") class eSCNEnergyHead(nn.Module, HeadInterface): - def __init__(self, backbone, reduce = "sum"): + def __init__(self, backbone, reduce="sum"): super().__init__() backbone.energy_block = None self.reduce = reduce @@ -558,7 +558,9 @@ def forward( elif self.reduce == "mean": return {"energy": energy / data.natoms} else: - raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}") + raise ValueError( + f"reduce can only be sum or mean, user provided: {self.reduce}" + ) @registry.register_model("escn_force_head") diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index c1a40ff59c..d60b24a845 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -139,8 +139,12 @@ def __init__( # Initialize the transformations between spherical and grid representations self.SO3_grid = nn.ModuleDict() - self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax, self.lmax, resolution=resolution) - self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax, self.mmax, resolution=resolution) + self.SO3_grid["lmax_lmax"] = SO3_Grid( + self.lmax, self.lmax, resolution=resolution + ) + self.SO3_grid["lmax_mmax"] = SO3_Grid( + self.lmax, self.mmax, resolution=resolution + ) self.mappingReduced = CoefficientMapping([self.lmax], [self.mmax]) # Initialize the blocks for each layer of the GNN @@ -157,7 +161,7 @@ def __init__( self.max_num_elements, self.SO3_grid, self.act, - self.mappingReduced + self.mappingReduced, ) self.layer_blocks.append(block) @@ -185,7 +189,6 @@ def __init__( requires_grad=False, ) - def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: pos: torch.Tensor = data["pos"] batch_idx: torch.Tensor = data["batch"] @@ -207,9 +210,7 @@ def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat( - edge_index, edge_distance_vec - ) + edge_rot_mat = self._init_edge_rot_mat(edge_index, edge_distance_vec) wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax).detach() ############################################################### @@ -252,7 +253,9 @@ def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. # These values are fed into the output blocks. - x_pt = torch.einsum("abc, pb->apc", x_message, self.sphharm_weights).contiguous() + x_pt = torch.einsum( + "abc, pb->apc", x_message, self.sphharm_weights + ).contiguous() ############################################################### # Energy estimation @@ -335,6 +338,7 @@ def _init_edge_rot_mat(self, edge_index, edge_distance_vec): def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) + class LayerBlock(torch.nn.Module): """ Layer block: Perform one layer (message passing and aggregation) of the GNN @@ -387,7 +391,7 @@ def __init__( max_num_elements, self.SO3_grid, self.act, - self.mappingReduced + self.mappingReduced, ) # Non-linear point-wise comvolution for the aggregated messages @@ -422,7 +426,11 @@ def forward( # Compute point-wise spherical non-linearity on aggregated messages # Project to grid - to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax, self.lmax)] + to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[ + :, + :, + self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax, self.lmax), + ] x_grid_message = torch.einsum("bai,zic->zbac", to_grid_mat, x_message) # x_grid = x.to_grid(self.SO3_grid["lmax_lmax"]) @@ -435,7 +443,11 @@ def forward( x_grid = self.fc3_sphere(x_grid) # Project back to spherical harmonic coefficients - from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax, self.lmax)] + from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[ + :, + :, + self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax, self.lmax), + ] return torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) @@ -498,7 +510,7 @@ def __init__( self.lmax, self.mmax, self.act, - self.mappingReduced + self.mappingReduced, ) self.so2_block_target = SO2Block( self.sphere_channels, @@ -507,7 +519,7 @@ def __init__( self.lmax, self.mmax, self.act, - self.mappingReduced + self.mappingReduced, ) def forward( @@ -558,7 +570,9 @@ def forward( x_target = torch.bmm(wigner_inv[:, :, self.out_mask], x_target) # Compute the sum of the incoming neighboring messages for each target node - new_embedding = torch.zeros(x.shape, dtype=x_target.dtype, device=x_target.device) + new_embedding = torch.zeros( + x.shape, dtype=x_target.dtype, device=x_target.device + ) new_embedding.index_add_(0, edge_index[1], x_target) return new_embedding @@ -585,7 +599,7 @@ def __init__( lmax: int, mmax: int, act, - mappingReduced + mappingReduced, ) -> None: super().__init__() self.sphere_channels = sphere_channels @@ -646,9 +660,7 @@ def forward( offset = self.mappingReduced.m_size[0] for m in range(1, self.mmax + 1): # Get the m order coefficients - x_m = x[ - :, offset : offset + 2 * self.mappingReduced.m_size[m] - ].contiguous() + x_m = x[:, offset : offset + 2 * self.mappingReduced.m_size[m]].contiguous() x_m = x_m.view(num_edges, 2, -1) # Perform SO(2) convolution x_m = self.so2_conv[m - 1](x_m, x_edge) diff --git a/src/fairchem/core/models/escn/so3.py b/src/fairchem/core/models/escn/so3.py index 34f505d51e..36e9d96cfc 100644 --- a/src/fairchem/core/models/escn/so3.py +++ b/src/fairchem/core/models/escn/so3.py @@ -462,8 +462,8 @@ def __init__(self, lmax: int, mmax: int, resolution: int | None = None) -> None: else: self.long_resolution = 2 * (self.mmax) + 1 if resolution: - self.long_resolution=resolution - self.lat_resolution=resolution + self.long_resolution = resolution + self.lat_resolution = resolution self.initialized = False diff --git a/src/fairchem/core/models/escn/so3_exportable.py b/src/fairchem/core/models/escn/so3_exportable.py index a5189d22df..b2442e2784 100644 --- a/src/fairchem/core/models/escn/so3_exportable.py +++ b/src/fairchem/core/models/escn/so3_exportable.py @@ -15,6 +15,8 @@ # https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10 # _Jd is a list of tensors of shape (2l+1, 2l+1) __Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt")) + + @torch.compiler.assume_constant_result def get_jd() -> torch.Tensor: return __Jd @@ -29,7 +31,9 @@ def wigner_D( lv: int, alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor ) -> torch.Tensor: _Jd = get_jd() - assert lv < len(_Jd), f"wigner D maximum l implemented is {len(_Jd) - 1}, send us an email to ask for more" + assert ( + lv < len(_Jd) + ), f"wigner D maximum l implemented is {len(_Jd) - 1}, send us an email to ask for more" alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma) J = _Jd[lv].to(dtype=alpha.dtype, device=alpha.device) @@ -49,6 +53,7 @@ def _z_rot_mat(angle: torch.Tensor, lv: int) -> torch.Tensor: M[..., inds, inds] = torch.cos(frequencies * angle[..., None]) return M + def rotation_to_wigner( edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int ) -> torch.Tensor: @@ -100,7 +105,7 @@ def __init__( # Compute the degree (l) and order (m) for each entry of the embedding l_harmonic = torch.tensor([]).long() m_harmonic = torch.tensor([]).long() - m_complex = torch.tensor([]).long() + m_complex = torch.tensor([]).long() self.res_size = torch.zeros([self.num_resolutions]).long().tolist() @@ -110,12 +115,8 @@ def __init__( mmax = min(self.mmax_list[i], l) m = torch.arange(-mmax, mmax + 1).long() m_complex = torch.cat([m_complex, m], dim=0) - m_harmonic = torch.cat( - [m_harmonic, torch.abs(m).long()], dim=0 - ) - l_harmonic = torch.cat( - [l_harmonic, m.fill_(l).long()], dim=0 - ) + m_harmonic = torch.cat([m_harmonic, torch.abs(m).long()], dim=0) + l_harmonic = torch.cat([l_harmonic, m.fill_(l).long()], dim=0) self.res_size[i] = len(l_harmonic) - offset offset = len(l_harmonic) @@ -143,57 +144,48 @@ def __init__( # save tensors and they will be moved to GPU self.register_buffer("l_harmonic", l_harmonic) self.register_buffer("m_harmonic", m_harmonic) - self.register_buffer("m_complex", m_complex) - self.register_buffer("to_m", to_m) + self.register_buffer("m_complex", m_complex) + self.register_buffer("to_m", to_m) self.pre_compute_coefficient_idx() - # Return mask containing coefficients of order m (real and imaginary parts) def complex_idx(self, m, lmax, m_complex, l_harmonic): """ - Add `m_complex` and `l_harmonic` to the input arguments - since we cannot use `self.m_complex`. + Add `m_complex` and `l_harmonic` to the input arguments + since we cannot use `self.m_complex`. """ if lmax == -1: lmax = max(self.lmax_list) indices = torch.arange(len(l_harmonic)) # Real part - mask_r = torch.bitwise_and( - l_harmonic.le(lmax), m_complex.eq(m) - ) + mask_r = torch.bitwise_and(l_harmonic.le(lmax), m_complex.eq(m)) mask_idx_r = torch.masked_select(indices, mask_r) mask_idx_i = torch.tensor([]).long() # Imaginary part if m != 0: - mask_i = torch.bitwise_and( - l_harmonic.le(lmax), m_complex.eq(-m) - ) + mask_i = torch.bitwise_and(l_harmonic.le(lmax), m_complex.eq(-m)) mask_idx_i = torch.masked_select(indices, mask_i) return mask_idx_r, mask_idx_i - def pre_compute_coefficient_idx(self): """ - Pre-compute the results of `coefficient_idx()` and access them with `prepare_coefficient_idx()` + Pre-compute the results of `coefficient_idx()` and access them with `prepare_coefficient_idx()` """ lmax = max(self.lmax_list) for l in range(lmax + 1): for m in range(lmax + 1): - mask = torch.bitwise_and( - self.l_harmonic.le(l), self.m_harmonic.le(m) - ) + mask = torch.bitwise_and(self.l_harmonic.le(l), self.m_harmonic.le(m)) indices = torch.arange(len(mask)) mask_indices = torch.masked_select(indices, mask) self.register_buffer(f"coefficient_idx_l{l}_m{m}", mask_indices) - def prepare_coefficient_idx(self): """ - Construct a list of buffers + Construct a list of buffers """ lmax = max(self.lmax_list) coefficient_idx_list = [] @@ -204,35 +196,39 @@ def prepare_coefficient_idx(self): coefficient_idx_list.append(l_list) return coefficient_idx_list - # Return mask containing coefficients less than or equal to degree (l) and order (m) def coefficient_idx(self, lmax: int, mmax: int): if lmax > max(self.lmax_list) or mmax > max(self.lmax_list): - mask = torch.bitwise_and( - self.l_harmonic.le(lmax), self.m_harmonic.le(mmax) - ) + mask = torch.bitwise_and(self.l_harmonic.le(lmax), self.m_harmonic.le(mmax)) indices = torch.arange(len(mask), device=mask.device) return torch.masked_select(indices, mask) else: temp = self.prepare_coefficient_idx() return temp[lmax][mmax] - def pre_compute_rotate_inv_rescale(self): lmax = max(self.lmax_list) for l in range(lmax + 1): for m in range(lmax + 1): mask_indices = self.coefficient_idx(l, m) - rotate_inv_rescale = torch.ones((1, int((l + 1)**2), int((l + 1)**2))) + rotate_inv_rescale = torch.ones( + (1, int((l + 1) ** 2), int((l + 1) ** 2)) + ) for l_sub in range(l + 1): if l_sub <= m: continue - start_idx = l_sub ** 2 + start_idx = l_sub**2 length = 2 * l_sub + 1 rescale_factor = math.sqrt(length / (2 * m + 1)) - rotate_inv_rescale[:, start_idx : (start_idx + length), start_idx : (start_idx + length)] = rescale_factor + rotate_inv_rescale[ + :, + start_idx : (start_idx + length), + start_idx : (start_idx + length), + ] = rescale_factor rotate_inv_rescale = rotate_inv_rescale[:, :, mask_indices] - self.register_buffer(f"rotate_inv_rescale_l{l}_m{m}", rotate_inv_rescale) + self.register_buffer( + f"rotate_inv_rescale_l{l}_m{m}", rotate_inv_rescale + ) def __repr__(self): return f"{self.__class__.__name__}(lmax_list={self.lmax_list}, mmax_list={self.mmax_list})" @@ -327,7 +323,7 @@ def get_to_grid_mat(self, device=None): return self.to_grid_mat # Compute matrices to transform grid to irreps - def get_from_grid_mat(self,device=None): + def get_from_grid_mat(self, device=None): return self.from_grid_mat # Compute grid from irreps representation diff --git a/src/fairchem/core/models/schnet.py b/src/fairchem/core/models/schnet.py index 878aee746a..968236e89d 100644 --- a/src/fairchem/core/models/schnet.py +++ b/src/fairchem/core/models/schnet.py @@ -114,13 +114,16 @@ def forward(self, data): outputs = {"energy": energy} if self.regress_forces: - forces = -1 * ( - torch.autograd.grad( - energy, - data.pos, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] + forces = ( + -1 + * ( + torch.autograd.grad( + energy, + data.pos, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] + ) ) outputs["forces"] = forces diff --git a/src/fairchem/core/modules/scaling/fit.py b/src/fairchem/core/modules/scaling/fit.py index 4bfc4bb62a..c5e86b453e 100644 --- a/src/fairchem/core/modules/scaling/fit.py +++ b/src/fairchem/core/modules/scaling/fit.py @@ -38,7 +38,6 @@ def _train_batch(trainer: BaseTrainer, batch) -> None: def compute_scaling_factors(config, num_batches: int = 16) -> None: - with new_trainer_context(config=config) as ctx: config = ctx.config trainer = ctx.trainer diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py index 473448de18..f4b5a757b4 100644 --- a/src/fairchem/core/preprocessing/atoms_to_graphs.py +++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py @@ -219,7 +219,9 @@ def convert(self, atoms: ase.Atoms, sid=None): data.edge_index = edge_index data.cell_offsets = cell_offsets - data.edge_distance_vec = self.get_edge_distance_vec(positions, edge_index, cell, cell_offsets) + data.edge_distance_vec = self.get_edge_distance_vec( + positions, edge_index, cell, cell_offsets + ) del atoms_copy if self.r_energy: diff --git a/src/fairchem/core/tasks/task.py b/src/fairchem/core/tasks/task.py index c10da91239..265220114e 100644 --- a/src/fairchem/core/tasks/task.py +++ b/src/fairchem/core/tasks/task.py @@ -31,13 +31,17 @@ def setup(self, trainer) -> None: # (https://github.com/FAIR-Chem/fairchem/blob/main/src/fairchem/core/_cli.py#L44), then we should attempt to # load that checkpoint if self.config["checkpoint"] is not None: - logging.info(f"Attemping to load user specified checkpoint at {self.config['checkpoint']}") + logging.info( + f"Attemping to load user specified checkpoint at {self.config['checkpoint']}" + ) self.trainer.load_checkpoint(checkpoint_path=self.config["checkpoint"]) # if the supplied checkpoint doesn't exist and there exists a previous checkpoint in the checkpoint path, this # means that the previous job didn't terminate "nicely" (due to node failures, crashes etc), then attempt # to load the last found checkpoint elif os.path.exists(self.chkpt_path): - logging.info(f"Previous checkpoint found at {self.chkpt_path}, resuming job from this checkecpoint") + logging.info( + f"Previous checkpoint found at {self.chkpt_path}, resuming job from this checkecpoint" + ) self.trainer.load_checkpoint(checkpoint_path=self.chkpt_path) def run(self):