Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable ruff formatter #853

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ jobs:
- name: ruff
run: |
ruff --version
ruff check src
ruff check src # tests has a lot of issues , TODO
ruff format --check src # tests
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
include = ["src/fairchem/core/**/*.py", "src/fairchem/data/oc/**/*.py"]
include = ["src/fairchem/core/**/*.py", "src/fairchem/data/oc/**/*.py", "tests/**/*.py"]
line-length = 88

[lint]
Expand Down
20 changes: 13 additions & 7 deletions src/fairchem/core/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,19 @@ def checkpoint(self, *args, **kwargs):
self.config["timestamp_id"] = self.trainer.timestamp_id
if self.trainer.logger is not None:
self.trainer.logger.mark_preempting()
logging.info(f'Checkpointing callback is triggered, checkpoint saved to: {self.config["checkpoint"]}, timestamp_id: {self.config["timestamp_id"]}')
logging.info(
f'Checkpointing callback is triggered, checkpoint saved to: {self.config["checkpoint"]}, timestamp_id: {self.config["timestamp_id"]}'
)
return DelayedSubmission(new_runner, self.config)


def runner_wrapper(config: dict):
Runner()(config)


def main(args: argparse.Namespace | None = None, override_args: list[str] | None = None):
def main(
args: argparse.Namespace | None = None, override_args: list[str] | None = None
):
"""Run the main fairchem program."""
setup_logging()

Expand All @@ -66,7 +70,9 @@ def main(args: argparse.Namespace | None = None, override_args: list[str] | None
args, override_args = parser.parse_known_args()

# TODO: rename num_gpus -> num_ranks everywhere
assert args.num_gpus > 0, "num_gpus is used to determine number ranks, so it must be at least 1"
assert (
args.num_gpus > 0
), "num_gpus is used to determine number ranks, so it must be at least 1"
config = build_config(args, override_args)

if args.submit: # Run on cluster
Expand Down Expand Up @@ -98,9 +104,7 @@ def main(args: argparse.Namespace | None = None, override_args: list[str] | None

else: # Run locally on a single node, n-processes
if args.num_gpus > 1:
logging.info(
f"Running in local mode with {args.num_gpus} ranks"
)
logging.info(f"Running in local mode with {args.num_gpus} ranks")
# HACK to disable multiprocess dataloading in local mode
# there is an open issue where LMDB's environment cannot be pickled and used
# during torch multiprocessing https://github.com/pytorch/examples/issues/526
Expand All @@ -119,7 +123,9 @@ def main(args: argparse.Namespace | None = None, override_args: list[str] | None
)
elastic_launch(launch_config, runner_wrapper)(config)
else:
logging.info("Running in local mode without elastic launch (single gpu only)")
logging.info(
"Running in local mode without elastic launch (single gpu only)"
)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
Expand Down
16 changes: 13 additions & 3 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
T = TypeVar("T")
DISTRIBUTED_PORT = 13356


def os_environ_get_or_throw(x: str) -> str:
if x not in os.environ:
raise RuntimeError(f"Could not find {x} in ENV variables")
Expand Down Expand Up @@ -68,7 +69,9 @@ def setup(config) -> None:
)

# ensures GPU0 does not have extra context/higher peak memory
logging.info(f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}")
logging.info(
f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}"
)
torch.cuda.set_device(config["local_rank"])

dist.init_process_group(
Expand Down Expand Up @@ -104,13 +107,20 @@ def setup(config) -> None:
)
else:
if not os.environ.get("MASTER_ADDR"):
assert config["world_size"] == 1, "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup"
assert (
config["world_size"] == 1
), "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
config["local_rank"] = int(os.environ.get("LOCAL_RANK"))
dist.init_process_group(backend=config["distributed_backend"], rank=int(os.environ.get("RANK")), world_size=config["world_size"], timeout=timeout)
dist.init_process_group(
backend=config["distributed_backend"],
rank=int(os.environ.get("RANK")),
world_size=config["world_size"],
timeout=timeout,
)


def cleanup() -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def log_summary(self, summary_dict: dict[str, Any]) -> None:
def log_artifact(self, name: str, type: str, file_location: str) -> None:
pass


