import amrex.space3d as amr import argparse import torch from torch import nn from urllib import request import zipfile def download_and_unzip(url, data_dir): request.urlretrieve(url, data_dir) with zipfile.ZipFile(data_dir, "r") as zip_dataset: zip_dataset.extractall() class ConnectedNN(nn.Module): """ ConnectedNN is a class of fully connected neural networks """ def __init__(self, layers, device=None): super().__init__() self.stack = nn.Sequential(*layers) if device is not None: self.to(device) def forward(self, x): return self.stack(x) class OneActNN(ConnectedNN): """ OneActNN is class of fully connected neural networks admitting only one activation function """ def __init__(self, n_in, n_out, n_hidden_nodes, n_hidden_layers, act, device=None): self.n_in = n_in self.n_out = n_out self.n_hidden_layers = n_hidden_layers self.n_hidden_nodes = n_hidden_nodes self.act = act layers = [nn.Linear(self.n_in, self.n_hidden_nodes)] for ii in range(self.n_hidden_layers): if self.act == 'ReLU': layers += [nn.ReLU()] if self.act == 'Tanh': layers += [nn.Tanh()] if self.act == 'PReLU': #Activation.PReLU: layers += [nn.PReLU()] if self.act == 'Sigmoid': layers += [nn.Sigmoid()] if ii < self.n_hidden_layers - 1: layers += [nn.Linear(self.n_hidden_nodes, self.n_hidden_nodes)] layers += [nn.Linear(self.n_hidden_nodes, self.n_out)] super().__init__(layers, device) def main(): parser = argparse.ArgumentParser() parser.add_argument('-k','--kill',action="store_true",help="Don't set num threads and therefore kill, crash or hang the code, generally just be unhappy") args = parser.parse_args() data_url = "https://zenodo.org/records/10810754/files/models.zip?download=1" download_and_unzip(data_url, "models.zip") if not args.kill: torch.set_num_threads(1) device = None stage_i = 0 model_file = f"models/beam_stage_{stage_i}_model.pt" if device is None: model_dict = torch.load(model_file, map_location="cpu") else: model_dict = torch.load(model_file, map_location=device) n_in = 6 n_out = 6 n_hidden_nodes = model_dict["n_hidden_nodes"] activation = model_dict["activation"] n_hidden_layers = model_dict["n_hidden_layers"] print('about to create NN') neural_network = OneActNN( n_in=n_in, n_out=n_out, n_hidden_nodes=n_hidden_nodes, n_hidden_layers=n_hidden_layers, act=activation, device=device, ) print('about to load state dict...') neural_network.load_state_dict(model_dict["model_state_dict"]) print('loaded state dict') if __name__=="__main__": main()