@registry.register_logger("wandb")
class WandBLogger(Logger):
def __init__(self, config) -> None:
Expand Down Expand Up @@ -115,6 +116,7 @@ def log_artifact(self, name: str, type: str, file_location: str) -> None:
art.add_file(file_location)
art.save()


@registry.register_logger("tensorboard")
class TensorboardLogger(Logger):
def __init__(self, config) -> None:
Expand Down
7 changes: 6 additions & 1 deletion src/fairchem/core/common/profiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
if TYPE_CHECKING:
from fairchem.core.common.logger import Logger


def get_default_profiler_handler(run_id: str, output_dir: str, logger: Logger):
"""Get a standard callback handle for the pytorch profiler"""

Expand All @@ -20,9 +21,13 @@ def trace_handler(p):
print(f"Saving trace in {output_path}")
p.export_chrome_trace(output_path)
if logger:
logger.log_artifact(name=trace_name, type="profile", file_location=output_path)
logger.log_artifact(
name=trace_name, type="profile", file_location=output_path
)

return trace_handler


def get_profile_schedule(wait: int = 5, warmup: int = 5, active: int = 2):
"""Get a profile schedule and total number of steps to run
check pytorch docs on the meaning of these paramters:
Expand Down
8 changes: 6 additions & 2 deletions src/fairchem/core/common/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
from submitit.core.utils import JobPaths


def add_timestamp_id_to_submission_pickle(slurm_folder: str, slurm_job_id: str, timestamp_id: str):
def add_timestamp_id_to_submission_pickle(
slurm_folder: str, slurm_job_id: str, timestamp_id: str
):
# Try to put the timestamp-id into the original submission pickle's config
# so that if the node crashes, it can be pick up the correct run to resume
#
# we need to do this after the job has started because the timestamp-id is generated at runtime
# instead a-priori before the submission starts (ie: if we had a db to store a global job unique job)
submission_pickle_path = JobPaths(folder=slurm_folder, job_id=slurm_job_id).submitted_pickle
submission_pickle_path = JobPaths(
folder=slurm_folder, job_id=slurm_job_id
).submitted_pickle
try:
with open(str(submission_pickle_path), "rb") as f:
pkl = pickle.load(f)
Expand Down
1 change: 1 addition & 0 deletions src/fairchem/core/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def spawn_multi_process(

return [mp_output_dict[i] for i in range(config.world_size)]


def init_local_distributed_process_group(backend="nccl"):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
Expand Down
4 changes: 3 additions & 1 deletion src/fairchem/core/datasets/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def sample_property_metadata(self, num_samples: int = 100):
}


def data_list_collater(data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False) -> BaseData | dict[str, torch.Tensor]:
def data_list_collater(
data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False
) -> BaseData | dict[str, torch.Tensor]:
batch = Batch.from_data_list(data_list)

if not otf_graph:
Expand Down
42 changes: 30 additions & 12 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,15 @@ def __init__(
# if finetune_config is provided, then attempt to load the model from the given finetune checkpoint
starting_model = None
if finetune_config is not None:
starting_model: HydraModel = load_model_and_weights_from_checkpoint(finetune_config["starting_checkpoint"])
logging.info(f"Found and loaded fine-tuning checkpoint: {finetune_config['starting_checkpoint']} (Note we are NOT loading the training state from this checkpoint, only parts of the model and weights)")
assert isinstance(starting_model, HydraModel), "Can only finetune starting from other hydra models!"
starting_model: HydraModel = load_model_and_weights_from_checkpoint(
finetune_config["starting_checkpoint"]
)
logging.info(
f"Found and loaded fine-tuning checkpoint: {finetune_config['starting_checkpoint']} (Note we are NOT loading the training state from this checkpoint, only parts of the model and weights)"
)
assert isinstance(
starting_model, HydraModel
), "Can only finetune starting from other hydra models!"

if backbone is not None:
backbone = copy.deepcopy(backbone)
Expand All @@ -268,17 +274,23 @@ def __init__(
)
elif starting_model is not None:
self.backbone = starting_model.backbone
logging.info(f"User did not specify a backbone, using the backbone from the starting checkpoint {self.backbone}")
logging.info(
f"User did not specify a backbone, using the backbone from the starting checkpoint {self.backbone}"
)
else:
raise RuntimeError("Backbone not specified and not found in the starting checkpoint")
raise RuntimeError(
"Backbone not specified and not found in the starting checkpoint"
)

if heads is not None:
heads = copy.deepcopy(heads)
# Iterate through outputs_cfg and create heads
self.output_heads: dict[str, HeadInterface] = {}

head_names_sorted = sorted(heads.keys())
assert len(set(head_names_sorted)) == len(head_names_sorted), "Head names must be unique!"
assert len(set(head_names_sorted)) == len(
head_names_sorted
), "Head names must be unique!"
for head_name in head_names_sorted:
head_config = heads[head_name]
if "module" not in head_config:
Expand All @@ -295,15 +307,23 @@ def __init__(
self.output_heads = torch.nn.ModuleDict(self.output_heads)
elif starting_model is not None:
self.output_heads = starting_model.output_heads
logging.info(f"User did not specify heads, using the output heads from the starting checkpoint {self.output_heads}")
logging.info(
f"User did not specify heads, using the output heads from the starting checkpoint {self.output_heads}"
)
else:
raise RuntimeError("Heads not specified and not found in the starting checkpoint")
raise RuntimeError(
"Heads not specified and not found in the starting checkpoint"
)

def forward(self, data: Batch):
# lazily get device from input to use with amp, at least one input must be a tensor to figure out it's device
if not self.device:
device_from_tensors = {x.device.type for x in data.values() if isinstance(x, torch.Tensor)}
assert len(device_from_tensors) == 1, f"all inputs must be on the same device, found the following devices {device_from_tensors}"
device_from_tensors = {
x.device.type for x in data.values() if isinstance(x, torch.Tensor)
}
assert (
len(device_from_tensors) == 1
), f"all inputs must be on the same device, found the following devices {device_from_tensors}"
self.device = device_from_tensors.pop()

emb = self.backbone(data)
Expand All @@ -319,5 +339,3 @@ def forward(self, data: Batch):
out[k] = self.output_heads[k](data, emb)

return out


34 changes: 20 additions & 14 deletions src/fairchem/core/models/dimenet_plus_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,16 @@ def forward(
)
}
if self.regress_forces:
outputs["forces"] = -1 * (
torch.autograd.grad(
outputs["energy"],
data.pos,
grad_outputs=torch.ones_like(outputs["energy"]),
create_graph=True,
)[0]
outputs["forces"] = (
-1
* (
torch.autograd.grad(
outputs["energy"],
data.pos,
grad_outputs=torch.ones_like(outputs["energy"]),
create_graph=True,
)[0]
)
)
return outputs

Expand Down Expand Up @@ -465,13 +468,16 @@ def forward(self, data):
outputs = {"energy": energy}

if self.regress_forces:
forces = -1 * (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
forces = (
-1
* (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
)
outputs["forces"] = forces

Expand Down
7 changes: 4 additions & 3 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def no_weight_decay(self) -> set:

@registry.register_model("equiformer_v2_energy_head")
class EquiformerV2EnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone, reduce: str="sum"):
def __init__(self, backbone, reduce: str = "sum"):
super().__init__()
self.reduce = reduce
self.avg_num_nodes = backbone.avg_num_nodes
Expand Down Expand Up @@ -645,8 +645,9 @@ def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
elif self.reduce == "mean":
return {"energy": energy / data.natoms}
else:
raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}")

raise ValueError(
f"reduce can only be sum or mean, user provided: {self.reduce}"
)


@registry.register_model("equiformer_v2_force_head")
Expand Down
6 changes: 4 additions & 2 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:

@registry.register_model("escn_energy_head")
class eSCNEnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone, reduce = "sum"):
def __init__(self, backbone, reduce="sum"):
super().__init__()
backbone.energy_block = None
self.reduce = reduce
Expand All @@ -558,7 +558,9 @@ def forward(
elif self.reduce == "mean":
return {"energy": energy / data.natoms}
else:
raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}")
raise ValueError(
f"reduce can only be sum or mean, user provided: {self.reduce}"
)


@registry.register_model("escn_force_head")
Expand Down
Loading
Loading